#R script for WGCNA
#based on guidelines from: https://horvath.genetics.ucla.edu/html/CoexpressionNetwork/Rpackages/WGCNA/Tutorials/
#by J.

####~load libraries~~~~~~~~~~####
library(WGCNA)
library(svglite)

####~housekeeping~~~~~~~~~~~~####
rm(list=ls()) #clear the environment
setwd(dirname(rstudioapi::getActiveDocumentContext()$path))
options(stringsAsFactors = FALSE) #important! (apparently)
allowWGCNAThreads() #allows use of multiple threads for analysis

###~~specify data~~~~~~~~~~~~####
traits_file = "../02_reference_data/sample_sheet.csv" #some sort of sample sheet with traits for all samples
data_file = "../04_edger/log2cpm_NOBATCH.csv" #this should be a vst() or log2 transformed expression matrix

###~~specify output~~~~~~~~~~####
output = "../09_wgcna"
dir.create(output)
setwd(output)

###~~logfile~~~~~~~~~~~~~~~~~####
log_file=file(paste("08_wgcna_",Sys.Date(),".log",sep=""))
sink(log_file,append=TRUE,type="output",split=TRUE)

####~load data~~~~~~~~~~~~~~~####
allTraits = read.csv(traits_file, row.names = 1)
expressionData = read.csv(data_file, row.names = 1)

####~parse data~~~~~~~~~~~~~~####

###~~expression data~~~~~~~~~####
datExpr0 = as.data.frame(t(expressionData)) #remove everything except expression data & transpose table
row.names(datExpr0) = row.names(allTraits)
gsg = goodSamplesGenes(datExpr0, verbose = 3) #check for genes with too many missing values
gsg$allOK # if TRUE all genes have passed the cutoff

##~~~remove bad data~~~~~~~~~####
#only needed if gsg$allOK is FALSE - otherwise does nothing
if (!gsg$allOK)
{
  # Optionally, print the gene and sample names that were removed:
  if (sum(!gsg$goodGenes)>0)
    printFlush(paste("Removing genes:", paste(names(datExpr0)[!gsg$goodGenes], collapse = ", ")));
  if (sum(!gsg$goodSamples)>0)
    printFlush(paste("Removing samples:", paste(rownames(datExpr0)[!gsg$goodSamples], collapse = ", ")));
  # Remove the offending genes and samples from the data:
  datExpr0 = datExpr0[gsg$goodSamples, gsg$goodGenes]
}

##~~~sample clustering~~~~~~~####
#note that this is NOT gene clustering - that comes later
#goal is to find outlier samples
sampleTree = hclust(dist(datExpr0),method = "average") #build the sample tree
par(cex = 0.6)
par(mar = c(0,4,2,0))
svglite(filename = "sample_clustering_1.svg")
plot(sampleTree, 
     main = "Sample clustering to detect outliers", 
     sub="", 
     xlab="", 
     cex.lab = 1.5,
     cex.axis = 1.5, 
     cex.main = 2) #plot the sample clusters, check for outliers
dev.off()

#~~~~finalise expr data~~~~~~####
datExpr = datExpr0 #if clusters look weird, can remove some here

###~~trait data~~~~~~~~~~~~~~####
traits = allTraits[,-8] #remove "group" column as this replicates existing information

##~~~convert to numericals~~~####

#~~~~Lineage~~~~~~~~~~~~~~~~~####
traitsWV = subset(traits, Lineage == "WV")
traitsWV$viviparous = 1
traitsEO = subset(traits, Lineage == "EO")
traitsEO$viviparous = 0
traits = rbind(traitsWV, traitsEO)

#~~~~Condition~~~~~~~~~~~~~~~####
traitsCultured = subset(traits, Condition =="cultured")
traitsCultured$cultured = 1
traitsLive = subset(traits, Condition =="post rep")
traitsLive$cultured = 0
traits = rbind(traitsCultured, traitsLive)

##~~~remove nonsense~~~~~~~~~####
traits = traits[,7:9]
datTraits = traits[row.names(datExpr),]
collectGarbage()

####~recluster samples~~~~~~~####
sampleTree2 = hclust(dist(datExpr),method = "average")
traitColours = numbers2colors(datTraits,signed = FALSE) 
svglite("sample_clustering_2.svg",width = 10, height = 8) #save graph to here
plotDendroAndColors(sampleTree2,
                    traitColours,
                    groupLabels = names(datTraits),
                    main = "Cluster dendrogram & trait heatmap") 
dev.off()

####~choose SFP~~~~~~~~~~~~~~####
#this section is about helping to choose a soft-thresholding power for network construction
#check the output of pickSFPgraphs.svg and rerun with a different SFP if needed

###~~choose a set of SFPs~~~~####
powers = c(c(1:10), seq(from = 12, to = 20, by = 2)) #should return: 1-10, 12, 14, 16, 18, 20

###~~test SFPs with tool~~~~~####
sft = pickSoftThreshold(datExpr,powerVector = powers, verbose = 5) #call topology analysis function
svglite(filename = "pickSFPgraphs.svg", width = 8, height = 6)
par(mfrow = c(1,2)) #I think this somehow splits the viewing pane in two
cex1 = 0.9
plot(sft$fitIndices[,1], #draws 1st plot
     -sign(sft$fitIndices[,3])*sft$fitIndices[,2],
     xlab="Soft Threshold (power)",
     ylab="Scale Free Topology Model Fit,signed R^2",
     type="n",
     main = paste("Scale independence"))
text(sft$fitIndices[,1], #fills in data for 1st plot
     -sign(sft$fitIndices[,3])*sft$fitIndices[,2],
     labels=powers,
     cex=cex1,
     col="red")
abline(h=0.90,col="red") #draws a line, h corresponds to 0.9 R^2 cutoff
plot(sft$fitIndices[,1], sft$fitIndices[,5], #draws a 2nd plot
     xlab="Soft Threshold (power)",
     ylab="Mean Connectivity", 
     type="n",
     main = paste("Mean connectivity"))
text(sft$fitIndices[,1], #fills in 2nd plot with data
     sft$fitIndices[,5], 
     labels=powers, 
     cex=cex1,
     col="red")
dev.off()
#looks like 7 is best!

####~1 step NC & MI~~~~~~~~~~####
# one step easy process for:
# - network construction
# - module identification
# alternatives are available! see 2B & 2C if needed...
net = blockwiseModules(datExpr,
                       power = 7, #see prev section
                       TOMType = "unsigned", #wut?
                       minModuleSize = 30, #a "relatively large" minimum module size
                       reassignThreshold = 0, #?
                       mergeCutHeight = 0.25, #threshold for merging of modules
                       numericLabels = TRUE, #alternative is colour labels
                       pamRespectsDendro = FALSE,
                       saveTOMs = TRUE, #TOM = Topological Overlap Matrix
                       saveTOMFileBase = "lizardTOM", #save TOM as .RData file
                       verbose = 3)
net$colors #contains module asignment
net$MEs #contains module eigengenes for each module

###~~output~~~~~~~~~~~~~~~~~~####

##~~~modules~~~~~~~~~~~~~~~~~####
table(net$colors) #shows 16 modules, labelled 2-16; plus 0 (not in a modules)

##~~~dendrogram~~~~~~~~~~~~~~####
mergedColours = labels2colors(net$colors) #convert labels to colours for plotting
svglite(filename = "cluster_dendrogram.svg", height = 8, width = 8)
plotDendroAndColors(net$dendrograms[[1]],
                    mergedColours[net$blockGenes[[1]]],
                    "Module colors",
                    dendroLabels = FALSE,
                    hang = 0.03,
                    addGuide = TRUE,
                    guideHang = 0.05)
dev.off()
#NB - can use recutBlockwiseTrees to modify the dendrogram without recomputing it 
moduleLabels = net$colors
moduleColours = labels2colors(net$colors)
MEs = net$MEs
geneTree = net$dendrograms[[1]]

####~define # genes & samples####
nGenes = ncol(datExpr)
nSamples = nrow(datExpr)

####~recalculate MEs~~~~~~~~~####
MEs0 = moduleEigengenes(datExpr, moduleColours)$eigengenes
MEs = orderMEs(MEs0)
moduleTraitCor = cor(MEs, datTraits, use = "p")
moduleTraitPvalue = corPvalueStudent(moduleTraitCor, nSamples)

####~graphical representation####
svglite(filename = "heatmap.svg", height = 8, width = 8)
textMatrix = paste(signif(moduleTraitCor, 2), "\n(",
                   signif(moduleTraitPvalue, 1), ")", sep = "")
dim(textMatrix) = dim(moduleTraitCor)
par(mar = c(6, 8.5, 3, 3)) #presumably just formatting the display window?

labeledHeatmap(Matrix = moduleTraitCor,
               xLabels = c("Days in culture","Reproductive mode","Cultured"),
               yLabels = names(MEs),
               ySymbols = names(MEs),
               colorLabels = FALSE,
               colors = blueWhiteRed(50),
               textMatrix = textMatrix,
               setStdMargins = FALSE,
               cex.text = 0.5,
               zlim = c(-1,1),
               main = paste("Module-trait relationships"))
dev.off()

####~gene-trait rel parity~~~####
###~~define variable~~~~~~~~~####
parity = as.data.frame(datTraits$viviparous)
names(parity) = "viviparous"

###~~name module colours~~~~~####
modNames = substring(names(MEs), 3)
geneModuleMembership = as.data.frame(cor(datExpr, MEs, use = "p"))
MMPvalue = as.data.frame(corPvalueStudent(as.matrix(geneModuleMembership), nSamples))
names(geneModuleMembership) = paste("MM", modNames, sep = "")

geneTraitSignificance = as.data.frame(cor(datExpr, parity, use = "p"))
GSPvalue = as.data.frame(corPvalueStudent(as.matrix(geneTraitSignificance), nSamples))

names(geneTraitSignificance) = paste("GS.",names(parity),sep="")
names(GSPvalue) = paste("p.GS.", names(parity), sep="")

####~intramodular analysis~~~####
# want to identify genes with high sig (GS) for interesting trait
# AND high module membership (MM) in interesting modules
# brown module is most significant for parity

###~~scatterplot of GS vs MM~####
module = "yellow"
column = match(module, modNames)
moduleGenes = moduleColours==module
svglite("scatterplot_Brown_parity.svg")
verboseScatterplot(abs(geneModuleMembership[moduleGenes, column]),
                   abs(geneTraitSignificance[moduleGenes, 1]),
                   xlab = paste("Module Membership in",module,"module"),
                   ylab = "Gene significance for parity mode",
                   main = paste("Module membership vs. gene significance\n"),
                   cex.main = 1.2, cex.lab = 1.2, cex.axis = 1.2, col = module)
dev.off()
#this plot shows correlation between GS and MM
#you can try this for other traits/modules, guided by the previous figure

####~summary output~~~~~~~~~~####

###~~create honkin df~~~~~~~~####
geneInfo0 = data.frame(moduleColor = moduleColours,
                       geneTraitSignificance,
                       GSPvalue)

##~~~order by sig for parity~####
modOrder = order(-abs(cor(MEs, parity, use = "p"))) 

##~~~add module membership~~~####
for (mod in 1:ncol(geneModuleMembership))
{
  oldNames = names(geneInfo0)
  geneInfo0 = data.frame(geneInfo0, 
                         geneModuleMembership[,modOrder[mod]],
                         MMPvalue[,modOrder[mod]])
  names(geneInfo0) = c(oldNames, 
                       paste("MM.", modNames[modOrder[mod]], sep=""),
                       paste("p.MM.",modNames[modOrder[mod]],sep=""))
}

##~~~order genes~~~~~~~~~~~~~####
# want to order 1st by module colour, then GS
geneOrder = order(geneInfo0$moduleColor, -abs(geneInfo0$GS.viviparous)) 
geneInfo = geneInfo0[geneOrder,]

###~~write out summary~~~~~~~####
write.csv(geneInfo, file = "parity.csv")

####~gene-trait rel DiC~~~~~~####

###~~define variable~~~~~~~~~####
DiC = as.data.frame(datTraits$DiC)
names(DiC) = "DiC"

###~~name module colours~~~~~####
modNames = substring(names(MEs), 3)
geneModuleMembership = as.data.frame(cor(datExpr, MEs, use = "p"))
MMPvalue = as.data.frame(corPvalueStudent(as.matrix(geneModuleMembership), nSamples))
names(geneModuleMembership) = paste("DiC", modNames, sep = "")

geneTraitSignificance = as.data.frame(cor(datExpr, DiC, use = "p"))
GSPvalue = as.data.frame(corPvalueStudent(as.matrix(geneTraitSignificance), nSamples))

names(geneTraitSignificance) = paste("GS.",names(DiC),sep="")
names(GSPvalue) = paste("p.GS.", names(DiC), sep="")

####~intramodular analysis~~~####
# want to identify genes with high sig (GS) for DiC
# AND high module membership (MM) in interesting modules
# brown module is most significant for DiC

###~~scatterplot of GS vs MM~####
module = "yellow"
column = match(module, modNames)
moduleGenes = moduleColours==module
svglite("scatterplot_Brown_DiC.svg")
verboseScatterplot(abs(geneModuleMembership[moduleGenes, column]),
                   abs(geneTraitSignificance[moduleGenes, 1]),
                   xlab = paste("Module Membership in",module,"module"),
                   ylab = "Gene significance for DiC",
                   main = paste("Module membership vs. gene significance\n"),
                   cex.main = 1.2, cex.lab = 1.2, cex.axis = 1.2, col = module)
dev.off()
#this plot shows correlation between GS and MM
#this doesn't look right! should be strong correlation but not apparent from the plot ... I've messed up somewhere XD

####~parity gene list~~~~~~~~####
intModules.parity = c("yellow") #choose interesting modules
for (module in intModules.parity)
{
  modGenes = (geneInfo0$moduleColor==module) #select module gene IDs
  allSymbols = row.names(geneInfo0)
  modSymbols = allSymbols[modGenes]
  fileName = paste("genes-",module,".txt",sep = "")
  write(modSymbols, file = fileName)
}

####~cultured gene list~~~~~~####
intModules.cultured = c("black",
                        "pink",
                        "red",
                        "yellow",
                        "brown",
                        "turquoise")
for (module in intModules.cultured)
{
  modGenes = (geneInfo0$moduleColor==module) #select module gene IDs
  allSymbols = row.names(geneInfo0)
  modSymbols = allSymbols[modGenes]
  fileName = paste("genes-",module,".txt",sep = "")
  write(modSymbols, file = fileName)
}

####~DiC gene list~~~~~~~~~~~####
intModules.cultured = c("cyan",
                        "pink",
                        "midnightblue",
                        "red",
                        "salmon",
                        "yellow",
                        "greenyellow",
                        "purple")
for (module in intModules.cultured)
{
  modGenes = (geneInfo0$moduleColor==module) #select module gene IDs
  allSymbols = row.names(geneInfo0)
  modSymbols = allSymbols[modGenes]
  fileName = paste("genes-",module,".txt",sep = "")
  write(modSymbols, file = fileName)
}
  
####~network heatmap~~~~~~~~~####

###~~all genes~~~~~~~~~~~~~~~####
# Calculate topological overlap anew: this could be done more efficiently by saving the TOM
# calculated during module detection, but let us do it again here.
dissTOM = 1-TOMsimilarityFromExpr(datExpr, power = 7);
# Transform dissTOM with a power to make moderately strong connections more visible in the heatmap
plotTOM = dissTOM^7;
# Set diagonal to NA for a nicer plot
diag(plotTOM) = NA;
# Call the plot function
svglite("network_heatmap.svg")
TOMplot(plotTOM, geneTree, moduleColours, main = "Network heatmap plot, all genes")
dev.off() #NB this takes ages