from packages import *
st.set_page_config(page_title = "NIRS Utils", page_icon = ":goat:", layout = "wide")
from common import *
local_css(css_file / "style_model.css")#load specific model page css

from utils.data_handling import *




hash_ = ''
def p_hash(add):
    global hash_
    hash_ = hash_data(hash_+str(add))
    return hash_
# Initialize the variable in session state if it doesn't exist for st.cache_data
if 'counter' not in st.session_state:
    st.session_state.counter = 0
def increment():
    st.session_state.counter += 1

# ####################################  Methods ##############################################
def delete_files(keep):
    supp = []
    # Walk through the directory
    for root, dirs, files in os.walk('report/', topdown=False):
        for file in files:
            if file != 'logo_cefe.png' and not any(file.endswith(ext) for ext in keep):
                os.remove(os.path.join(root, file))

class lw:
    def __init__(self, Reg_json, pred):
        self.model_ = Reg_json['model']
        self.best_hyperparams_ = Reg_json['best_lwplsr_params']
        self.pred_data_ = [json_normalize(Reg_json[i]) for i in pred]


################ clean the results dir #############
delete_files(keep = ['.py', '.pyc','.bib'])
for i in ['model', 'dataset', 'figures']:
    dirpath = Path('./report/out/')/i
    if not dirpath.exists():
        dirpath.mkdir(parents=True, exist_ok=True)
# ####################################### page preamble #######################################
st.header("Calibration Model Development") # page title
st.markdown("Create a predictive model, then use it for predicting your target variable (chemical data) from NIRS spectra")
c0, c1 = st.columns([1, .4])
c0.image("./images/model_creation.png", use_column_width = True) # graphical abstract

################################################################# Begin : I- Data loading and preparation ######################################
files_format = ['csv', 'dx'] # Supported files format
file = c1.radio('Select files format:', options = files_format,horizontal = True) # Select a file format

spectra = DataFrame() # preallocate the spectral data block
y = DataFrame() # preallocate the target(s) data block

match file:
    # load csv file
    case 'csv':
        with c1:
            # Load X-block data
            xcal_csv = st.file_uploader("Select NIRS Data", type = "csv", help = " :mushroom: select a csv matrix with samples as rows and lambdas as columns")
            if xcal_csv:
                sepx = st.radio("Select separator (X file): " , options = [";", ","], key = 0,horizontal = True)
                hdrx = st.checkbox("Samples name (X file): ")
                colx = 0 if hdrx else False
            else:
                st.info('Info: Insert your spectral data file above!')
            
            # Load Y-block data
            ycal_csv = st.file_uploader("Select corresponding Chemical Data", type = "csv", help = " :mushroom: select a csv matrix with samples as rows and chemical values as a column")
            if ycal_csv:
                sepy = st.radio("Select separator (Y file): ", options = [";",  ","], key = 2, horizontal = True)
                hdry = st.checkbox("samples name (Y file)?: ")
                coly = 0 if hdry else False



            else:
                st.info('Info: Insert your target data file above!')


            # AFTER LOADING BOTH X AND Y FILES
            if xcal_csv and ycal_csv:
                # create a str instance for storing the hash of both x and y data
                xy_str = ''
                from io import StringIO
                for i in ["xcal_csv", "ycal_csv"]:
                    stringio = StringIO(eval(f'{i}.getvalue().decode("utf-8")'))
                    xy_str += str(stringio.read())
                p_hash([xy_str + str(xcal_csv.name) + str(ycal_csv.name), hdrx, sepx, hdry, sepy])
                # p_hash(add = )
                
                @st.cache_data
                def csv_loader(change):
                    delete_files(keep = ['.py', '.pyc','.bib'])
                    file_name = str(xcal_csv.name) +' and '+ str(ycal_csv.name)
                    xfile = read_csv(xcal_csv, decimal = '.', sep = sepx, index_col = colx, header = 0)
                    yfile =  read_csv(ycal_csv, decimal = '.', sep = sepy, index_col = coly)
                    return xfile, yfile, file_name
                
                xfile, yfile, file_name = csv_loader(change = hash_)



                if yfile.shape[1]>0 and xfile.shape[1]>0 :

                    # prepare x data
                    try: 
                        spectra, meta_data = col_cat(xfile)
                    except:
                        st.error('Error: The format of the X-file does not correspond to the expected dialect settings. To read the file correctly, please adjust the separator parameters.')
                    spectra = DataFrame(spectra).astype(float)
                    
                    # prepare y data
                    try:
                        chem_data, idx = col_cat(yfile)
                    except:
                        st.error('Error: The format of the Y-file does not correspond to the expected dialect settings. To read the file correctly, please adjust the separator parameters.')

                    if 'chem_data' in globals():
                        if chem_data.shape[1]>1:
                            yname = c1.selectbox('Select a target', options = ['']+chem_data.columns.tolist(), format_func = lambda x: x if x else "<Select>")
                            if yname:
                                y = chem_data.loc[:, yname]
                            else:
                                c1.info('Info: Select the target analyte from the drop down list!')
                        elif chem_data.shape[1] == 1:
                            y = chem_data.iloc[:, 0]
                            yname = chem_data.iloc[:, [0]].columns[0]
                        
                    ### warning
                    if not y.empty:
                        if spectra.shape[0] != y.shape[0]:
                            st.error('Error: X and Y have different sample size')
                            y = DataFrame
                            spectra = DataFrame

                else:
                    st.error('Error: The data has not been loaded successfully, please consider tuning the dialect settings!')
    
    # Load .dx file
    case 'dx':
        with c1:
            data_file = st.file_uploader("Select Data", type = ".dx", help = " :mushroom: select a dx file")
            if data_file:
                file_name = str(data_file.name)
                ## creating the temp file
                with NamedTemporaryFile(delete = False, suffix = ".dx") as tmp:
                    tmp.write(data_file.read())
                    tmp_path = tmp.name
                    with open(tmp.name, 'r') as dd:
                        dxdata = dd.read()
                        p_hash(str(dxdata)+str(data_file.name))

                ## load and parse the temp dx file
                @st.cache_data
                def dx_loader(change):
                    chem_data, spectra, meta_data, meta_data_st = read_dx(file =  tmp_path)    
                    os.unlink(tmp_path)
                    return chem_data, spectra, meta_data, meta_data_st
                chem_data, spectra, meta_data, meta_data_st = dx_loader(change = hash_)
                
                if not spectra.empty:
                    st.success("Info: The data have been loaded successfully", icon = "✅")

                if chem_data.shape[1]>0:
                    yname = st.selectbox('Select the target analyte', options = ['']+chem_data.columns.tolist(), format_func = lambda x: x if x else "<Select>" )
                    if yname:
                        measured = chem_data.loc[:, yname] > 0
                        y = chem_data.loc[:, yname].loc[measured]
                        spectra = spectra.loc[measured]
                        
                        
                    else:
                        st.info('Info: Please select the target analyte from the dropdown list!')
                else:
                    st.warning('Warning: your file includes no target variables to model !', icon = "⚠️")


            else :
                st.info('Info: Load your file here!')
################################################### END : I- Data loading and preparation ####################################################






################################################### BEGIN : visualize and split the data ####################################################
st.subheader("I - Data visualization", divider = 'blue')
if not spectra.empty and not y.empty:
    p_hash(y)
    p_hash(np.mean(spectra))
    if np.array(spectra.columns).dtype.kind in ['i', 'f']:
        colnames = spectra.columns
    else:
        colnames = np.arange(spectra.shape[1])
    
    X_train, X_test, y_train, y_test, train_index, test_index = data_split(x=spectra, y=y)
    


    #### insight on loaded data
    spectra_plot = plot_spectra(spectra, xunits = 'Wavelength/Wavenumber', yunits = "Signal intensity")
    target_plot = hist(y = y, y_train = y_train, y_test = y_test, target_name=yname)
    stats = DataFrame([desc_stats(y_train), desc_stats(y_test), desc_stats(y)], index =['train', 'test', 'total'] ).round(2) 

    # fig1, ax1 = plt.subplots( figsize = (12, 3))
    # spectra.T.plot(legend = False, ax = ax1, linestyle = '-', linewidth = 0.6)
    # ax1.set_ylabel('Signal intensity')
    # ax1.margins(0)
    # plt.tight_layout()
    c2, c3 = st.columns([1, .4])
    with c2:
        st.pyplot(spectra_plot) ######## Loaded graph
        st.pyplot(target_plot)

    with c3:
        st.write('Loaded data summary')
        st.write(stats)

################################################### END : visualize and split the data #######################################################




# if 'model_type' not in st.session_state:
#     st.cache_data.model_type = ''

#     ###################################################     BEGIN : Create Model     ####################################################
model_type = None # initialize the selected regression algorithm
Reg = None  # initialize the regression model object
# intervalls_with_cols = DataFrame()

st.subheader("II - Model creation", divider = 'blue')
if not spectra.empty and not y.empty:
    c4, c5, c6 = st.columns([1, 1, 3])
    with c4:
        # select type of supervised modelling problem
        var_nature = ['Continuous', 'Categorical']
        mode = c4.radio("The nature of the target variable :", options = var_nature)
        p_hash(mode)
        match mode:
            case "Continuous":
                reg_algo = ["", "PLS", "LW-PLS", "TPE-iPLS"]
                st.markdown(f'Example1: Quantifying the volume of nectar consumed by a pollinator during a foraging session.')
                st.markdown(f"Example2: Measure the sugar content, amino acids, or other compounds in nectar from different flower species.")
            case 'Categorical':
                reg_algo = ["", "PLS", "LW-PLS", "TPE-iPLS", 'LDA']
                st.markdown(f"Example1: Classifying pollinators into categories such as bees, butterflies, moths, and beetles.")
                st.markdown(f"Example2: Classifying plants based on their health status, such as healthy, stressed, or diseased, using NIR spectral data.")
    with c5:
        model_type = c5.selectbox("Choose a modelling algorithm:", options = reg_algo, key = 12, format_func = lambda x: x if x else "<Select>")
    
    with c6:
        st.markdown("-------------")
        match model_type:
            case "PLS":
                st.markdown("#### For further details on the PLS (Partial Least Squares) algorithm, check the following reference:")
                st.markdown('##### https://www.tandfonline.com/doi/abs/10.1080/03610921003778225')
                
            case "LW-PLS":
                st.markdown("#### For further details on the LW-PLS (Locally Weighted - Partial Least Squares) algorithm, check the following reference:")
                st.markdown('##### https://analyticalsciencejournals.onlinelibrary.wiley.com/doi/full/10.1002/cem.3117')
            
            case "TPE-iPLS":
                st.markdown("#### For further details on the TPE-iPLS (Tree-structured Parzen Estimator based interval-Partial Least Squares) algorithm, which is a wrapper method for interval selection, check the following references:")
                st.markdown("##### https://papers.nips.cc/paper_files/paper/2011/file/86e8f7ab32cfd12577bc2619bc635690-Paper.pdf")
                st.markdown('##### https://www.tandfonline.com/doi/abs/10.1080/03610921003778225')
                st.markdown('##### https://journals.sagepub.com/doi/abs/10.1366/0003702001949500')
        st.markdown("-------------")

    # if  model_type != st.session_state.model_type:
    #     st.session_state.model_type = model_type
    #     increment()
    
    p_hash(model_type)


    # Training set preparation for cross-validation(CV)
    nb_folds = 3

    # Model creation-M20 columns
    with c5:
        @st.cache_data
        def RequestingModelCreation(change):
            # spectra_plot.savefig("./report/figures/spectra_plot.png")
            # target_plot.savefig("./report/figures/histogram.png")
            # st.session_state['hash_Reg'] = str(np.random.randint(2000000000))
            folds = KF_CV.CV(X_train, y_train, nb_folds)# split train data into nb_folds for cross_validation

            match model_type:
                case 'PLS':
                    Reg = Plsr(train = [X_train, y_train], test = [X_test, y_test], n_iter = 100, cv = nb_folds)
                    # reg_model = Reg.model_
                    rega = Reg.selected_features_

                case 'LW-PLS':
                    # export data to csv for Julia train/test
                    global x_train_np, y_train_np, x_test_np, y_test_np
                    data_to_work_with = ['x_train_np', 'y_train_np', 'x_test_np', 'y_test_np']
                    x_train_np, y_train_np, x_test_np, y_test_np = X_train.to_numpy(), y_train.to_numpy(), X_test.to_numpy(), y_test.to_numpy()
                    # Cross-Validation calculation
                    d = {}
                    for i in range(nb_folds):
                        d["xtr_fold{0}".format(i+1)], d["ytr_fold{0}".format(i+1)], d["xte_fold{0}".format(i+1)], d["yte_fold{0}".format(i+1)] = np.delete(x_train_np, folds[list(folds)[i]], axis=0), np.delete(y_train_np, folds[list(folds)[i]], axis=0), x_train_np[folds[list(folds)[i]]], y_train_np[folds[list(folds)[i]]]
                        data_to_work_with.append("xtr_fold{0}".format(i+1))
                        data_to_work_with.append("ytr_fold{0}".format(i+1))
                        data_to_work_with.append("xte_fold{0}".format(i+1))
                        data_to_work_with.append("yte_fold{0}".format(i+1))
                    # check best pre-treatment with a global PLSR model
                    preReg = Plsr(train = [X_train, y_train], test = [X_test, y_test], n_iter=100)
                    temp_path = Path('temp/')
                    with open(temp_path / "lwplsr_preTreatments.json", "w+") as outfile:
                        json.dump(preReg.best_hyperparams_, outfile)
                    # export Xtrain, Xtest, Ytrain, Ytest and all CV folds to temp folder as csv files
                    for i in data_to_work_with:
                        if 'fold' in i:
                            j = d[i]
                        else:
                            j = globals()[i]
                            # st.write(j)
                        np.savetxt(temp_path / str(i + ".csv"), j, delimiter=",")
                    open(temp_path / 'model', 'w').close()
                    # run Julia Jchemo as subprocess
                    import subprocess
                    subprocess_path = Path("utils/")
                    subprocess.run([f"{sys.executable}", subprocess_path / "LWPLSR_Call.py"])
                    # retrieve json results from Julia JChemo
                    try:
                        with open(temp_path / "lwplsr_outputs.json", "r") as outfile:
                            Reg_json = json.load(outfile)
                            # delete csv files
                            for i in data_to_work_with: os.unlink(temp_path / str(i + ".csv"))
                        # delete json file after import
                        os.unlink(temp_path / "lwplsr_outputs.json")
                        os.unlink(temp_path / "lwplsr_preTreatments.json")
                        os.unlink(temp_path / 'model')
                        # format result data into Reg object
                        pred = ['pred_data_train', 'pred_data_test']### keys of the dict
                        for i in range(nb_folds):
                            pred.append("CV" + str(i+1)) ### add cv folds keys to pred
                            
                        # global Reg
                        # Reg = type('obj', (object,), {'model_' : Reg_json['model'], 'best_hyperparams_' : Reg_json['best_lwplsr_params'],
                        #                             'pred_data_' : [json_normalize(Reg_json[i]) for i in pred]})
                        # global Reg
                        Reg = lw(Reg_json = Reg_json, pred = pred)
                        # reg_model = Reg.model_
                        Reg.CV_results_ = DataFrame()
                        Reg.cv_data_ = {'YpredCV' : {}, 'idxCV' : {}}
                        # set indexes to Reg.pred_data (train, test, folds idx)
                        for i in range(len(pred)):
                            Reg.pred_data_[i] = Reg.pred_data_[i].T.reset_index().drop(columns = ['index'])
                            if i == 0: # data_train
                                # Reg.pred_data_[i] = np.array(Reg.pred_data_[i])
                                Reg.pred_data_[i].index = list(y_train.index)
                                Reg.pred_data_[i] = Reg.pred_data_[i].iloc[:,0]
                            elif i == 1: # data_test
                                # Reg.pred_data_[i] = np.array(Reg.pred_data_[i])
                                Reg.pred_data_[i].index = list(y_test.index)
                                Reg.pred_data_[i] = Reg.pred_data_[i].iloc[:,0]
                            else:
                                # CVi
                                Reg.pred_data_[i].index = folds[list(folds)[i-2]]
                                # Reg.CV_results_ = concat([Reg.CV_results_, Reg.pred_data_[i]])
                                Reg.cv_data_['YpredCV']['Fold' + str(i-1)] = np.array(Reg.pred_data_[i]).reshape(-1)
                                Reg.cv_data_['idxCV']['Fold' + str(i-1)] = np.array(folds[list(folds)[i-2]]).reshape(-1)

                        Reg.CV_results_= KF_CV.metrics_cv(y = y_train, ypcv = Reg.cv_data_['YpredCV'], folds = folds)[1]
                        #### cross validation results print
                        Reg.best_hyperparams_print = Reg.best_hyperparams_
                        ## plots
                        Reg.cv_data_ = KF_CV().meas_pred_eq(y = np.array(y_train), ypcv = Reg.cv_data_['YpredCV'], folds = folds)
                        Reg.pretreated_spectra_ = preReg.pretreated_spectra_
                        
                        Reg.best_hyperparams_print = {**preReg.best_hyperparams_, **Reg.best_hyperparams_}
                        Reg.best_hyperparams_ = {**preReg.best_hyperparams_, **Reg.best_hyperparams_}

                        Reg.__hash__ = hash_data(Reg.best_hyperparams_print)
                    except FileNotFoundError as e:
                        Reg = None
                        for i in data_to_work_with: os.unlink(temp_path / str(i + ".csv"))

                case 'TPE-iPLS':
                    Reg = TpeIpls(train = [X_train, y_train], test=[X_test, y_test], n_intervall = s, n_iter=it, cv = nb_folds)
                    # reg_model = Reg.model_
                    
                    global intervalls, intervalls_with_cols
                    intervalls = Reg.selected_features_.T.copy()
                    intervalls_with_cols = Reg.selected_features_.T.copy().astype(str)
                    
                    for i in range(intervalls.shape[0]):
                        for j in range(intervalls.shape[1]):
                            intervalls_with_cols.iloc[i,j] = spectra.columns[intervalls.iloc[i,j]]
                    rega = Reg.selected_features_

                    st.session_state.intervalls = Reg.selected_features_.T
                    st.session_state.intervalls_with_cols = intervalls_with_cols
            return Reg
        




        if model_type:
            info = st.info('Info: The model is being created. This may take a few minutes.')
            if model_type == 'TPE-iPLS':# if model type is ipls then ask for the number of iterations and intervalls
                s = st.number_input(label = 'Enter the maximum number of intervals', min_value = 1, max_value = 6)
                it = st.number_input(label = 'Enter the number of iterations', min_value = 2, max_value = 500, value = 250)
            else:
                s, it = None, None
            p_hash(str(s)+str(it))
                
            remodel_button = st.button('re-model the data', key=4, help=None, type="primary", use_container_width=True, on_click=increment)
            p_hash(st.session_state.counter)
            Reg = RequestingModelCreation(change = hash_)
            reg_model = Reg.model_
            hash_ = hash(Reg)
        else:
            st.info('Info: Choose a modelling algorithm from the dropdown list!')
                
        if model_type:
            info.empty()
            if Reg:
                st.success('Success! Your model has been created and is ready to use.')
            else:
                st.error("Error: Model creation failed. Please try again.")
        
        if model_type:
            if model_type == 'TPE-iPLS':
                 if ('intervalls' and 'intervalls_with_cols') in st.session_state:
                    intervalls = st.session_state.intervalls
                    intervalls_with_cols = st.session_state.intervalls_with_cols

if Reg:
    # remodel_button = st.button('re-model the data', key=4, help=None, type="primary", use_container_width=True)
    # if remodel_button:# remodel feature for re-tuning the model
    #     increment()


    # fitted values and predicted  values 
    yc = Reg.pred_data_[0]
    yt = Reg.pred_data_[1]

    
    c7, c8 = st.columns([2 ,4])
    with c7:
        # Show and export the preprocessing methods
        st.write('-- Spectral preprocessing info --')
        st.write(Reg.best_hyperparams_print)
        @st.cache_data(show_spinner =False)
        def preprocessings(change):
            with open('report/out/Preprocessing.json', "w") as outfile:
                json.dump(Reg.best_hyperparams_, outfile)
        preprocessings(change=hash_)

        # Show the model performance table
        st.write("-- Model performance --")
        if model_type != reg_algo[2]:
            model_per = DataFrame(metrics(c = [y_train, yc], t = [y_test, yt], method = 'regression').scores_)
        else:
            model_per = DataFrame(metrics(c = [y_train, yc], t = [y_test, yt], method = 'regression').scores_)    
        st.dataframe(model_per)


    
    # M1.dataframe(model_per) # duplicate with line 371
    @st.cache_data(show_spinner =False)
    def prep_important(change, model_type, model_hash):
        fig, (ax1, ax2) = plt.subplots(2,1, figsize = (12, 4), sharex=True)
        ax1.plot(colnames, np.mean(X_train, axis = 0), color = 'black', label = 'Average spectrum (Raw)')
        # if model_type != reg_algo[2]:
        ax2.plot(colnames, np.mean(Reg.pretreated_spectra_ , axis = 0), color = 'black', label = 'Average spectrum (Pretreated)')
        ax2.set_xlabel('Wavelenghts')
        plt.tight_layout()

        for i in range(2):
            eval(f'ax{i+1}').grid(color = 'grey', linestyle = ':', linewidth = 0.2)
            eval(f'ax{i+1}').margins(x = 0)
            eval(f'ax{i+1}').legend(loc = 'upper right')
            eval(f'ax{i+1}').set_ylabel('Intensity')
            if model_type == 'TPE-iPLS':
                a = change
                for j in range(s):
                    if np.array(spectra.columns).dtype.kind in ['i','f']:
                        min, max = intervalls_with_cols.iloc[j,0], intervalls_with_cols.iloc[j,1]
                    else:
                        min, max = intervalls.iloc[j,0], intervalls.iloc[j,1]

                    eval(f'ax{i+1}').axvspan(min, max, color = '#00ff00', alpha = 0.5, lw = 0)

        if model_type == 'PLS':
            ax1.scatter(colnames[np.array(Reg.sel_ratio_.index)], np.mean(X_train, axis = 0).iloc[np.array(Reg.sel_ratio_.index)],
                            color = '#7ab0c7', label = 'Important variables')
            ax2.scatter(colnames[Reg.sel_ratio_.index], np.mean(Reg.pretreated_spectra_, axis = 0)[np.array(Reg.sel_ratio_.index)],
                            color = '#7ab0c7', label = 'Important variables')
            ax1.legend()
            ax2.legend()
        return fig
    
    with c8:## Visualize raw,preprocessed spectra, and selected intervalls(in case of ipls) 
        if model_type =='TPE-iPLS' :
                st.write('-- Important Spectral regions used for model creation --')
                st.table(intervalls_with_cols)
        st.write('-- Visualization of the spectral regions used for model creation --')
        imp_fig = prep_important(change = st.session_state.counter, model_type = model_type, model_hash = hash_)
        st.pyplot(imp_fig)

        # Display CV results
    numbers_dict = {1: "One", 2: "Two",3: "Three",4: "Four",5: "Five",
                    6: "Six",7: "Seven",8: "Eight",9: "Nine",10: "Ten"}
    st.subheader(f" {numbers_dict[nb_folds]}-Fold Cross-Validation results")
    
    @st.cache_data(show_spinner =False)
    def cv_display(change):
        fig1 = px.scatter(Reg.cv_data_[0], x = 'Measured', y = 'Predicted' , trendline = 'ols', color = 'Folds', symbol = 'Folds',
                color_discrete_sequence=px.colors.qualitative.G10)
        fig1.add_shape(type = 'line', x0 = .95 * min(Reg.cv_data_[0].loc[:,'Measured']), x1 = 1.05 * max(Reg.cv_data_[0].loc[:,'Measured']),
                        y0 = .95 * min(Reg.cv_data_[0].loc[:,'Measured']), y1 = 1.05 * max(Reg.cv_data_[0].loc[:,'Measured']), line = dict(color = 'black', dash = "dash"))
        fig1.update_traces(marker_size = 7, showlegend=False)
        
        fig0 = px.scatter(Reg.cv_data_[0], x ='Measured', y = 'Predicted' , trendline = 'ols', color = 'Folds', symbol = "Folds", facet_col = 'Folds',facet_col_wrap = 1,
                color_discrete_sequence = px.colors.qualitative.G10, text = 'index', width = 800, height = 1000)
        fig0.update_traces(marker_size = 8, showlegend = False)
        return fig0, fig1
    fig0, fig1 = cv_display(change= Reg.cv_data_)
    
    cv1, cv2 = st.columns([2, 2])
    with cv2:
        cv_results = DataFrame(Reg.CV_results_).round(4)# CV table
        st.write('-- Cross-Validation Summary--')
        st.write(cv_results.astype(str).style.map(lambda _: "background-color: #cecece;", subset = (cv_results.index.drop(['sd', 'mean', 'cv']), slice(None))))
        
        st.write('-- Out-of-Fold Predictions Visualization (All in one) --')
        st.plotly_chart(fig1, use_container_width = True)

    with cv1:
        st.write('-- Out-of-Fold Predictions Visualization (Separate plots) --')
        st.plotly_chart(fig0, use_container_width=True)
    

    ###################################################    BEGIN : Model Diagnosis    ####################################################
st.subheader("III - Model Diagnosis", divider='blue')
if Reg:
    # signal preprocessing results preparation for latex report
    prep_para = Reg.best_hyperparams_.copy()
    if model_type != reg_algo[2]:
        prep_para.pop('n_components')
        for i in ['deriv','polyorder']:
            if Reg.best_hyperparams_[i] == 0:
                prep_para[i] = '0'
            elif Reg.best_hyperparams_[i] == 1:
                prep_para[i] = '1st'
            elif Reg.best_hyperparams_[i] > 1:
                prep_para[i] = f"{Reg.best_hyperparams_[i]}nd"
    
    # reg plot and residuals plot
    if model_type != reg_algo[2]:
        measured_vs_predicted = reg_plot([y_train, y_test],[yc, yt], train_idx = train_index, test_idx = test_index)
        residuals_plot = resid_plot([y_train, y_test], [yc, yt], train_idx = train_index, test_idx = test_index)
    else:
        measured_vs_predicted = reg_plot([y_train, y_test],[yc, yt], train_idx = train_index, test_idx = test_index)
        residuals_plot = resid_plot([y_train, y_test], [yc, yt], train_idx=train_index, test_idx=test_index)
    
    M7, M8 = st.columns([2,2])
    with M7:
        st.write('Predicted vs Measured values')
        st.pyplot(measured_vs_predicted)
        # regression_plot.savefig('./report/figures/measured_vs_predicted.png')
    
    with M8:
        st.write('Residuals plot')
        st.pyplot(residuals_plot)
        # residual_plot.savefig('./report/figures/residuals_plot.png')

###################################################      END : Model Diagnosis   #######################################################
    
###################################################    BEGIN : Download results    #######################################################
##########################################################################################################################################
##########################################################################################################################################
if Reg:
    zip_data = ""
    st.header('Download the analysis results')
    st.write("**Note:** Please check the box only after you have finished processing your data and are satisfied with the results. Checking the box prematurely may slow down the app and could lead to crashes.")
    decis = st.checkbox("Yes, I want to download the results")
    if decis:
        @st.cache_data(show_spinner =False)
        def export_report(change):
            match model_type:
                case 'PLS':
                        latex_report = report.report('Predictive model development', file_name, stats, list(prep_para.values()), model_type, model_per, cv_results)
                        

                case 'LW-PLS':
                        latex_report = report.report('Predictive model development', file_name, stats,
                                                    list({key: Reg.best_hyperparams_[key] for key in ['deriv', 'normalization', 'polyorder', 'window_length'] if key in Reg.best_hyperparams_}.values()), model_type, model_per, cv_results)
                        
                case 'TPE-iPLS':
                        latex_report = report.report('Predictive model development', file_name, stats,
                                                    list({key: Reg.best_hyperparams_[key] for key in ['deriv', 'normalization', 'polyorder', 'window_length'] if key in Reg.best_hyperparams_}.values()), model_type, model_per, cv_results)
                        
                case _:
                    st.warning('Data processing has not been performed or finished yet!', icon = "⚠️")

        @st.cache_data(show_spinner =False)
        def preparing_results_for_downloading(change):
            match file:
                # load csv file
                case 'csv':
                    xfile.to_csv('report/out/dataset/'+ xcal_csv.name, sep = ';', encoding = 'utf-8', mode = 'a')
                    yfile.to_csv('report/out/dataset/'+ ycal_csv.name, sep = ';', encoding = 'utf-8', mode = 'a')
                case 'dx':
                    with open('report/out/dataset/'+data_file.name, 'w') as dd:
                        dd.write(dxdata)
                                    
            with open('./report/out/model/'+ model_type + '.pkl','wb') as f:# export model
                dump(reg_model, f)
            figpath ='./report/out/figures/'
            spectra_plot.savefig(figpath + "spectra_plot.png")
            target_plot.savefig(figpath + "histogram.png")
            imp_fig.savefig(figpath + "variable_importance.png")
            fig1.write_image(figpath + "meas_vs_pred_cv_all.png")
            fig0.write_image(figpath + "meas_vs_pred_cv_onebyone.png")
            measured_vs_predicted.savefig(figpath + 'measured_vs_predicted.png')
            residuals_plot.savefig(figpath + 'residuals_plot.png')
            # with open('report/out/Preprocessing.json', "w") as outfile:
            #     json.dump(Reg.best_hyperparams_, outfile)
            
            if model_type == 'TPE-iPLS': # export selected wavelengths
                wlfilename = './report/out/model/'+ model_type+'-selected_wavelengths.xlsx'
                all = concat([intervalls_with_cols.T, Reg.selected_features_], axis = 0,  ignore_index=True).T
                all.columns=['wl_from','wl_to','idx_from', 'idx_to']
                all.to_excel(wlfilename)
            
            export_report(change = hash_)
            if Path("./report/report.tex").exists():
                report.generate_report(change = hash_)
            if Path("./report/report.pdf").exists():
                move("./report/report.pdf", "./report/out/report.pdf")
            
            # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
            # pklfile = {'model_': Reg.model_,"model_type" : model_type, 'training_data':{'raw-spectra':spectra,'target':y, },
            #         'spec-preprocessing':{"normalization": Reg.best_hyperparams_['normalization'], 'SavGol(polyorder,window_length,deriv)': [Reg.best_hyperparams_["polyorder"],
            #                                                                                                                                    Reg.best_hyperparams_['window_length'],
            #                                                                                                                                    Reg.best_hyperparams_['deriv']]}}
            pklfile = {'model_': Reg.model_,"model_type" : model_type, 'data':{'raw-spectra':spectra,'target':y, 'training_data_idx':train_index,'testing_data_idx':test_index},
                    'spec-preprocessing':{"normalization": Reg.best_hyperparams_['normalization'], 'SavGol(polyorder,window_length,deriv)': [Reg.best_hyperparams_["polyorder"],
                                                                                                                                               Reg.best_hyperparams_['window_length'],
                                                                                                                                               Reg.best_hyperparams_['deriv']]}}
            if model_type == 'TPE-iPLS': # export selected wavelengths
                pklfile['selected-wls'] = {'idx':Reg.selected_features_.T , "wls":intervalls_with_cols }
            elif model_type == 'LW-PLS': # export LWPLS best model parameters
                pklfile['selected-wls'] = {'idx':None, "wls":None }
                pklfile['lwpls_params'] = Reg.best_hyperparams_
            else:
                pklfile['selected-wls'] = {'idx':None, "wls":None }
                    
            with open('./report/out/file_system.pkl', "wb") as pkl:
                dump(pklfile, pkl)

            return change
        preparing_results_for_downloading(change = hash_)
        
    
        @st.cache_data(show_spinner =False)
        def tempdir(change):
            with  TemporaryDirectory( prefix="results", dir="./report") as temp_dir:# create a temp directory
                tempdirname = os.path.split(temp_dir)[1]

                if len(os.listdir('./report/out/figures/'))>2:
                    make_archive(base_name="./report/Results", format="zip", base_dir="out", root_dir = "./report")# create a zip file
                    move("./report/Results.zip", f"./report/{tempdirname}/Results.zip")# put the inside the temp dir
                    with open(f"./report/{tempdirname}/Results.zip", "rb") as f:
                        zip_data = f.read()
            return tempdirname, zip_data

        try :
            tempdirname, zip_data = tempdir(change = hash_)
        except:
            pass
    date_time = datetime.now().strftime('%y%m%d%H%M')
    disabled_down = True if zip_data=='' else False
    st.download_button(label = 'Download', data = zip_data, file_name = f'Nirs_Workflow_{date_time}_Reg_.zip', mime ="application/zip",
                args = None, kwargs = None,type = "primary",use_container_width = True, disabled = disabled_down)


    delete_files(keep = ['.py', '.pyc','.bib'])