#Code accompanying the paper, "Phenotypic plasticity as a route to population shifts via tipping points".
#Produces Table S.3.3 and Fig. S.3.3 from the paper
#Does a sensitivity analysis on the parameters in the reaction norms for the through juvenile survival
#and the maximum adult fecundity, by doing a latin hypercube sampling on parameter space and checking if
#hysteresis occurs.


using DataFrames
using PlotlyJS
using LatinHypercubeSampling
using StatsBase, GLM
using StatsModels, RegressionFormulae
using StatsPlots

taus=[0.6,5,4.1,5.9];#E,L,pi+J
tau=sum(taus);
deltas=[0.07,0.004,0.0025];
SE=exp(-taus[1]*deltas[1]);
SL=exp(-taus[2]*deltas[2]);
SP=exp(-taus[3]*deltas[3]);
deltaA=0.27;

KL=50000;

C1=deltaA*(1-SL)/(deltas[2]*SL)#First constant
C2=deltaA/(SE*SL)


function q(a,p)#This sets up the adult fecundity classes
    q=p[1][1]*(p[1][2]*log(a)+p[1][3]+0im)^p[1][4];
    if (imag(q) ==0 && real(q)>0) #&& a>=0.14 
        return real(q);
    else
        return 0;
    end
end

function qprime(a,p)
    q= (p[1][2]/a)*p[1][4]*p[1][1]*(p[1][2]*log(a)+p[1][3]+0im)^(p[1][4]-1)
    if imag(q) ==0 #&& a>=0.14 
        return real(q);
    else
        return 0;
    end
end

function G(x,p)#logistic quadratic regression for G=ln(SJ/(1-SJ))
    #-4+3.36*log(x)-0.44*log(x)^2;
    p[2][2]+ p[2][3]*log(x)+p[2][4]*log(x)^2
end

function Gprime(x,p)
    p[2][3]/x+p[2][4]*2*log(x)/x
end

function SJ(a,p)#Equation for the survival probability classes
    Gs=exp(G(p[2][1]/(5*a),p))/(1+exp(G(p[2][1]/(5*a),p)));#G(KL/5a)
    if isnan(Gs)
        return 0;
    #elseif a<1.4
    #    return 0;
    else
        return Gs;
    end
end

function SJprime(a,p)
    Gs=(-(p[2][1]/5)/a^2*Gprime(p[2][1]/(5*a),p)*exp(G(p[2][1]/(5*a),p)))/((1+exp(G(p[2][1]/(5*a),p)))^2)
    if isnan(Gs)
        return 0;
    #elseif a<1.4
    #    return 0;
    else
        return Gs;
    end
end

function logDR(a,p)
    log(deltaA/(q(a,p)*SE*SL*SP*SJ(a,p)))
end

function logDRprime(a,p)
    -(qprime(a,p)/q(a,p)+SJprime(a,p)/SJ(a,p))
end

function aq(q,p)#Goes from q to the classes a
    exp(((q/p[1][1])^(1/p[1][4])-p[1][3])/p[1][2]);
end

function KAfroma(a,p)
    -(KL*SJ(a,p))/(C1*a*log(deltaA/(q(a,p)*SE*SL*SJ(a,p))))
end


m=64*4;



h1=3.95;
h2=6.90;
h3=-0.97;
h4=0.78;

KLexp=10000;
SJ1=-4;
SJ2=3.36;
SJ3=-0.44;

astart=exp(-h3/h2);
aend=exp(1/h2*((59/h1)^(1/h4)-h3));

an=1001;
as=Vector(LinRange(astart,aend,an));

changeRange=0.5;
sampleNumber=10000
plan = randomLHC(sampleNumber,7)
scaled_plan = scaleLHC(plan,[((1-changeRange)*h1,(1+changeRange)*h1),((1-changeRange)*h2,(1+changeRange)*h2),((1-changeRange)*h3,(1+changeRange)*h3),((1-changeRange)*h4,(1+changeRange)*h4),
                             ((1-changeRange)*SJ1,(1+changeRange)*SJ1),((1-changeRange)*SJ2,(1+changeRange)*SJ2),((1-changeRange)*SJ3,(1+changeRange)*SJ3)])

df=DataFrame(scaled_plan,["h1","h2","h3","h4","SJ1","SJ2","SJ3"]);
bifurcationCount=zeros(sampleNumber);
isBifurcation=zeros(sampleNumber);
for i = 1:sampleNumber
    params=scaled_plan[i,:]
    p=[[params[1],params[2],params[3],params[4]],[KLexp,params[5],params[6],params[7]]];
    grads=zeros(an);
    for j=1:an
        grads[j]=as[j]*((SJprime(as[j],p)/SJ(as[j],p))-(logDRprime(as[j],p)/logDR(as[j],p)))-1;
    end
    gradflips=grads[1:an-1].*grads[2:an];
    for j in (1:an-1)[gradflips.<0]
        if KAfroma(as[j],p)<0
            gradflips[j]=NaN;
        elseif  KAfroma(as[j],p)>1000
            gradflips[j]=NaN;
        end
    end        
    gradsum=sum(gradflips.<0)
    bifurcationCount[i]=gradsum
    if gradsum==2
        isBifurcation[i]=1
    end
end
df[!,:isBifurcation]=isBifurcation;
df[!,:bifurcationCount]=bifurcationCount;

varNames=["Δv₁","Δv₂","Δv₃","Δh₁","Δh₂","Δh₃","Δh₄"]
varValues=[SJ1,SJ2,SJ3,h1,h2,h3,h4]
variables=[:SJ1,:SJ2,:SJ3,:h1,:h2,:h3,:h4]
varRatios=[:SJ1Ratio,:SJ2Ratio,:SJ3Ratio,:h1Ratio,:h2Ratio,:h3Ratio,:h4Ratio]
for i in 1:length(variables)
    df[!,varRatios[i]]=(df[!,variables[i]].-varValues[i])./varValues[i]
end

plotrows=[1,1,1,2,2,3,3]
plotcols=[1,2,3,1,2,1,2]

fig = make_subplots(
    rows=3, cols=3
)

for i in 1:length(varNames)
    add_trace!(fig,PlotlyJS.violin(x=df[!,varRatios[i]][isBifurcation.==1],orientation="h",marker=attr(color="#FF4136",size=2,opacity=0.5, line=attr(color="#FF4136", width=2)), box_visible=true, color=:blue, points="all",name="", showlegend = false,xaxis_range=[-0.6,0.6]),row=plotrows[i],col=plotcols[i])
end
relayout!(fig, 
    xaxis_title=varNames[1], 
    xaxis2_title=varNames[2], 
    xaxis3_title=varNames[3], 
    xaxis4_title=varNames[4], 
    xaxis5_title=varNames[5], 
    xaxis7_title=varNames[6], 
    xaxis8_title=varNames[7], 
    yaxis_title="KDE of tipping point induction", 
    yaxis2_title="KDE of tipping point induction", 
    yaxis3_title="KDE of tipping point induction", 
    yaxis4_title="KDE of tipping point induction", 
    yaxis5_title="KDE of tipping point induction", 
    yaxis7_title="KDE of tipping point induction", 
    yaxis8_title="KDE of tipping point induction",
    width=1600,height=1200,font_size=16,
    xaxis_range=[-0.6,0.6],
    xaxis2_range=[-0.6,0.6],
    xaxis3_range=[-0.6,0.6],
    xaxis4_range=[-0.6,0.6],
    xaxis5_range=[-0.6,0.6],
    xaxis6_range=[-0.6,0.6],
    xaxis7_range=[-0.6,0.6],
    xaxis8_range=[-0.6,0.6]
)
fig


#save("ParameterSensitivtyAnalysis.png", fig1, px_per_unit = 3) # save high-resolution png

#:Linear regression
fm1 =@formula(isBifurcation ~ h1+h2+h3+h4+SJ1+SJ2+SJ3);
logit1 =glm(fm1, df, Binomial(), ProbitLink())
logit1CT=coeftable(logit1)
#lap(logit1CT)

#Making regression up to quadratic terms
fm2 =@formula(isBifurcation ~ (h1+h2+h3+h4+SJ1+SJ2+SJ3) ^ 2);
s = schema(fm2, df)
ts = apply_schema(fm2, s,RegressionModel)
logit = glm(ts, df, Binomial(), ProbitLink())
logitCT=coeftable(logit)
regressionDF=DataFrame(variable = coefnames(logit),
Estimate = logitCT.cols[1],
StdError = logitCT.cols[2],
z_val = logitCT.cols[3],
z_abs = abs.(logitCT.cols[3]),
p_val = logitCT.cols[4])

print(sort!(subset(regressionDF, :p_val => p -> p.<0.05),:z_abs,rev=true))
