# EdgeR analysis script 
# by ~J.
# for analysis of DGE in pregnant Z. vivipara oviduct (WV)

####~load libraries~~~~~~~~~~####
library(GenomicFeatures)
library(AnnotationDbi)
library(tximport)
library(edgeR)
library(svglite)

####~housekeeping~~~~~~~~~~~~####
rm(list=ls()) #clear the environment
setwd(dirname(rstudioapi::getActiveDocumentContext()$path)) #set wd to Scripts folder

###~~output directory~~~~~~~~####
output = "../04_edger/" #specify where the output should go
dir.create(output) #create directory for output
setwd(output) #set the new output directory as the working directory

###~~specify data~~~~~~~~~~~~####
refdata = "../02_reference_data/" #specify where the reference data is kept
saldata = "../03_salmon/" #specify where the salmon quant files are

###~~logfile~~~~~~~~~~~~~~~~~####
log_file=file(paste("01_edgeR_",Sys.Date(),".log",sep=""))
sink(log_file,append=TRUE,type="output")
sink(log_file,append=TRUE,type="message")

####~load data~~~~~~~~~~~~~~~####

###~~sample sheet~~~~~~~~~~~~####
ss = read.csv(paste(refdata,"sample_sheet.csv",sep = ""), row.names = 1)

###~~quant files~~~~~~~~~~~~~####

##~~~make tx2gene~~~~~~~~~~~~####
txdb = makeTxDbFromGFF(file= paste(refdata,"annotation.gff",sep = ""), format=c("gff"))
k = keys(txdb,keytype = "TXNAME")
tx2gene = AnnotationDbi::select(txdb, k, "GENEID", "TXNAME")

##~~~load quant files~~~~~~~~####
salmonQuantFiles = file.path("../03_salmon",paste(ss$Batch,ss$Barcode,sep = "_"),"quant.sf") #makes a list of filepaths to the quant data
names(salmonQuantFiles) = row.names(ss) #associate filepaths with sampleIDs from ss 
txi = tximport(salmonQuantFiles, type = "salmon", tx2gene = tx2gene) #import salmon quant files for DGE
cts = txi$counts #get gene counts for DGE
write.csv(cts, "salmon_counts.csv") #save gene counts

####~prepare DGEList~~~~~~~~~####
y = DGEList(cts, group = ss$Group)
y = normLibSizes(y) #normalise DGEList for library size
design = model.matrix(~ group, data = y$samples) #design for filtering
keep = filterByExpr(y, design) 
y = y[keep, ] #should keep only genes with ~10+ reads in at least one group
y = estimateDisp(y, design) #estimate dispersion

###~~get CPM~~~~~~~~~~~~~~~~~####
cpms = edgeR::cpm(y, offset = y$offset, log = FALSE)
write.csv(cpms, "cpm.csv")

###~~get log2 CPM~~~~~~~~~~~~####
logcpm = cpm(y, log = TRUE)
write.csv(logcpm, "log2cpm.csv")

###~~remove batch effect~~~~~####
logcpm.nb = removeBatchEffect(logcpm, batch = ss$Batch)
write.csv(logcpm.nb, "log2cpm_NOBATCH.csv")

####~some basic plots~~~~~~~~####
svglite("mds.svg", width = 4, height = 4)
plotMDS(y) #visualise variation between samples
dev.off()

svglite("bcv.svg", width = 4, height = 4)
plotBCV(y) #visualise dispersion estimates
dev.off()

####~GLM analysis of DGE~~~~~####

###~~fit GLM~~~~~~~~~~~~~~~~~####
design = model.matrix(~ 0 + group, data = y$samples)
colnames(design) = levels(factor(make.names(ss$Group)))
fit = glmQLFit(y, design)

my.contrasts = makeContrasts(
  WV.m1vl = WV.m1-WV.live, #compare WV cells at 1 month to live
  WV.m2vl = WV.m2-WV.live, #compare WV cells at 2 month to live
  WV.m2vm1 = WV.m2-WV.m1, #compare WV cells at 2 month to 1 month
  EO.m2vm1 = EO.m2-EO.m1, #compare EO cells at 2 month to 1 month
  EOvWV.m1 = EO.m1-WV.m1, #compare EO and WV cells at 1 month
  EOvWV.m2 = EO.m2-WV.m2, #compare EO and WV cells at 2 month
  EOvWV.DiC = (EO.m2-EO.m1)-(WV.m2-WV.m1), #compare changes from 1 month to 2 month in WV and EO
  levels=design)
qlf.WV.m1vl = glmQLFTest(fit, contrast=my.contrasts[,"WV.m1vl"])
qlf.WV.m2vl = glmQLFTest(fit, contrast=my.contrasts[,"WV.m2vl"])
qlf.WV.m2vm1 = glmQLFTest(fit, contrast=my.contrasts[,"WV.m2vm1"])
qlf.EO.m2vm1 = glmQLFTest(fit, contrast=my.contrasts[,"EO.m2vm1"])
qlf.EOvWV.m1 = glmQLFTest(fit, contrast=my.contrasts[,"EOvWV.m1"])
qlf.EOvWV.m2 = glmQLFTest(fit, contrast=my.contrasts[,"EOvWV.m2"])
qlf.EOvWV.DiC = glmQLFTest(fit, contrast=my.contrasts[,"EOvWV.DiC"])

###~~get DEGs~~~~~~~~~~~~~~~~####
res.WV.m1vl = topTags(qlf.WV.m1vl, n=nrow(y), sort.by = "PValue")
res.WV.m2vl = topTags(qlf.WV.m2vl, n=nrow(y), sort.by = "PValue")
res.WV.m2vm1 = topTags(qlf.WV.m2vm1, n=nrow(y), sort.by = "PValue")
res.EO.m2vm1 = topTags(qlf.EO.m2vm1, n=nrow(y), sort.by = "PValue")
res.EOvWV.m1 = topTags(qlf.EOvWV.m1, n=nrow(y), sort.by = "PValue")
res.EOvWV.m2 = topTags(qlf.EOvWV.m2, n=nrow(y), sort.by = "PValue")
res.EOvWV.DiC = topTags(qlf.EOvWV.DiC, n=nrow(y), sort.by = "PValue")

##~~~write out DEGs~~~~~~~~~~####
write.csv(res.WV.m1vl, "results_WV_m1vlive.csv")
write.csv(res.WV.m2vl, "results_WV_m2vlive.csv")
write.csv(res.WV.m2vm1, "results_WV_m2vm1.csv")
write.csv(res.EO.m2vm1, "results_EO_m2vm1.csv")
write.csv(res.EOvWV.m1, "results_EOvWV_m1.csv")
write.csv(res.EOvWV.m2, "results_EOvWV_m2.csv")
write.csv(res.EOvWV.DiC, "results_EOvWV_DiC.csv")

####~fin~~~~~~~~~~~~~~~~~~~~~####
closeAllConnections()
