# -*- coding: utf-8 -*-
"""
Created on Thu Dec  5 10:15:30 2024

@author: 2313518K
"""
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import interp1d


def multiplication_error_propagation(sigma_a, A, sigma_b, B, C):
    sigma_c = C * np.sqrt((sigma_a/A)**2 + (sigma_b/B)**2)
    return sigma_c

def addition_error_propagation(sigma_a, sigma_b):
    sigma_c = np.sqrt((sigma_a)**2 + (sigma_b)**2)
    return sigma_c


wavelength = np.array([1550,2328, 2790,2900, 3500, 3800])
wavelength_range = np.array([8, 71, 120, 112, 126, 186])
wavelength_error = wavelength_range/(2*np.sqrt(3))
#D_CD_1550 = 20

wavelength_2 = np.array([ 4400, 4750, 5030, 5438 ])
wavelength_range_2 = np.array([ 157, 114, 239, 148])

wavelength_error_2 = wavelength_range_2/(2*np.sqrt(3))
#electronic jitter error propagation ########################
rms_noise = 8.46e-3 #V
rms_noise_err = 2.5248E-3 #v

discriminator_slew_rate = 1.352e9 #V/s
discriminator_slew_rate_error = 127.72e6

electronic_jitter = (rms_noise/discriminator_slew_rate)*2.355
electronic_jitter_error = electronic_jitter * np.sqrt((discriminator_slew_rate_error/discriminator_slew_rate)**2 + (rms_noise_err/rms_noise)**2 )
#laser jitter ###############################################
laser_jitter = 3e-12 # s #typical customer values
laser_jitter_error = laser_jitter*0.05
#timetagger jitter ##########################################
timetagger_jitter = (1.2*2.35)*1e-12 # s #measured using timetagger device
timetagger_jitter_error = timetagger_jitter*0.05
#diode jitter ############################################################
diode_jitter = 15e-12 #s #typical customer value
diode_jitter_error = 15e-12*0.05
#############################################################
#D_CD = np.array([-7.61,8.532, 10.49,10.629, 11.236, 11.67])*1e-12 #s/nm km # typical customer values

file_path = "chromatic_dispersion_1500nm_to_3800nm.txt"
file_path_2 = "chromatic_dispersion_4400_to_5500nm.txt"
data = np.loadtxt(file_path)
data_2 = np.loadtxt(file_path_2)

# Extract wavelength (x) and dispersion values (y)
wavelengths = data[:, 0]
dispersion = data[:, 1]

wavelengths_2 = data_2[:, 0]
dispersion_2 = data_2[:, 1]

# Fit a 3rd-order polynomial
poly_func = interp1d(wavelengths, dispersion, 'cubic')
poly_func_2 = interp1d(wavelengths_2, dispersion_2, 'cubic')

# Generate interpolated values from 1500 to 3800 nm in 1 nm steps
interp_wavelengths = np.arange(1500, 3801, 1)
interp_dispersion = poly_func(interp_wavelengths)
D_CD = poly_func(wavelength)*1e-12 #s/nm km

interp_wavelengths_2 = np.arange(4400, 5501, 1)
interp_dispersion_2 = poly_func_2(interp_wavelengths_2)
D_CD_2 = poly_func_2(wavelength_2)*1e-12
D_CD = np.concatenate([D_CD, D_CD_2] )

L = (11)*1e-3 # km
L_error = 1e-5 #km
del_lambda = np.concatenate([wavelength_error, wavelength_error_2]) #nm

chromatic_dispersion_jitter = D_CD*L*del_lambda # s
chromatic_dispersion_jitter_error = chromatic_dispersion_jitter*0.05
 
#system jitter error ############################################################
system_jitter = np.sqrt(laser_jitter**2 + timetagger_jitter**2 + diode_jitter**2 + chromatic_dispersion_jitter**2 + electronic_jitter**2)
system_jitter_error = (1 / system_jitter) * np.sqrt(
    (laser_jitter * laser_jitter_error)**2 +
    (timetagger_jitter * timetagger_jitter_error)**2 +
    (diode_jitter * diode_jitter_error)**2 +
    (chromatic_dispersion_jitter * chromatic_dispersion_jitter_error)**2 +
    (electronic_jitter * electronic_jitter_error)**2
)
system_jitter = system_jitter*1e12
system_jitter_error = system_jitter_error*1e12


#total (measured) jitter ############################################################

total_jitter = np.array([25.23, 29.73, 33.68, 34.26, 39.3, 45.2, 46.86, 49.25, 55.98, 59.24])   
total_jitter_error = np.array([0.1, 1.48, 0.94, 0.93, 2.53, 0.6, 1.08, 1.64, 0.98, 1.2])
#intrinsic jitter error ############################################################

intrinsic_jitter = np.sqrt(total_jitter**2  -  system_jitter**2)
intrinsic_jitter_error = (1 / intrinsic_jitter) * np.sqrt(
    (total_jitter * total_jitter_error)**2 +
    (system_jitter * system_jitter_error)**2
)

#############################################################

# Set figure size for single-column width (85 mm)
fig_width_mm = 85  # Nature Photonics single-column width
fig_height_mm = fig_width_mm * 0.67  # Proportional height
fig_width_inches = fig_width_mm / 25.4  # Convert mm to inches
fig_height_inches = fig_height_mm / 25.4  # Convert mm to inches

# Create the figure and axis
fig, ax = plt.subplots(figsize=(fig_width_inches, fig_height_inches))

wavelength = np.concatenate([wavelength, wavelength_2])
wavelength_error = np.concatenate([wavelength_error, wavelength_error_2])
# Plot data with error bars
ax.errorbar(wavelength, total_jitter, xerr=wavelength_error, yerr=total_jitter_error, 
            fmt='o', color='blue', ecolor='lightblue', elinewidth=2, capsize=2, 
            label='$J_{system}$', markersize=4)

ax.errorbar(wavelength, intrinsic_jitter, xerr=wavelength_error, yerr = intrinsic_jitter_error,
            fmt='s', color='red', ecolor='pink', elinewidth=2, capsize=2, 
            label='$J_{intrinsic}$', markersize=4)

# Plot system jitter
ax.plot(wavelength, system_jitter, label='$J_{other}$', linestyle='--', 
        color='black', linewidth=0.75)

#ax.errorbar(wavelength, system_jitter, yerr = system_jitter_error, label='$J_{other}$', linestyle='--', 
 #       color='black', linewidth=0.75)
ax.fill_between(wavelength, system_jitter-system_jitter_error, system_jitter+system_jitter_error, alpha = 0.2, color= 'grey') 
#ax.errorbar(wavelength, system_jitter, yerr = system_jitter_error,
             #fmt='s', color='black', ecolor='grey', elinewidth=2, capsize=2, 
          #   label='$J_{other}$', markersize=4)


# Labels, legend, and formatting
ax.set_xlabel("Wavelength (nm)", fontsize=7, fontname='Arial')
ax.set_ylabel("Timing Jitter (ps)", fontsize=7, fontname='Arial')
ax.tick_params(axis='both', which='major', labelsize=7)

# Legend
ax.legend(loc="best", fontsize=7)  # No frame for legend

# Grid (subtle, dashed)
#ax.grid(linewidth=0.5, linestyle='--', alpha=0.5, color='gray')

# Save the plot
plt.savefig("Jitter_contributions.svg", dpi=300, bbox_inches='tight')

# Display the plot
plt.show()

