# -*- coding: utf-8 -*-
"""
Created on Wed Dec  4 14:36:55 2024

@author: 2313518K
"""

import numpy as np
from tkinter import Tk, filedialog
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from scipy.stats import exponnorm
import os  # Add this to the top of your script

# Function to extract IRF data from file
def extract_IRF(file, start_p, end_p):
    data = np.loadtxt(file, skiprows=1)
    data_arr = data[:, 1]
    time_index = data[:, 0]
    return data_arr[start_p:end_p], time_index[start_p:end_p]

# Fitting function using exponnorm
def exponnorm_fit(x, loc, scale, shape, amplitude):
    # Exponentially modified normal distribution (exponnorm) from scipy.stats
    return amplitude * exponnorm.pdf(x, shape, loc, scale)

# Function to calculate FWHM from fitted parameters
def calculate_fwhm(shape, scale):
    # FWHM of the exponentially modified normal distribution
    fwhm = 2 * np.sqrt(2 * np.log(2)) * scale * np.sqrt(1 + shape**2)
    return fwhm

def calculate_fwhm_numerical(data, time):
    max_val = np.max(data)
    half_max = max_val / 2

    # Find indices where the data crosses the half max
    indices = np.where(data >= half_max)[0]

    if len(indices) < 2:
        return None, None  # Not enough data to calculate FWHM

    # FWHM is the difference between the first and last time points at half max
    left_idx = indices[0]
    right_idx = indices[-1]

    left_time = time[left_idx]
    right_time = time[right_idx]
    fwhm = right_time - left_time

    # Estimate errors based on time bin resolution and data fluctuations
    time_resolution = time[1] - time[0]  # Assuming uniform time bins
    data_noise = np.std(data)  # Estimate data noise level
    print(data_noise)
    # Calculate error on crossing points
    error_left = time_resolution / 2 + abs(data[left_idx] - half_max) / max(data_noise, 1e-6)
    error_right = time_resolution / 2 + abs(data[right_idx] - half_max) / max(data_noise, 1e-6)

    # Propagate errors to FWHM
    fwhm_error = np.sqrt(error_left**2 + error_right**2)

    return fwhm

def calculate_r_squared(y_obs, y_fit):
    ss_res = np.sum((y_obs - y_fit)**2)  # Residual sum of squares
    ss_tot = np.sum((y_obs - np.mean(y_obs))**2)  # Total sum of squares
    r_squared = 1 - (ss_res / ss_tot)
    return r_squared

# Calculate R^2


# Load data
start_p = 0
end_p = 10000
Tk().withdraw()  # Hide the root window
file = filedialog.askopenfilename(title="Select Save Directory", initialdir="Z:/Photon1K-PC/Daniel/Measurements/Devices/NASA_JPL/1_R5C13_differential_160nm")
data_arr, time_index = extract_IRF(file, start_p, end_p)
file_name = os.path.basename(file)

# Initial guess for the parameters [loc, scale, shape, amplitude]
initial_guess = [time_index[np.argmax(data_arr)], 10, 15, max(data_arr)]

'''
# Fit the data using exponnorm distribution
params, covariance = curve_fit(exponnorm_fit, time_index, data_arr, p0=initial_guess)

# Extract the fitted parameters
loc, scale, shape, amplitude = params

# Calculate the FWHM from the fitted parameters
fwhm = calculate_fwhm(shape, scale)

# Generate x values for plotting the fitted distribution
x_vals = np.linspace(min(time_index), max(time_index), len(data_arr))
fitted_vals = exponnorm_fit(x_vals, loc, scale, shape, amplitude)
fwhm = calculate_fwhm_numerical(fitted_vals, x_vals)
fwhm_numerical = calculate_fwhm_numerical(data_arr, time_index)
# Define the range to plot (300 ps around the peak)

'''

loc = np.argmax(data_arr)



plot_range_min = loc - 100
plot_range_max = loc + 100

# Filter the data within the range
filtered_indices = (time_index >= plot_range_min) & (time_index <= plot_range_max)
time_index_filtered = time_index[filtered_indices]
data_arr_filtered = data_arr[filtered_indices]

params, covariance = curve_fit(exponnorm_fit, time_index_filtered, data_arr_filtered, p0=initial_guess)

# Extract the fitted parameters
loc, scale, shape, amplitude = params
# Calculate 95% confidence intervals
param_errors = np.sqrt(np.diag(covariance))
confidence_intervals = 1.96 * param_errors  # 95% CI using normal distribution assumption
# Generate x values for plotting the fitted distribution within the range
x_vals_filtered = np.linspace(plot_range_min, plot_range_max, len(data_arr_filtered))
fitted_vals_filtered = exponnorm_fit(x_vals_filtered, loc, scale, shape, amplitude)
r_squared = calculate_r_squared(data_arr_filtered, fitted_vals_filtered)
print(f"R^2 = {r_squared:.4f}")

x_vals_filtered = np.linspace(plot_range_min, plot_range_max, 100000)
fitted_vals_filtered = exponnorm_fit(x_vals_filtered, loc, scale, shape, amplitude)

fwhm = calculate_fwhm_numerical(fitted_vals_filtered, x_vals_filtered)
fwhm_numerical = calculate_fwhm_numerical(data_arr_filtered, time_index_filtered)

# Calculate 95% confidence interval for FWHM
perr = np.sqrt(np.diag(covariance))  # Parameter standard deviations
shape_err, scale_err = perr[2], perr[1]  # Extract uncertainties for shape and scale
fwhm_err = fwhm * np.sqrt((scale_err / scale)**2 + (2 * shape * shape_err / (1 + shape**2))**2)
fwhm_lower = fwhm - 1.96 * fwhm_err
fwhm_upper = fwhm + 1.96 * fwhm_err

# Plot the data and the fitted distribution within the range
# Journal-compliant figure size (one-column width: 3.37 inches, or two-column width: 6.69 inches)
#fig_width = 6.69  # Adjust to 3.37 for one-column width
fig_width = 3.37
#fig_height = fig_width * 0.6  # Maintain a suitable aspect ratio
fig_height = fig_width*0.67
plt.figure(figsize=(3.37, fig_height))

# Plotting the data
plt.plot(time_index_filtered, data_arr_filtered/max(data_arr_filtered), label="4400nm", color='blue', lw=0.5)  # Ensure line width is >= 0.5 pt
plt.plot(
    x_vals_filtered,
    fitted_vals_filtered/max(fitted_vals_filtered),
    label=f'{round(fwhm,2)} ps ',
    color='red',
    lw=1.0,  # Increase line width for better visibility
    linestyle='--'
)

print(f"Exponnorm Fit Parameters:")
print(f"    Loc = {loc:.2f} ± {confidence_intervals[0]:.2f} ps")
print(f"    Scale = {scale:.2f} ± {confidence_intervals[1]:.2f} ps")
print(f"    Shape = {shape:.2f} ± {confidence_intervals[2]:.2f}")
print(f"    Amplitude = {amplitude:.2f} ± {confidence_intervals[3]:.2f}")
print(f"    FWHM = {fwhm:.2f} ps")
print(f"    95% CI for FWHM: (error {fwhm_err:.2f}, lower: {fwhm_lower:.2f}, upper: {fwhm_upper:.2f}) ps")
if fwhm_numerical is not None:
    print(f"    Numerical FWHM = {fwhm_numerical:.2f} ps")
else:
    print("    Numerical FWHM calculation failed")

# Labels and title
#plt.xlabel('Time (ps)', fontsize=8)  # Minimum 8-point font size
#plt.ylabel('Normalised counts', fontsize=8)
#plt.title(f'Instrument Response Function {file_name}', fontsize=10)  # Slightly larger for the title

plt.xticks(fontsize = 8)
plt.yticks(fontsize = 8)

# Grid, legend, and layout
#plt.grid(linewidth=0.5, linestyle='--', alpha=0.7)
plt.legend(fontsize=8)  # Minimum legend font size
plt.tight_layout()  # Optimize layout to fit within specified dimensions

# Save and/or show
#plt.savefig('1550nm.svg', dpi=300, bbox_inches='tight')  # Ensure high-resolution output
plt.show()




# Print the parameters of the fit, including FWHM
6