#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 13 10:51:41 2024

@author: fmuir
"""

import os
import numpy as np
import pickle
from datetime import datetime,timedelta
import matplotlib.pyplot as plt
from itertools import combinations
import pandas as pd

from Toolshed import Predictions, PredictionsPlotting

# Only use tensorflow in CPU mode
import tensorflow as tf
tf.config.set_visible_devices([],'GPU')
# %load_ext tensorboard
# import tensorboard

#%% Load Transect Data
# Name of site to save directory and files under
sitename = 'StAndrewsEastS2Full2024'
filepath = os.path.join(os.getcwd(), 'Data')

# Load in transect data with coastal change variables
TransectInterGDF, TransectInterGDFWater, TransectInterGDFTopo, TransectInterGDFWave = Predictions.LoadIntersections(filepath, sitename)

# Define symbol disctionary for labelling
SymbolDict = {'VE':          r'$VE$',
              'WL':          r'$WL$',
              'tideelev':    r'$z_{tide,sat}$',
              'beachwidth':  r'$d_{VE,WL}$',
              'tideelevFD':  r'$\bar{z}_{tide}$',
              'tideelevMx':  r'$z^{*}_{tide}$',
              'WaveHsFD':    r'$H_{s}$',
              'WaveDirFD':   r'$\bar\theta$',
              'WaveDirsin':  r'$sin(\bar\theta)$',
              'WaveDircos':  r'$cos(\bar\theta)$',
              'WaveTpFD':    r'$T_{p}$', 
              'WaveAlphaFD': r'$\alpha$',
              'Runups':      r'$R_{2}$',
              'Iribarren':   r'$\xi_{0}$', 
              'WL_u':        r'$WL_{u}$',
              'VE_u':        r'$VE_{u}$', 
              'WL_d':        r'$WL_{d}$',
              'VE_d':        r'$VE_{d}$'}

#%% Initialise storage dicts

CoastalDF = Predictions.CompileTransectData(TransectInterGDF, TransectInterGDFWater, TransectInterGDFTopo, TransectInterGDFWave)

PredDicts = dict.fromkeys(list(range(len(CoastalDF))),None)
FutureOutputs = dict.fromkeys(list(range(len(CoastalDF))),None)
TransectsDFTrain = dict.fromkeys(list(range(len(CoastalDF))),None)
TransectsDFTest = dict.fromkeys(list(range(len(CoastalDF))),None)

#%% Full-site run (looped through transects)

for Tr in CoastalDF[::5].index: # every 5th Tr (50m gap)
    print(f"{Tr}/{len(CoastalDF)}")
    # ignore Tr with less than a year of data (one img per month)
    if len(CoastalDF['VE'].iloc[Tr]) < 12 or len(CoastalDF['WL'].iloc[Tr]) < 12: 
        continue
    else:
        # Interpolate over transect data to get daily metrics
        TransectDF = Predictions.InterpVEWLWv(CoastalDF, Tr, IntpKind='pchip')
        TransectDFTrain = TransectDF.iloc[:int(0.8825*len(TransectDF))]
        TransectDFTest = TransectDF.iloc[int(0.8825*len(TransectDF)):]
        
        # Define training and target features
        # TrainFeats = ['WaveHsFD', 'Runups', 'WaveDirFD', 'WaveTpFD']#, 'tideelev']
        TrainFeats = ['WaveHsEW', 'Runups', 'WaveDirEW', 'WaveTpEW', 'WL_u-10', 'VE_u-10','WL_d-10', 'VE_d-10']
        TargFeats = ['VE', 'WL']
        
        # Separate timeseries into training/validation and testing portions
        TransectsDFTrain[Tr] = TransectDFTrain
        TransectsDFTest[Tr] = TransectDFTest
        
        # Prep data
        PredDict, VarDFDayTrain, VarDFDayTest = Predictions.PrepData(TransectDF, 
                                                                      MLabels=['Tr'+str(Tr)], 
                                                                      ValidSizes=[0.1], 
                                                                      TSteps=[10],
                                                                      TrainFeatCols=[TrainFeats],
                                                                      TargFeatCols=[TargFeats],
                                                                      TrainTestPortion=0.8825)
        # Compile LSTM based on provided hyperparameters
        PredDict = Predictions.CompileRNN(PredDict, 
                                          epochNums=[150], 
                                          batchSizes=[64],
                                          denseLayers=[128],
                                          dropoutRt=[0.2],
                                          learnRt=[0.001],
                                          hiddenLscale=[6],
                                          LossFn='Shoreshop',
                                          DynamicLR=False)
        
        # Train LSTM to predict target features
        PredDict = Predictions.TrainRNN(PredDict,filepath,sitename,EarlyStop=True)
        PredDicts[Tr] = PredDict
        
        # Use the trained model to predict target features in timeseries over the test period
        FutureOutput = Predictions.FuturePredict(PredDict, VarDFDayTest)
        # Assess the performance of the predictions against unseen test data
        FutureOutput = Predictions.ShorelineRMSE(FutureOutput, TransectDFTest)
        FutureOutputs[Tr] = FutureOutput

PredDictsClean = {k: v for k, v in PredDicts.items() if v is not None}
FutureOutputsClean = {k: v for k, v in FutureOutputs.items() if v is not None}

# Save all outputs to pickle file
pklpath = os.path.join(filepath,sitename,'predictions',sitename+'_FullPredict_EW_neighbours.pkl')
with open(pklpath, 'wb') as f:
    pickle.dump(FutureOutputsClean, f)
    
    
#%% Reindex transect IDs for plotting

NewKeys = list(np.array(list(FutureOutputsClean.keys()))/5)
FutureOutputsPlot = dict(zip(NewKeys, list(FutureOutputsClean.values())))

#%% Read in already trained and predicted test data
pklpath = os.path.join(filepath,sitename,'predictions',sitename+'_FullPredict_EW_neighbours.pkl')
with open(pklpath, 'rb') as f:
    FutureOutputsClean = pickle.load(f)
 
#%% Plot full site root mean square errors
PredictionsPlotting.PlotSiteRMSE(FutureOutputsPlot, filepath, sitename, '_EW_neighbours')

#%% Save RMSEs to transect shapefile
CoastalGDF = Predictions.SaveRMSEtoSHP(filepath, sitename, TransectInterGDFWater, CoastalDF, FutureOutputs, '_EW_neighbours')

#%% Plot the relationship between long-term change rate and RMSE for each transect
PredictionsPlotting.PlotRMSE_Rt(CoastalGDF, filepath, sitename, '_EW_neighbours')

#%% RMSE statistics totalled
VERMSE, WLRMSE, VERMSE10d, WLRMSE10d = Predictions.RMSE_Stats(FutureOutputsClean)

#%% Chosen impact class
PlotDateRange = [datetime(2023,10,1), datetime(2023,11,5)] # Storm Babet
for Tr in [55]:
    TransectDFTrain = TransectsDFTrain[Tr]
    TransectDFTest = TransectsDFTest[Tr]
    
    ImpactClass = Predictions.ClassifyImpact(pd.concat([TransectDFTrain,TransectDFTest]),Method='combi')
    PredImpactClass = Predictions.ClassifyImpact(FutureOutputsClean[Tr]['output'][0], Method='combi')
    
    # FutureImpacts = Predictions.ApplyImpactClasses(ImpactClass, FutureOutputs)
    # PredictionsPlotting.PlotImpactClasses(filepath, sitename, Tr, ImpactClass, pd.concat([TransectDFTrain,TransectDFTest]))
    # PredictionsPlotting.PlotImpactClasses(filepath, sitename, TransectIDs[0], PredImpactClass, FullFutureOutputs['output'][0])
    
    PredictionsPlotting.PlotFutureShort(0, Tr, TransectDFTrain, TransectDFTest, FutureOutputsClean[Tr], 
                                filepath, sitename, PlotDateRange, Storm=[datetime(2023,10,18), datetime(2023,10,21)],
                                ImpactClass=PredImpactClass)

#%% Impact classes for whole site
PredImpactClasses = dict.fromkeys(list(FutureOutputsClean.keys()),None)
for Tr in FutureOutputsClean.keys():
    TransectDFTrain = TransectsDFTrain[Tr]
    TransectDFTest = TransectsDFTest[Tr]
    
    # ImpactClasses[Tr] = Predictions.ClassifyImpact(pd.concat([TransectDFTrain,TransectDFTest]),Method='combi')
    PredImpactClasses[Tr] = Predictions.ClassifyImpact(FutureOutputsClean[Tr]['output'][0], Method='combi')
    








