#DGE analysis script
# by ~J.
#intended for use with output from edgeR2.R script

####~load libraries~~~~~~~~~~####
library(rstudioapi)
library(ggplot2)
library(ggrepel)
library(amap)
library(reshape2)
library(svglite)

####~housekeeping~~~~~~~~~~~~####
rm(list=ls()) #clear the environment
setwd(dirname(rstudioapi::getActiveDocumentContext()$path)) #set wd
count = 00

###~~output dir~~~~~~~~~~~~~~####
output = "../06_dge_analysis/"
dir.create(output) #make folder for output
setwd(output)

###~~specify data~~~~~~~~~~~~####
ss_file="../02_reference_data/sample_sheet.csv" #sample sheet
fa_file="../02_reference_data/functional_annotation.csv" #load eggnog mappings
em_file="../04_edger/cpm.csv" #expression matrix with CPM
de_file_start="../04_edger/results_" #differential expression output from edgeR
analyses = c("PREGvNON","PREGvPRE","PREGvPOST","POSTvPRE")
sample_order = c("pre rep","pregnant","post rep")
sample_labels = c("pre-pregnancy","pregnancy","post-parturition")

###~~logfile~~~~~~~~~~~~~~~~~####
log_file=file(paste("05_dge_analysis_",Sys.Date(),".log",sep=""))
sink(log_file,append=TRUE,type="output")
sink(log_file,append=TRUE,type="message")
Sys.time()

####~big loop~~~~~~~~~~~~~~~~####
for (analysis in analyses) {

####~analysis start~~~~~~~~~~####
count = count+1
if (count < 10) {
  anal_dir=paste("0",count,"_",analysis,sep = "") #if the analysis' index is less than 10, add a trailing 0
} else {
  anal_dir=paste(count,"_",analysis,sep = "") 
} 
dir.create(anal_dir)

####~load data~~~~~~~~~~~~~~~####
ss=read.csv(ss_file,row.names=3) #loads sample sheet
fa=read.csv(fa_file,row.names=1) #loads functional annotation
em=read.csv(em_file,row.names=1) #loads expression matrix
de=read.csv(paste(de_file_start,analysis,".csv",sep = ""),row.names=1) #loads differential expression
setwd(anal_dir) #output will now go to the directory for the current analysis

####~parse data~~~~~~~~~~~~~~####
sortByFDR = function(df, x = "FDR") {
  order_of_x=order(df[,x],decreasing=FALSE)
  df[order_of_x,]
}

geneNamesFromNog = function(df, x = fa) {
  ncbi = row.names(x) #get ncbi gene names from fa file
  nog = x$Preferred_name #get eggnog gene names from fa file
  gene2gene = data.frame(ncbi,nog) #make a data frame with both sets of gene names
  gene2gene$newnames = ifelse(
    grepl("^LOC\\d+", gene2gene$ncbi) & !(gene2gene$nog %in% c("","0","NA","NaN","-")),
    paste(sep = "", gene2gene$nog,"_",gene2gene$ncbi),
    gene2gene$ncbi
  ) #make a list of new names, preferring the NCBI annotation for non "LOC..." genes
  row.names(gene2gene) = gene2gene$ncbi
  gene2gene = gene2gene[-c(1,2)] #leave just the new names indexed by the ncbi names
  df = merge(df,gene2gene,by = 0) #merge the new names into the old df
  row.names(df) = df$newnames #rename rows using the new names
  df = subset(df, select=-c(Row.names,newnames)) #remove the old row names column
  return(df)
}

###~~expression matrix~~~~~~~####
em = em[,row.names(ss)] #select and reorder columns in em based on row names in ss
em = geneNamesFromNog(em)
em_scaled=data.frame(t(scale(data.frame(t(em)))))
em_scaled=na.omit(em_scaled)

##~~~write out em~~~~~~~~~~~~###
write.csv(em, file = "em.csv")
write.csv(em_scaled, file = "em_scaled.csv")

###~~master~~~~~~~~~~~~~~~~~~####
de = geneNamesFromNog(de)
master=merge(em,de,by.x=0,by.y=0) #combine DGE results with CPM to make master
row.names(master)=master[,"Row.names"]
names(master)[1]="SYMBOL"
master$mean=rowMeans(master[,2:(nrow(ss)+1)])
master$mlog10p=-log10(master$FDR)
master$sig=as.factor(master$FDR<0.1&abs(master$logFC)>1.0)
fa = geneNamesFromNog(fa)
master=merge(master,fa,by.x=0,by.y=0) #add eggnog functional annotation to master
row.names(master)=master[,"Row.names"]
master=master[,-1]

##~~~sig genes~~~~~~~~~~~~~~~####
master_sig=subset(master,sig==TRUE)
master_sig=sortByFDR(df = master_sig, x = "FDR")
write.csv(master_sig, file = "sig.csv") #write out master sig
sig_genes=master_sig$SYMBOL
em_sig=em[sig_genes,]
em_scaled_sig=em_scaled[sig_genes,]

##~~~sig up and sig down~~~~~####
master_sig_up=subset(master_sig,logFC>0)
master_sig_down=subset(master_sig,logFC<0)
master_non_sig=subset(master,sig==FALSE)

##~~~remake master~~~~~~~~~~~####
master_non_sig$direction="ns"
master_sig_up$direction="up"
master_sig_down$direction="down"
master_sig=rbind(master_sig_up,master_sig_down)
master=rbind(master_sig,master_non_sig)
master$direction=factor(master$direction,levels=c("up","down","ns"))

##~~~write out masters~~~~~~~####

###~~top up and down genes~~~####
master_sig_up=sortByFDR(df = master_sig_up, x = "FDR")
top5_sig_up=master_sig_up[1:5,]
top10_sig_up=master_sig_up[1:10,]
top20_sig_up=master_sig_up[1:20,]
write.csv(master_sig_up, file = "sig_up.csv") #write out master sig up

master_sig_down=sortByFDR(df = master_sig_down, x = "FDR")
top5_sig_down=master_sig_down[1:5,]
top10_sig_down=master_sig_down[1:10,]
top20_sig_down=master_sig_down[1:20,]
write.csv(master_sig_down, file = "sig_down.csv") #write out master sig down

###~~re-sort master~~~~~~~~~~####
master$direction=factor(master$direction,levels=c("up","down","ns"))
master$sig=factor(master$sig,levels=c("TRUE","FALSE"))
master = sortByFDR(df = master, x = "FDR")
write.csv(master, file = "master.csv") #write out final version of master

###~~gene lists~~~~~~~~~~~~~~####
all_genes = row.names(master)
write(all_genes, "gene_universe.txt")
write(sig_genes, "genes_sig.txt")
genes_non_sig = row.names(master_non_sig)
write(genes_non_sig, "genes_non_sig.txt")
genes_sig_up = row.names(master_sig_up)
write(genes_sig_up, "genes_sig_up.txt")
genes_sig_down = row.names(master_sig_down)
write(genes_sig_down, "genes_sig_down.txt")

####~theme~~~~~~~~~~~~~~~~~~~####

js_theme=theme(
  plot.title=element_text(size=14),
  axis.text.x=element_text(size=10),
  axis.text.y=element_text(size=10),
  axis.title.x=element_text(size=18),
  axis.title.y=element_text(size=18)
)

theme_j = theme(
  plot.title = element_blank(),
  axis.text.x=element_text(size=10),
  axis.text.y=element_text(size=10),
  axis.title.x=element_text(size=18),
  axis.title.y=element_text(size=18),
  panel.background = element_rect(fill = "white", colour = "lightgrey"),
  panel.grid.major = element_line(linewidth = 0.5, linetype = "solid", colour = "lightgrey"),
  panel.grid.minor = element_line(linewidth =  0.25, linetype = "solid", colour = "lightgrey"),
  legend.key = element_blank(),
  legend.title = element_blank(), 
  legend.text = element_text(size = 18), 
  legend.key.size = unit(1, "cm"))

###~~colourblind palettes~~~~####
palette_cb1 = c("#999999", "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7")
palette_cb2 = c("#000000", "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7")

####~make plots~~~~~~~~~~~~~~####

###~~MA~~~~~~~~~~~~~~~~~~~~~~####
ma_plot=ggplot(master,aes(x=log10(mean),y=logFC,colour=direction))+
  geom_point(size=0.9)+
  labs(title="MA plot",x="P",y="log2(FC)")+
  theme_bw()+
  geom_vline(xintercept=2,linetype="dashed",colour="grey",linewidth=0.5)+
  geom_hline(yintercept=0,linetype="dashed",colour="grey",linewidth=0.5)+
  scale_colour_manual(values=c("darkred","skyblue","black"),labels=c("up","down","non-sig"),name="")+
  geom_label_repel(data=top5_sig_up, aes(label=SYMBOL),show.legend=FALSE)+
  geom_label_repel(data=top5_sig_down, aes(label=SYMBOL),show.legend=FALSE)
ggsave("ma.svg",plot = ma_plot)

###~~volcano~~~~~~~~~~~~~~~~~####
volcano_plot=ggplot(master,aes(x=logFC,y=mlog10p,colour=direction))+
  geom_point()+
  labs(title="Volcano plot",x="log2(fold change)",y="-log10(p-adj)")+
  theme_j+
  geom_vline(xintercept=-1,linetype="dashed",colour="grey",size=0.5)+
  geom_vline(xintercept=1,linetype="dashed",colour="grey",size=0.5)+
  geom_hline(yintercept=-log10(0.05),linetype="dashed",colour="grey")+
  scale_colour_manual(values=c("darkred","skyblue","black"),labels=c("up","down","non-sig"),name="")+
  geom_label_repel(data=top5_sig_up, 
                   seed=2,
                   max.iter = Inf,
                   aes(label=SYMBOL),
                   nudge_x = 5,
                   xlim=c(1,NA),
                   show.legend=FALSE)+
  geom_label_repel(data=top5_sig_down, 
                   seed=36,
                   max.iter = Inf,
                   aes(label=SYMBOL),
                   nudge_x = -5,
                   xlim=c(NA,-1),
                   show.legend=FALSE)+
  expand_limits(x = c(-10,10))
ggsave("volcano.svg",plot = volcano_plot, width = 10, height = 8)
volcano_plot

###~~boxplots~~~~~~~~~~~~~~~~####

makeMyBoxplots = function(candidate_genes, em = em_scaled, palette = palette_cb1, savename) {
  #~make gene table~~~~~~~~~####
  gene_data=data.frame(t(em_scaled[candidate_genes,]))
  gene_data$sample_group=ss$Condition
  gene_data.m=melt(gene_data,id.vars = "sample_group")
  gene_data.m$sample_group=factor(gene_data$sample_group,levels=sample_order) #reorder
  
  #~make boxplot~~~~~~~~~~~~####
  boxplot=ggplot(gene_data.m,aes(x=variable,y=value,fill=sample_group))+
    geom_boxplot(outlier.size=0,show.legend=TRUE)+
    theme_j+
    xlab(element_blank())+
    ylab("expression")+
    scale_fill_manual(values=palette, labels=sample_labels)+
    theme(axis.text.x=element_text(angle=45,hjust=1))#must be placed after all other theme, rotates x axis text

  #~make faceted boxplot~~~~####
  faceted_boxplot=ggplot(gene_data.m,aes(y=value,fill=sample_group))+ 
    geom_boxplot(outlier.size=0,show.legend=TRUE)+
    theme_j+
    theme(axis.text.x = element_blank(), 
          axis.ticks.x = element_blank(), 
          axis.text.y = element_blank(),
          axis.ticks.y = element_blank(),
          strip.text = element_text(size = 7))+
    xlab(element_blank())+
    ylab("relative expression")+
    scale_fill_manual(values=palette, labels=sample_labels)+
    facet_wrap(~variable,ncol=5)
  
  #~save out~~~~~~~~~~~~~~~~####
  ggsave(paste(savename,".svg",sep = ""), boxplot, height = 7, width = 10)
  ggsave(paste(savename,"_faceted.svg",sep = ""), faceted_boxplot, height = 7, width = 10)
}

##~~~top 10~~~~~~~~~~~~~~~~~~####
top10=master[1:10,]
candidate_genes=as.vector(row.names(top10)) #get top 10 genes as a vector
makeMyBoxplots(candidate_genes = candidate_genes, em = em_scaled, savename = "top10")

##~~~top 20~~~~~~~~~~~~~~~~~~####
top20=master[1:20,]
candidate_genes=as.vector(row.names(top20)) #get top 20 genes as a vector
makeMyBoxplots(candidate_genes = candidate_genes, em = em_scaled, savename = "top20")

##~~~top 5 up~~~~~~~~~~~~~~~~####
candidate_genes=as.vector(row.names(top5_sig_up)) #get top genes as a vector
makeMyBoxplots(candidate_genes = candidate_genes, em = em_scaled, savename = "5up")

##~~~top 10 up~~~~~~~~~~~~~~~####
candidate_genes=as.vector(row.names(top10_sig_up)) #get top genes as a vector
makeMyBoxplots(candidate_genes = candidate_genes, em = em_scaled, savename = "10up")

##~~~top 20 up~~~~~~~~~~~~~~~####
candidate_genes=as.vector(row.names(top20_sig_up)) #get top genes as a vector
makeMyBoxplots(candidate_genes = candidate_genes, em = em_scaled, savename = "20up")

##~~~top 5 down~~~~~~~~~~~~~~####
candidate_genes=as.vector(row.names(top5_sig_down)) #get top genes as a vector
makeMyBoxplots(candidate_genes = candidate_genes, em = em_scaled, savename = "5down")

##~~~top 10 down~~~~~~~~~~~~~####
candidate_genes=as.vector(row.names(top10_sig_down)) #get top genes as a vector
makeMyBoxplots(candidate_genes = candidate_genes, em = em_scaled, savename = "10down")

##~~~top 20 down~~~~~~~~~~~~~####
candidate_genes=as.vector(row.names(top20_sig_down)) #get top genes as a vector
makeMyBoxplots(candidate_genes = candidate_genes, em = em_scaled, savename = "20down")

####~close loop~~~~~~~~~~~~~~####
setwd("..")
}

####~end of script~~~~~~~~~~~####
closeAllConnections()
