#DRAFT script for running variancePartition
# by ~J.
#intended for use with output from edgeR2.R script

####~load libraries~~~~~~~~~~####
library(variancePartition)
library(ggplot2)
library(svglite)

####~housekeeping~~~~~~~~~~~~####
#browser()
rm(list=ls()) #clear the environment
setwd(dirname(rstudioapi::getActiveDocumentContext()$path)) #set wd

###~~specify data~~~~~~~~~~~~####
em_file=("../04_edger/cpm.csv") #expression matrix with CPM
ss_file=("../02_reference_data/sample_sheet.csv") #sample sheet

###~~specify forms~~~~~~~~~~~####
form1 = ~ DiC + (1|Individual) + (1|Lineage) #edit variables as needed! random effects need (1|effect)
form2 = ~ + DiC + Individual + Lineage #for CCA analysis, list of test variables

###~~specify output~~~~~~~~~~####
output = "../08_variance_partition/"
dir.create(output)
setwd(output) #set wd to output directory

###~~logfile~~~~~~~~~~~~~~~~~####
log_file=file(paste("07_variance_partition_",Sys.Date(),".log",sep=""))
sink(log_file,append=TRUE,type="output",split=TRUE)

####~load data~~~~~~~~~~~~~~~####
em=read.csv(em_file,row.names=1) #loads expression matrix
ss=read.csv(ss_file,row.names=1) #loads sample sheet

####~parse data~~~~~~~~~~~~~~####
ss=subset(ss, Group != "WV.live") #remove WV live oviduct samples from analysis

###~~prune columns~~~~~~~~~~~####
em=em[,row.names(ss)] #removes live samples and orders em by sample sheet

###~~prune rows with 0 reads~####
em$sums = rowSums(em)
em = subset(em, sums > 0)
em = subset(em, select = -sums)

###~~scale em~~~~~~~~~~~~~~~~####
em.scaled=data.frame(t(scale(data.frame(t(em)))))

####~remove batch effects~~~~####
residList = fitVarPartModel(em, ~ (1|Batch), ss,fxn=residuals) #extract residuals directly
residMatrix = do.call(rbind, residList)

####~variance parititon~~~~~~####
varPartResid = fitExtractVarPartModel(residMatrix,form1,ss) #fit model, extract results

####~visualise output~~~~~~~~####
###~~CCA~~~~~~~~~~~~~~~~~~~~~####
C = canCorPairs(form2, ss)
svglite(filename = "cca.svg")
cca = plotCorrMatrix(C)
dev.off()

###~~basic violin plot~~~~~~~####
vp = sortCols(varPartResid)
violin = plotVarPart(vp) #violin plot shows contribution of each variable to total viariance
ggsave("violin_plot.svg", plot = violin)

####~save some stats~~~~~~~~~####
vp.DiC = median(vp$DiC)
vp.ind = median(vp$Individual)
vp.lineage = median(vp$Lineage)
writeLines(c(paste("Variance explained (DiC) =",vp.DiC),
           paste("Variance explained (individual) =",vp.ind),
           paste("Variance explained (lineage) =",vp.lineage)),
           "summary.txt")

####~top gene for parity~~~~~####
em.scaled = as.matrix(em.scaled)
i = which.max(varPartResid$Lineage)
GE = data.frame(Expression = em.scaled[i,], Lineage = ss$Lineage)
parity_top1 = plotStratify( Expression ~ Lineage,GE, main=rownames(em)[i])
ggsave("parity_top1.svg", plot = parity_top1)

####~top 20 genes for parity~####
by_parity = order(varPartResid$Lineage, decreasing = TRUE)
vp.par = varPartResid[by_parity,]
vp.par.20 = vp.par[1:20,]
top20parity = row.names(vp.par.20)
parity_top20 = plotPercentBars(vp.par.20)
ggsave("parity_top_20.svg", plot = parity_top20)

####~fin~~~~~~~~~~~~~~~~~~~~~####
closeAllConnections()
