from common import *
st.set_page_config(page_title = "NIRS Utils", page_icon = ":goat:", layout = "wide")





# layout
UiComponents(pagespath = pages_folder, csspath= css_file,imgpath=image_path ,
             header=True, sidebar= True, bgimg=False, colborders=True)
hash_ = ''

# Initialize the variable in session state if it doesn't exist for st.cache_data
if 'counter' not in st.session_state:
    st.session_state.counter = 0
def increment():
    st.session_state.counter += 1

# ####################################  Methods ##############################################
def delete_files(keep):
    supp = []
    # Walk through the directory
    for root, dirs, files in os.walk('report/', topdown=False):
        for file in files:
            if file != 'logo_cefe.png' and not any(file.endswith(ext) for ext in keep):
                os.remove(os.path.join(root, file))

class lw:
    def __init__(self, Reg_json, pred):
        self.model_ = Reg_json['model']
        self.best_hyperparams_ = Reg_json['best_lwplsr_params']
        self.pred_data_ = [json_normalize(Reg_json[i]) for i in pred]


################ clean the results dir #############
delete_files(keep = ['.py', '.pyc','.bib'])
for i in ['model', 'dataset', 'figures']:
    dirpath = Path('./report/out/')/i
    if not dirpath.exists():
        dirpath.mkdir(parents=True, exist_ok=True)
# ####################################### page preamble #######################################
st.header("Calibration Model Development") # page title
st.markdown("Create a predictive model, then use it for predicting your target variable (chemical data) from NIRS spectra")
c0, c1 = st.columns([1, .4])
c0.image("./images/model_creation.png", use_column_width = True) # graphical abstract

################################################################# Begin : I- Data loading and preparation ######################################
files_format = ['csv', 'dx'] # Supported files format
file = c1.radio('Select files format:', options = files_format,horizontal = True) # Select a file format

spectra = DataFrame() # preallocate the spectral data block
y = DataFrame() # preallocate the target(s) data block

match file:
    # load csv file
    case 'csv':
        from utils.data_parsing import CsvParser
        def read_csv(file = file, change = None, dec = None, sep= None, names = None, hdr = None):
            delete_files(keep = ['.py', '.pyc','.bib'])
            from utils.data_parsing import CsvParser
            par = CsvParser(file= file)
            par.parse(decimal = dec, separator = sep, index_col = names, header = hdr)
            return par.float, par.meta_data, par.meta_data_st_, par.df

        with c1:
            # Load X-block data
            xcal_csv = st.file_uploader("Select NIRS Data", type = "csv", help = " :mushroom: select a csv matrix with samples as rows and lambdas as columns")
            if xcal_csv:
                c1_1, c2_2 = st.columns([.5, .5])
                with c1_1:
                    decx = st.radio('decimal(x):', options= [".", ","], horizontal = True)
                    sepx = st.radio("separator(x):", options = [";", ","], horizontal = True)
                with c2_2:
                    phdrx = st.radio("header(x): ", options = ["yes", "no"], horizontal = True)
                    pnamesx = st.radio("samples name(x):", options = ["yes", "no"], horizontal = True)

                hdrx = 0 if phdrx =="yes" else None
                namesx = 0 if pnamesx =="yes" else None
                try:
                    spectra, meta_data, md_df_st_, xfile = read_csv(file= xcal_csv, change = hash_, dec = decx, sep = sepx, names =namesx, hdr = hdrx)
                    st.success('xfile has been loaded successfully')
                except:
                    st.error('Error: The xfile has not been loaded successfully, please consider tuning the dialect settings!')
                
            else:
                st.info('Info: Insert your spectral data file above!')
            



            # Load Y-block data
            ycal_csv = st.file_uploader("Select corresponding Chemical Data", type = "csv", help = " :mushroom: select a csv matrix with samples as rows and chemical values as a column")
            if ycal_csv:
                c1_1, c2_2 = st.columns([.5, .5])
                with c1_1:
                    decy = st.radio('decimal(y):', options= [".", ","], horizontal = True)
                    sepy = st.radio("separator(y):", options = [";", ","], horizontal = True)
                with c2_2:
                    phdry = st.radio("header(y): ", options = ["yes", "no"], horizontal = True)
                    pnamesy = st.radio("samples name(y):", options = ["yes", "no"], horizontal = True)

                hdry = 0 if phdry =="yes" else None
                namesy = 0 if pnamesy =="yes" else None
                try:
                    chem_data, meta_data, md_df_st_, yfile = read_csv(file= ycal_csv, change = hash_, dec = decy, sep = sepy, names =namesy, hdr = hdry)
                    st.success('yfile has been loaded successfully')
                except:
                    st.error('Error: The yfile has not been loaded successfully, please consider tuning the dialect settings!')

            else:
                st.info('Info: Insert your target data file above!')


            # AFTER LOADING BOTH X AND Y FILES
            if xcal_csv and ycal_csv:
                # create a str instance for storing the hash of both x and y data
                xy_str = ''
                from io import StringIO
                for i in ["xcal_csv", "ycal_csv"]:
                    stringio = StringIO(eval(f'{i}.getvalue().decode("utf-8")'))
                    xy_str += str(stringio.read())
                # p_hash([xy_str + str(xcal_csv.name) + str(ycal_csv.name), hdrx, sepx, hdry, sepy])
                hash_ = ObjectHash(current=hash_,add = xy_str)
                








            


                # xfile, yfile, file_name = csv_loader(change = hash_)
                # yfile =  read_csv(file= ycal_csv, change = hash_)



                if yfile.shape[1]>0 and xfile.shape[1]>0 :    
                    if 'chem_data' in globals():
                        if chem_data.shape[1] > 1:
                            yname = c1.selectbox('Select a target', options = ['']+chem_data.columns.tolist(), format_func = lambda x: x if x else "<Select>")
                            if yname:
                                y = chem_data.loc[:, yname]
                            else:
                                c1.info('Info: Select the target analyte from the drop down list!')
                        elif chem_data.shape[1] == 1:
                            y = chem_data.iloc[:, 0]
                            yname = chem_data.iloc[:, [0]].columns[0]
                        
                    ### warning
                    if not y.empty:
                        if spectra.shape[0] != y.shape[0]:
                            st.error('Error: X and Y have different sample size')
                            y = DataFrame
                            spectra = DataFrame

                else:
                    st.error('Error: The data has not been loaded successfully, please consider tuning the dialect settings!')
    
    # Load .dx file
    case 'dx':
        with c1:
            data_file = st.file_uploader("Select Data", type = ".dx", help = " :mushroom: select a dx file")
            if data_file:
                file_name = str(data_file.name)
                ## creating the temp file
                with NamedTemporaryFile(delete = False, suffix = ".dx") as tmp:
                    tmp.write(data_file.read())
                    tmp_path = tmp.name
                    with open(tmp.name, 'r') as dd:
                        dxdata = dd.read()
                        # p_hash(str(dxdata)+str(data_file.name))

                ## load and parse the temp dx file
                @st.cache_data
                def read_dx(tmp_path):
                    M = JcampParser(path = tmp_path)
                    M.parse()
                    # chem_data, spectra, meta_data, meta_data_st = read_dx(file =  tmp_path)    
                    # os.unlink(tmp_path)
                    return M.chem_data, M.specs_df_, M.meta_data, M.meta_data_st_
                chem_data, spectra, meta_data, meta_data_st = read_dx(tmp_path = tmp_path)
                
                if not spectra.empty:
                    st.success("Info: The data have been loaded successfully", icon = "✅")

                if chem_data.shape[1]>0:
                    yname = st.selectbox('Select the target analyte', options = ['']+chem_data.columns.tolist(), format_func = lambda x: x if x else "<Select>" )
                    if yname:
                        measured = chem_data.loc[:, yname] > 0
                        y = chem_data.loc[:, yname].loc[measured]
                        spectra = spectra.loc[measured]
                        
                        
                    else:
                        st.info('Info: Please select the target analyte from the dropdown list!')
                else:
                    st.warning('Warning: your file includes no target variables to model !', icon = "⚠️")


            else :
                st.info('Info: Load your file here!')
################################################### END : I- Data loading and preparation ####################################################






################################################### BEGIN : visualize and split the data ####################################################
st.subheader("I - Data visualization", divider = 'blue')
if not spectra.empty and not y.empty:
    # p_hash(y)
    # p_hash(np.mean(spectra))
    if np.array(spectra.columns).dtype.kind in ['i', 'f']:
        colnames = spectra.columns
    else:
        colnames = np.arange(spectra.shape[1])

    from utils.miscellaneous import data_split
    X_train, X_test, y_train, y_test, train_index, test_index = data_split(x=spectra, y=y)
    


    #### insight on loaded data
    spectra_plot = plot_spectra(spectra, xunits = 'Wavelength/Wavenumber', yunits = "Signal intensity")
    target_plot = hist(y = y, y_train = y_train, y_test = y_test, target_name=yname)
    from utils.miscellaneous import desc_stats
    stats = DataFrame([desc_stats(y_train), desc_stats(y_test), desc_stats(y)], index =['train', 'test', 'total'] ).round(2) 

    # fig1, ax1 = plt.subplots( figsize = (12, 3))
    # spectra.T.plot(legend = False, ax = ax1, linestyle = '-', linewidth = 0.6)
    # ax1.set_ylabel('Signal intensity')
    # ax1.margins(0)
    # plt.tight_layout()
    c2, c3 = st.columns([1, .4])
    with c2:
        st.pyplot(spectra_plot) ######## Loaded graph
        st.pyplot(target_plot)

    with c3:
        st.write('Loaded data summary')
        st.write(stats)

################################################### END : visualize and split the data #######################################################




# if 'model_type' not in st.session_state:
#     st.cache_data.model_type = ''

#     ###################################################     BEGIN : Create Model     ####################################################
model_type = None # initialize the selected regression algorithm
Reg = None  # initialize the regression model object
# intervalls_with_cols = DataFrame()

st.subheader("II - Model creation", divider = 'blue')
if not spectra.empty and not y.empty:
    c4, c5, c6 = st.columns([1, 1, 3])
    with c4:
        # select type of supervised modelling problem
        var_nature = ['Continuous', 'Categorical']
        mode = c4.radio("The nature of the target variable :", options = var_nature)
        # p_hash(mode)
        match mode:
            case "Continuous":
                reg_algo = ["", "PLS", "LW-PLS", "TPE-iPLS"]
                st.markdown(f'Example1: Quantifying the volume of nectar consumed by a pollinator during a foraging session.')
                st.markdown(f"Example2: Measure the sugar content, amino acids, or other compounds in nectar from different flower species.")
            case 'Categorical':
                reg_algo = ["", "PLS", "LW-PLS", "TPE-iPLS", 'LDA']
                st.markdown(f"Example1: Classifying pollinators into categories such as bees, butterflies, moths, and beetles.")
                st.markdown(f"Example2: Classifying plants based on their health status, such as healthy, stressed, or diseased, using NIR spectral data.")
    with c5:
        model_type = c5.selectbox("Choose a modelling algorithm:", options = reg_algo, key = 12, format_func = lambda x: x if x else "<Select>")
    
    with c6:
        st.markdown("-------------")
        match model_type:
            case "PLS":
                st.markdown("#### For further details on the PLS (Partial Least Squares) algorithm, check the following reference:")
                st.markdown('##### https://www.tandfonline.com/doi/abs/10.1080/03610921003778225')
                
            case "LW-PLS":
                st.markdown("#### For further details on the LW-PLS (Locally Weighted - Partial Least Squares) algorithm, check the following reference:")
                st.markdown('##### https://analyticalsciencejournals.onlinelibrary.wiley.com/doi/full/10.1002/cem.3117')
            
            case "TPE-iPLS":
                st.markdown("#### For further details on the TPE-iPLS (Tree-structured Parzen Estimator based interval-Partial Least Squares) algorithm, which is a wrapper method for interval selection, check the following references:")
                st.markdown("##### https://papers.nips.cc/paper_files/paper/2011/file/86e8f7ab32cfd12577bc2619bc635690-Paper.pdf")
                st.markdown('##### https://www.tandfonline.com/doi/abs/10.1080/03610921003778225')
                st.markdown('##### https://journals.sagepub.com/doi/abs/10.1366/0003702001949500')
        st.markdown("-------------")

    # if  model_type != st.session_state.model_type:
    #     st.session_state.model_type = model_type
    #     increment()
    
    # p_hash(model_type)


    # Training set preparation for cross-validation(CV)
    nb_folds = 3

    # Model creation-M20 columns
    with c5:
        @st.cache_data
        def RequestingModelCreation(change):
            # spectra_plot.savefig("./report/figures/spectra_plot.png")
            # target_plot.savefig("./report/figures/histogram.png")
            # st.session_state['hash_Reg'] = str(np.random.randint(2000000000))
            folds = KF_CV.CV(X_train, y_train, nb_folds)# split train data into nb_folds for cross_validation

            match model_type:
                case 'PLS':
                    from utils.regress import Plsr
                    Reg = Plsr(train = [X_train, y_train], test = [X_test, y_test], n_iter = 100, cv = nb_folds)
                    # reg_model = Reg.model_
                    rega = Reg.selected_features_

                case 'LW-PLS':
                    # export data to csv for Julia train/test
                    global x_train_np, y_train_np, x_test_np, y_test_np
                    data_to_work_with = ['x_train_np', 'y_train_np', 'x_test_np', 'y_test_np']
                    x_train_np, y_train_np, x_test_np, y_test_np = X_train.to_numpy(), y_train.to_numpy(), X_test.to_numpy(), y_test.to_numpy()
                    # Cross-Validation calculation
                    d = {}
                    for i in range(nb_folds):
                        d["xtr_fold{0}".format(i+1)], d["ytr_fold{0}".format(i+1)], d["xte_fold{0}".format(i+1)], d["yte_fold{0}".format(i+1)] = np.delete(x_train_np, folds[list(folds)[i]], axis=0), np.delete(y_train_np, folds[list(folds)[i]], axis=0), x_train_np[folds[list(folds)[i]]], y_train_np[folds[list(folds)[i]]]
                        data_to_work_with.append("xtr_fold{0}".format(i+1))
                        data_to_work_with.append("ytr_fold{0}".format(i+1))
                        data_to_work_with.append("xte_fold{0}".format(i+1))
                        data_to_work_with.append("yte_fold{0}".format(i+1))
                    # check best pre-treatment with a global PLSR model
                    from utils.regress import Plsr
                    preReg = Plsr(train = [X_train, y_train], test = [X_test, y_test], n_iter=100)
                    temp_path = Path('temp/')
                    with open(temp_path / "lwplsr_preTreatments.json", "w+") as outfile:
                        json.dump(preReg.best_hyperparams_, outfile)
                    # export Xtrain, Xtest, Ytrain, Ytest and all CV folds to temp folder as csv files
                    for i in data_to_work_with:
                        if 'fold' in i:
                            j = d[i]
                        else:
                            j = globals()[i]
                            # st.write(j)
                        np.savetxt(temp_path / str(i + ".csv"), j, delimiter=",")
                    open(temp_path / 'model', 'w').close()
                    # run Julia Jchemo as subprocess
                    import subprocess
                    subprocess_path = Path("utils/")
                    subprocess.run([f"{sys.executable}", subprocess_path / "lwplsr_call.py"])
                    # retrieve json results from Julia JChemo
                    try:
                        with open(temp_path / "lwplsr_outputs.json", "r") as outfile:
                            Reg_json = json.load(outfile)
                            # delete csv files
                            for i in data_to_work_with: os.unlink(temp_path / str(i + ".csv"))
                        # delete json file after import
                        os.unlink(temp_path / "lwplsr_outputs.json")
                        os.unlink(temp_path / "lwplsr_preTreatments.json")
                        os.unlink(temp_path / 'model')
                        # format result data into Reg object
                        pred = ['pred_data_train', 'pred_data_test']### keys of the dict
                        for i in range(nb_folds):
                            pred.append("CV" + str(i+1)) ### add cv folds keys to pred
                            
                        # global Reg
                        # Reg = type('obj', (object,), {'model_' : Reg_json['model'], 'best_hyperparams_' : Reg_json['best_lwplsr_params'],
                        #                             'pred_data_' : [json_normalize(Reg_json[i]) for i in pred]})
                        # global Reg
                        Reg = lw(Reg_json = Reg_json, pred = pred)
                        # reg_model = Reg.model_
                        Reg.CV_results_ = DataFrame()
                        Reg.cv_data_ = {'YpredCV' : {}, 'idxCV' : {}}
                        # set indexes to Reg.pred_data (train, test, folds idx)
                        for i in range(len(pred)):
                            Reg.pred_data_[i] = Reg.pred_data_[i].T.reset_index().drop(columns = ['index'])
                            if i == 0: # data_train
                                # Reg.pred_data_[i] = np.array(Reg.pred_data_[i])
                                Reg.pred_data_[i].index = list(y_train.index)
                                Reg.pred_data_[i] = Reg.pred_data_[i].iloc[:,0]
                            elif i == 1: # data_test
                                # Reg.pred_data_[i] = np.array(Reg.pred_data_[i])
                                Reg.pred_data_[i].index = list(y_test.index)
                                Reg.pred_data_[i] = Reg.pred_data_[i].iloc[:,0]
                            else:
                                # CVi
                                Reg.pred_data_[i].index = folds[list(folds)[i-2]]
                                # Reg.CV_results_ = concat([Reg.CV_results_, Reg.pred_data_[i]])
                                Reg.cv_data_['YpredCV']['Fold' + str(i-1)] = np.array(Reg.pred_data_[i]).reshape(-1)
                                Reg.cv_data_['idxCV']['Fold' + str(i-1)] = np.array(folds[list(folds)[i-2]]).reshape(-1)

                        Reg.CV_results_= KF_CV.metrics_cv(y = y_train, ypcv = Reg.cv_data_['YpredCV'], folds = folds)[1]
                        #### cross validation results print
                        Reg.best_hyperparams_print = Reg.best_hyperparams_
                        ## plots
                        Reg.cv_data_ = KF_CV().meas_pred_eq(y = np.array(y_train), ypcv = Reg.cv_data_['YpredCV'], folds = folds)
                        Reg.pretreated_spectra_ = preReg.pretreated_spectra_
                        
                        Reg.best_hyperparams_print = {**preReg.best_hyperparams_, **Reg.best_hyperparams_}
                        Reg.best_hyperparams_ = {**preReg.best_hyperparams_, **Reg.best_hyperparams_}

                        Reg.__hash__ = ObjectHash(current = hash_,add = Reg.best_hyperparams_print)
                    except FileNotFoundError as e:
                        Reg = None
                        for i in data_to_work_with: os.unlink(temp_path / str(i + ".csv"))

                case 'TPE-iPLS':
                    from utils.regress import TpeIpls
                    Reg = TpeIpls(train = [X_train, y_train], test=[X_test, y_test], n_intervall = s, n_iter=it, cv = nb_folds)
                    # reg_model = Reg.model_
                    
                    global intervalls, intervalls_with_cols
                    intervalls = Reg.selected_features_.T.copy()
                    intervalls_with_cols = Reg.selected_features_.T.copy().astype(str)
                    
                    for i in range(intervalls.shape[0]):
                        for j in range(intervalls.shape[1]):
                            intervalls_with_cols.iloc[i,j] = spectra.columns[intervalls.iloc[i,j]]
                    rega = Reg.selected_features_

                    st.session_state.intervalls = Reg.selected_features_.T
                    st.session_state.intervalls_with_cols = intervalls_with_cols
            return Reg
        




        if model_type:
            info = st.info('Info: The model is being created. This may take a few minutes.')
            if model_type == 'TPE-iPLS':# if model type is ipls then ask for the number of iterations and intervalls
                s = st.number_input(label = 'Enter the maximum number of intervals', min_value = 1, max_value = 6)
                it = st.number_input(label = 'Enter the number of iterations', min_value = 2, max_value = 500, value = 250)
            else:
                s, it = None, None
            hash_ = ObjectHash( current = hash_,add = str(s)+str(it))
                
            remodel_button = st.button('re-model the data', key=4, help=None, type="primary", use_container_width=True, on_click=increment)
            hash_ = ObjectHash(current = hash_, add = st.session_state.counter)
            Reg = RequestingModelCreation(change = hash_)
            reg_model = Reg.model_
            hash_ = hash(Reg)
        else:
            st.info('Info: Choose a modelling algorithm from the dropdown list!')
                
        if model_type:
            info.empty()
            if Reg:
                st.success('Success! Your model has been created and is ready to use.')
            else:
                st.error("Error: Model creation failed. Please try again.")
        
        if model_type:
            if model_type == 'TPE-iPLS':
                 if ('intervalls' and 'intervalls_with_cols') in st.session_state:
                    intervalls = st.session_state.intervalls
                    intervalls_with_cols = st.session_state.intervalls_with_cols

if Reg:
    # remodel_button = st.button('re-model the data', key=4, help=None, type="primary", use_container_width=True)
    # if remodel_button:# remodel feature for re-tuning the model
    #     increment()


    # fitted values and predicted  values 
    yc = Reg.pred_data_[0]
    yt = Reg.pred_data_[1]

    
    c7, c8 = st.columns([2 ,4])
    with c7:
        # Show and export the preprocessing methods
        st.write('-- Spectral preprocessing info --')
        st.write(Reg.best_hyperparams_print)
        @st.cache_data(show_spinner =False)
        def preprocessings(change):
            with open('report/out/Preprocessing.json', "w") as outfile:
                json.dump(Reg.best_hyperparams_, outfile)
        preprocessings(change=hash_)

        # Show the model performance table
        st.write("-- Model performance --")
        if model_type != reg_algo[2]:
            model_per = DataFrame(metrics(c = [y_train, yc], t = [y_test, yt], method = 'regression').scores_)
        else:
            model_per = DataFrame(metrics(c = [y_train, yc], t = [y_test, yt], method = 'regression').scores_)    
        st.dataframe(model_per)


    
    # M1.dataframe(model_per) # duplicate with line 371
    @st.cache_data(show_spinner =False)
    def prep_important(change, model_type, model_hash):
        fig, (ax1, ax2) = plt.subplots(2,1, figsize = (12, 4), sharex=True)
        ax1.plot(colnames, np.mean(X_train, axis = 0), color = 'black', label = 'Average spectrum (Raw)')
        # if model_type != reg_algo[2]:
        ax2.plot(colnames, np.mean(Reg.pretreated_spectra_ , axis = 0), color = 'black', label = 'Average spectrum (Pretreated)')
        ax2.set_xlabel('Wavelenghts')
        plt.tight_layout()

        for i in range(2):
            eval(f'ax{i+1}').grid(color = 'grey', linestyle = ':', linewidth = 0.2)
            eval(f'ax{i+1}').margins(x = 0)
            eval(f'ax{i+1}').legend(loc = 'upper right')
            eval(f'ax{i+1}').set_ylabel('Intensity')
            if model_type == 'TPE-iPLS':
                a = change
                for j in range(s):
                    if np.array(spectra.columns).dtype.kind in ['i','f']:
                        min, max = intervalls_with_cols.iloc[j,0], intervalls_with_cols.iloc[j,1]
                    else:
                        min, max = intervalls.iloc[j,0], intervalls.iloc[j,1]

                    eval(f'ax{i+1}').axvspan(min, max, color = '#00ff00', alpha = 0.5, lw = 0)

        if model_type == 'PLS':
            ax1.scatter(colnames[np.array(Reg.sel_ratio_.index)], np.mean(X_train, axis = 0).iloc[np.array(Reg.sel_ratio_.index)],
                            color = '#7ab0c7', label = 'Important variables')
            ax2.scatter(colnames[Reg.sel_ratio_.index], np.mean(Reg.pretreated_spectra_, axis = 0)[np.array(Reg.sel_ratio_.index)],
                            color = '#7ab0c7', label = 'Important variables')
            ax1.legend()
            ax2.legend()
        return fig
    
    with c8:## Visualize raw,preprocessed spectra, and selected intervalls(in case of ipls) 
        if model_type =='TPE-iPLS' :
                st.write('-- Important Spectral regions used for model creation --')
                st.table(intervalls_with_cols)
        st.write('-- Visualization of the spectral regions used for model creation --')
        imp_fig = prep_important(change = st.session_state.counter, model_type = model_type, model_hash = hash_)
        st.pyplot(imp_fig)

        # Display CV results
    numbers_dict = {1: "One", 2: "Two",3: "Three",4: "Four",5: "Five",
                    6: "Six",7: "Seven",8: "Eight",9: "Nine",10: "Ten"}
    st.subheader(f" {numbers_dict[nb_folds]}-Fold Cross-Validation results")
    
    @st.cache_data(show_spinner =False)
    def cv_display(change):
        fig1 = px.scatter(Reg.cv_data_[0], x = 'Measured', y = 'Predicted' , trendline = 'ols', color = 'Folds', symbol = 'Folds',
                color_discrete_sequence=px.colors.qualitative.G10)
        fig1.add_shape(type = 'line', x0 = .95 * min(Reg.cv_data_[0].loc[:,'Measured']), x1 = 1.05 * max(Reg.cv_data_[0].loc[:,'Measured']),
                        y0 = .95 * min(Reg.cv_data_[0].loc[:,'Measured']), y1 = 1.05 * max(Reg.cv_data_[0].loc[:,'Measured']), line = dict(color = 'black', dash = "dash"))
        fig1.update_traces(marker_size = 7, showlegend=False)
        
        fig0 = px.scatter(Reg.cv_data_[0], x ='Measured', y = 'Predicted' , trendline = 'ols', color = 'Folds', symbol = "Folds", facet_col = 'Folds',facet_col_wrap = 1,
                color_discrete_sequence = px.colors.qualitative.G10, text = 'index', width = 800, height = 1000)
        fig0.update_traces(marker_size = 8, showlegend = False)
        return fig0, fig1
    fig0, fig1 = cv_display(change= Reg.cv_data_)
    
    cv1, cv2 = st.columns([2, 2])
    with cv2:
        cv_results = DataFrame(Reg.CV_results_).round(4)# CV table
        st.write('-- Cross-Validation Summary--')
        st.write(cv_results.astype(str).style.map(lambda _: "background-color: #cecece;", subset = (cv_results.index.drop(['sd', 'mean', 'cv']), slice(None))))
        
        st.write('-- Out-of-Fold Predictions Visualization (All in one) --')
        st.plotly_chart(fig1, use_container_width = True)

    with cv1:
        st.write('-- Out-of-Fold Predictions Visualization (Separate plots) --')
        st.plotly_chart(fig0, use_container_width=True)
    

    ###################################################    BEGIN : Model Diagnosis    ####################################################
st.subheader("III - Model Diagnosis", divider='blue')
if Reg:
    # signal preprocessing results preparation for latex report
    prep_para = Reg.best_hyperparams_.copy()
    if model_type != reg_algo[2]:
        prep_para.pop('n_components')
        for i in ['deriv','polyorder']:
            if Reg.best_hyperparams_[i] == 0:
                prep_para[i] = '0'
            elif Reg.best_hyperparams_[i] == 1:
                prep_para[i] = '1st'
            elif Reg.best_hyperparams_[i] > 1:
                prep_para[i] = f"{Reg.best_hyperparams_[i]}nd"
    
    # reg plot and residuals plot
    if model_type != reg_algo[2]:
        measured_vs_predicted = reg_plot([y_train, y_test],[yc, yt], train_idx = train_index, test_idx = test_index)
        residuals_plot = resid_plot([y_train, y_test], [yc, yt], train_idx = train_index, test_idx = test_index)
    else:
        measured_vs_predicted = reg_plot([y_train, y_test],[yc, yt], train_idx = train_index, test_idx = test_index)
        residuals_plot = resid_plot([y_train, y_test], [yc, yt], train_idx=train_index, test_idx=test_index)
    
    M7, M8 = st.columns([2,2])
    with M7:
        st.write('Predicted vs Measured values')
        st.pyplot(measured_vs_predicted)
        # regression_plot.savefig('./report/figures/measured_vs_predicted.png')
    
    with M8:
        st.write('Residuals plot')
        st.pyplot(residuals_plot)
        # residual_plot.savefig('./report/figures/residuals_plot.png')

###################################################      END : Model Diagnosis   #######################################################
    
###################################################    BEGIN : Download results    #######################################################
##########################################################################################################################################
##########################################################################################################################################
if Reg:
    zip_data = ""
    st.header('Download the analysis results')
    st.write("**Note:** Please check the box only after you have finished processing your data and are satisfied with the results. Checking the box prematurely may slow down the app and could lead to crashes.")
    decis = st.checkbox("Yes, I want to download the results")
    if decis:
        @st.cache_data(show_spinner =False)
        def export_report(change):
            match model_type:
                case 'PLS':
                        latex_report = report.report('Predictive model development', file_name, stats, list(prep_para.values()), model_type, model_per, cv_results)
                        

                case 'LW-PLS':
                        latex_report = report.report('Predictive model development', file_name, stats,
                                                    list({key: Reg.best_hyperparams_[key] for key in ['deriv', 'normalization', 'polyorder', 'window_length'] if key in Reg.best_hyperparams_}.values()), model_type, model_per, cv_results)
                        
                case 'TPE-iPLS':
                        latex_report = report.report('Predictive model development', file_name, stats,
                                                    list({key: Reg.best_hyperparams_[key] for key in ['deriv', 'normalization', 'polyorder', 'window_length'] if key in Reg.best_hyperparams_}.values()), model_type, model_per, cv_results)
                        
                case _:
                    st.warning('Data processing has not been performed or finished yet!', icon = "⚠️")

        @st.cache_data(show_spinner =False)
        def preparing_results_for_downloading(change):
            match file:
                # load csv file
                case 'csv':
                    xfile.to_csv('report/out/dataset/'+ xcal_csv.name, sep = ';', encoding = 'utf-8', mode = 'a')
                    yfile.to_csv('report/out/dataset/'+ ycal_csv.name, sep = ';', encoding = 'utf-8', mode = 'a')
                case 'dx':
                    with open('report/out/dataset/'+data_file.name, 'w') as dd:
                        dd.write(dxdata)
                                    
            with open('./report/out/model/'+ model_type + '.pkl','wb') as f:# export model
                dump(reg_model, f)
            figpath ='./report/out/figures/'
            spectra_plot.savefig(figpath + "spectra_plot.png")
            target_plot.savefig(figpath + "histogram.png")
            imp_fig.savefig(figpath + "variable_importance.png")
            fig1.write_image(figpath + "meas_vs_pred_cv_all.png")
            fig0.write_image(figpath + "meas_vs_pred_cv_onebyone.png")
            measured_vs_predicted.savefig(figpath + 'measured_vs_predicted.png')
            residuals_plot.savefig(figpath + 'residuals_plot.png')
            # with open('report/out/Preprocessing.json', "w") as outfile:
            #     json.dump(Reg.best_hyperparams_, outfile)
            
            if model_type == 'TPE-iPLS': # export selected wavelengths
                wlfilename = './report/out/model/'+ model_type+'-selected_wavelengths.xlsx'
                all = concat([intervalls_with_cols.T, Reg.selected_features_], axis = 0,  ignore_index=True).T
                all.columns=['wl_from','wl_to','idx_from', 'idx_to']
                all.to_excel(wlfilename)
            
            export_report(change = hash_)
            if Path("./report/report.tex").exists():
                report.generate_report(change = hash_)
            if Path("./report/report.pdf").exists():
                move("./report/report.pdf", "./report/out/report.pdf")
            
            # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
            # pklfile = {'model_': Reg.model_,"model_type" : model_type, 'training_data':{'raw-spectra':spectra,'target':y, },
            #         'spec-preprocessing':{"normalization": Reg.best_hyperparams_['normalization'], 'SavGol(polyorder,window_length,deriv)': [Reg.best_hyperparams_["polyorder"],
            #                                                                                                                                    Reg.best_hyperparams_['window_length'],
            #                                                                                                                                    Reg.best_hyperparams_['deriv']]}}
            pklfile = {'model_': Reg.model_,"model_type" : model_type, 'data':{'raw-spectra':spectra,'target':y, 'training_data_idx':train_index,'testing_data_idx':test_index},
                    'spec-preprocessing':{"normalization": Reg.best_hyperparams_['normalization'], 'SavGol(polyorder,window_length,deriv)': [Reg.best_hyperparams_["polyorder"],
                                                                                                                                               Reg.best_hyperparams_['window_length'],
                                                                                                                                               Reg.best_hyperparams_['deriv']]}}
            if model_type == 'TPE-iPLS': # export selected wavelengths
                pklfile['selected-wls'] = {'idx':Reg.selected_features_.T , "wls":intervalls_with_cols }
            elif model_type == 'LW-PLS': # export LWPLS best model parameters
                pklfile['selected-wls'] = {'idx':None, "wls":None }
                pklfile['lwpls_params'] = Reg.best_hyperparams_
            else:
                pklfile['selected-wls'] = {'idx':None, "wls":None }
                    
            with open('./report/out/file_system.pkl', "wb") as pkl:
                dump(pklfile, pkl)

            return change
        preparing_results_for_downloading(change = hash_)
        
    
        @st.cache_data(show_spinner =False)
        def tempdir(change):
            with  TemporaryDirectory( prefix="results", dir="./report") as temp_dir:# create a temp directory
                tempdirname = os.path.split(temp_dir)[1]

                if len(os.listdir('./report/out/figures/'))>2:
                    make_archive(base_name="./report/Results", format="zip", base_dir="out", root_dir = "./report")# create a zip file
                    move("./report/Results.zip", f"./report/{tempdirname}/Results.zip")# put the inside the temp dir
                    with open(f"./report/{tempdirname}/Results.zip", "rb") as f:
                        zip_data = f.read()
            return tempdirname, zip_data

        try :
            tempdirname, zip_data = tempdir(change = hash_)
        except:
            pass
    date_time = datetime.now().strftime('%y%m%d%H%M')
    disabled_down = True if zip_data=='' else False
    st.download_button(label = 'Download', data = zip_data, file_name = f'Nirs_Workflow_{date_time}_Reg_.zip', mime ="application/zip",
                args = None, kwargs = None,type = "primary",use_container_width = True, disabled = disabled_down)


    delete_files(keep = ['.py', '.pyc','.bib'])