from Packages import *
st.set_page_config(page_title="NIRS Utils", page_icon=":goat:", layout="wide")
from utils import read_dx, DxRead, LinearPCA, Umap, find_col_index, Nmf, Sk_Kmeans, AP, KS, RDM
from mod import *
# HTML pour le bandeau "CEFE - CNRS"
add_header()
add_sidebar(pages_folder)
local_css(css_file / "style_model.css")#load specific model page css








hash_ = ''
def p_hash(add):
    global hash_
    hash_ = hash_data(hash_+str(add))
    return hash_

# ####################################  Methods ##############################################
# empty temp figures
def delete_files(keep):
    supp = []
    # Walk through the directory
    for root, dirs, files in os.walk('report/', topdown=False):
        for file in files:
            if file != 'logo_cefe.png' and not any(file.endswith(ext) for ext in keep):
                os.remove(os.path.join(root, file))

dirpath = Path('report/out/model')
if dirpath.exists() and dirpath.is_dir():
    shutil.rmtree(dirpath)

# algorithms available on our app
dim_red_methods=['PCA','UMAP', 'NMF']  # List of dimensionality reduction algos
cluster_methods = ['Kmeans','HDBSCAN', 'AP'] # List of clustering algos
selec_strategy = ['center','random']

match st.session_state["interface"]:
    case 'simple':
        st.write(':red[Automated Simple Interface]')
        # hide_pages("Predictions")
        if 37 not in st.session_state:
            default_reduction_option = 1
        else:
            default_reduction_option = dim_red_methods.index(st.session_state.get(37))
        if 38 not in st.session_state:
            default_clustering_option = 1
        else:
            default_clustering_option = cluster_methods.index(st.session_state.get(38))
        if 102 not in st.session_state:
            default_sample_selection_option = 1
        else:
            default_sample_selection_option = selec_strategy.index(st.session_state.get(102))
        
    case'advanced':
        default_reduction_option = 0
        default_clustering_option = 0
        default_sample_selection_option = 0



################ clean the results dir #############
delete_files(keep = ['.py', '.pyc','.bib'])

# ####################################### page preamble #######################################
st.title("Calibration Subset Selection") # page title
st.markdown("Create a predictive model, then use it for predicting your target variable (chemical data) from NIRS spectra")
col2, col1 = st.columns([3, 1])
col2.image("./images/sample selection.png", use_column_width=True) # graphical abstract

################################### I - Data Loading and Visualization ########################################
files_format = ['csv', 'dx'] # Supported files format
# loader for datafile
file = col1.file_uploader("Data file", type=["csv","dx"], help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns", key=5)

## Preallocation of data structure
spectra = pd.DataFrame()
meta_data = pd.DataFrame()
tcr=pd.DataFrame()
sam=pd.DataFrame()
sam1=pd.DataFrame()
selected_samples = pd.DataFrame()
non_clustered = None
l1 = []
labels = []
color_palette = None
dr_model = None # dimensionality reduction model
cl_model = None # clustering model
selection = None
selection_number = "None"
samples_df_chem = pd.DataFrame
selected_samples = []
selected_samples_idx = []

if not file:
    col1.info('Info: Please load data file !')

else:
    extension = file.name.split(".")[-1]
    userfilename = file.name.replace(f".{extension}", '')

    match extension:
    ## Load .csv file
        case 'csv':
            with col1:
                psep = st.radio("Select csv separator - _detected_: ", options = [";", ","],horizontal=True, key=9)
                phdr = st.radio("indexes column in csv? - _detected_: " , options = ["no", "yes"],horizontal=True, key=31)

                if phdr == 'yes':col = 0
                else:col = False

            # with col1:
            #     # Select list for CSV delimiter
            #     psep = st.radio("Select csv separator - _detected_: " + str(find_delimiter('data/'+file.name)), options = [";", ","], index = [";", ","].index(str(find_delimiter('data/'+file.name))),horizontal=True, key=9)
            #         # Select list for CSV header True / False
            #     phdr = st.radio("indexes column in csv? - _detected_: " + str(find_col_index('data/'+file.name)), options = ["no", "yes"], index = ["no", "yes"].index(str(find_col_index('data/'+file.name))),horizontal=True, key=31)
            #     if phdr == 'yes':col = 0
            #     else:col = False
      

                
                from io import StringIO
                stringio = StringIO(file.getvalue().decode("utf-8"))
                data_str = str(stringio.read())
                p_hash([data_str + str(file.name) , psep, phdr])
                
                @st.cache_data
                def csv_loader(change):
                    imp = pd.read_csv(file, sep = psep, index_col=col)
                    spectra, md_df_st_ = col_cat(imp)
                    meta_data = md_df_st_
                    return spectra, md_df_st_, meta_data, imp
                
                try : 
                    spectra, md_df_st_, meta_data, imp = csv_loader(change = hash_)
                    st.success("The data have been loaded successfully", icon="✅")
                except:
                    st.error('''Error: The format of the file does not correspond to the expected dialect settings.
                              To read the file correctly, please adjust the separator parameters.''')
                    

           


        ## Load .dx file
        case 'dx':
            with col1:
                # Create a temporary file to save the uploaded file
                with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
                    tmp.write(file.read())
                    tmp_path = tmp.name
                    with open(tmp.name, 'r') as dd:
                        dxdata = dd.read()
                        p_hash(str(dxdata)+str(file.name))
                        
                    ## load and parse the temp dx file
                    @st.cache_data
                    def dx_loader(change):
                        _, spectra, meta_data, md_df_st_ = read_dx(file = tmp_path)
                        # os.unlink(tmp_path) 
                        return _, spectra, meta_data, md_df_st_
                    _, spectra, meta_data, md_df_st_ = dx_loader(change = hash_)

                    st.success("The data have been loaded successfully", icon="✅")

################################################### END : I- Data loading and preparation ####################################################
# with open('report/datasets/'+file.name, 'w') as dd:
#     dd.write(dxdata)
#     tmp_path = tmp.name
# imp.to_csv("./report/datasets/"+file.name,sep = ';', encoding='utf-8', mode='a')
# fig.savefig("./report/figures/spectra_plot.png", dpi=400) ## Export report

################################################### BEGIN : visualize and split the data ####################################################
st.header("I - Spectral Data Visualization", divider='blue')
if not spectra.empty:
    p_hash(np.mean(spectra))
    n_samples = spectra.shape[0]
    nwl = spectra.shape[1]
    # retrieve columns name and rows name of the dataframe
    colnames = list(spectra.columns)
    rownames = [str(i) for i in list(spectra.index)]
    spectra.index = rownames

    @st.cache_data
    def spectra_visualize(change):
        fig, ax = plt.subplots(figsize = (30,7))
        if extension =='dx':
            lab = ['Wavenumber (1/cm)' if meta_data.loc[:,'xunits'][0] == '1/cm' else 'Wavelength (nm)']
            if lab[0] =='Wavenumber (1/cm)':
                spectra.T.plot(legend=False, ax = ax).invert_xaxis()
            else :
                spectra.T.plot(legend=False, ax = ax)
            ax.set_xlabel(lab[0], fontsize=18)
        else:
            spectra.T.plot(legend=False, ax = ax)
            ax.set_xlabel('Wavelength/Wavenumber', fontsize=18)
        
        ax.set_ylabel('Signal intensity', fontsize=18)
        plt.margins(x = 0)
        plt.tight_layout()
        
        data_info = pd.DataFrame({'Name': [file.name],
                                'Number of scanned samples': [n_samples]},
                                index = ['Input file'])
        

        # update lines size to export for report
        for line in ax.get_lines():
            line.set_linewidth(0.8)  # Set the desired line width here

        # Update the size of plot axis for exprotation to report
        l, w = fig.get_size_inches()
        fig.set_size_inches(8, 3)
        for label in (ax.get_xticklabels()+ax.get_yticklabels()):
            ax.xaxis.label.set_size(9.5)
            ax.yaxis.label.set_size(9.5)
        plt.tight_layout()
        fig.set_size_inches(l, w)# reset the plot size to its original size
        return fig, data_info
    fig_spectra, data_info = spectra_visualize(change = hash_)

    col1, col2 = st.columns([3, 1])
    with col1:
        st.pyplot(fig_spectra)

    with col2:
        st.info('Information on the loaded data file')
        st.write(data_info) ## table showing the number of samples in the data file

################################################### END : visualize and split the data ####################################################

############################## Exploratory data analysis ###############################
st.header("II - Exploratory Data Analysis-Multivariable Data Analysis", divider='blue')

###### 1- Dimensionality reduction ######
t = pd.DataFrame # scores
p = pd.DataFrame # loadings
if not spectra.empty:
    xc = standardize(spectra, center=True, scale=False)

    bb1, bb2, bb3, bb4, bb5, bb6, bb7 = st.columns([1,1,0.6,0.6,0.6,1.5,1.5])
    with bb1:
        dim_red_method = st.selectbox("Dimensionality reduction techniques: ", options = ['']+dim_red_methods, index = default_reduction_option, key = 37, format_func = lambda x: x if x else "<Select>")
        if dim_red_method == '':
            st.info('Info: Select a dimensionality reduction technique!')
        p_hash(dim_red_method)


        if dim_red_method == "UMAP":
            if not meta_data.empty:
                filter = md_df_st_.columns.tolist()
                supervised = st.selectbox('Supervised UMAP by(optional):', options = ['']+filter, format_func = lambda x: x if x else "<Select>", key=108)
                umapsupervisor = [None if supervised == '' else md_df_st_[supervised]][0]

            else:
                supervised = st.selectbox('Supervised UMAP by:', options = ["Meta-data is not available"], disabled=True, format_func = lambda x: x if x else "<Select>", key=108)
                umapsupervisor = None
            p_hash(supervised)

        disablewidgets = [False if dim_red_method else True][0]
        clus_method = st.selectbox("Clustering techniques(optional): ", options = ['']+cluster_methods, index = default_clustering_option, key = 38, format_func = lambda x: x if x else "<Select>", disabled= disablewidgets)

        
        # if disablewidgets == False and dim_red_method in dim_red_methods:
        #     inf = st.info('Info: Select a clustering technique!')

        if dim_red_method:
            @st.cache_data
            def dimensionality_reduction(change):
                match dim_red_method:
                    case "PCA":
                            dr_model = LinearPCA(xc, Ncomp=8)
                    case "UMAP":
                            dr_model = Umap(numerical_data = MinMaxScale(spectra), cat_data = umapsupervisor)   
                    case 'NMF':
                            dr_model = Nmf(spectra, Ncomp= 3)
                return dr_model
            
            dr_model = dimensionality_reduction(change = hash_)
            

        if dr_model:
            axis1 = bb3.selectbox("x-axis", options = dr_model.scores_.columns, index=0)
            axis2 = bb4.selectbox("y-axis", options = dr_model.scores_.columns, index=1)
            axis3 = bb5.selectbox("z-axis", options = dr_model.scores_.columns, index=2)
            axis = np.unique([axis1, axis2, axis3])
            p_hash(axis)
            t = dr_model.scores_.loc[:,np.unique(axis)]
            tcr = standardize(t)

###### II - clustering #######
if not t.empty:
    clustered = np.arange(n_samples)
    non_clustered = None

    if dim_red_method == 'UMAP':
        scores = st.container()
    else:
        scores, loadings= st.columns([3,3])

if not spectra.empty:
    sel_ratio = bb2.number_input('Enter the number/fraction of samples to be selected:',min_value=0.01, max_value=float("{:.2f}".format(spectra.shape[0])), value=0.20, format="%.2f", disabled= disablewidgets)
    if sel_ratio:
        p_hash(sel_ratio)
        if sel_ratio > 1.00:
            ratio = int(sel_ratio)
        elif sel_ratio < 1.00:
            ratio = int(sel_ratio*spectra.shape[0])
if dr_model and not clus_method:
    clus_method = bb2.radio('Select samples selection strategy:',
                    options = ['RDM', 'KS'],)
elif dr_model and clus_method:
    # sel_ratio = bb2.number_input('Enter the ratio/precentage of samples to be selected:',min_value=0.01, max_value=float("{:.2f}".format(spectra.shape[0])), value=0.20, format="%.2f")
    # p_hash(sel_ratio)
    # if sel_ratio > 1.00:
    #     ratio = int(sel_ratio)
    # elif sel_ratio < 1.00:
    #     ratio = int(sel_ratio*spectra.shape[0])

    if clus_method in cluster_methods:
        selection = bb2.radio('Select samples selection strategy:',
                    options = selec_strategy, index = default_sample_selection_option,key=102,disabled  = False)
    else:
        selection = bb2.radio('Select samples selection strategy:',
                    options = selec_strategy, horizontal=True, key=102,disabled  = True)






if dr_model and sel_ratio:
    # Clustering
    match clus_method:
        case 'Kmeans':
            cl_model = Sk_Kmeans(tcr, max_clusters = ratio)
            data, labels, clu_centers = cl_model.fit_optimal_
            ncluster = clu_centers.shape[0]

        # 2- HDBSCAN clustering
        case 'HDBSCAN':
            cl_model = Hdbscan(np.array(tcr))
            labels, clu_centers, non_clustered = cl_model.labels_,cl_model.centers_, cl_model.non_clustered
            ncluster = len(clu_centers)

        # 3- Affinity propagation
        case 'AP':
            cl_model = AP(X = tcr)
            data, labels, clu_centers = cl_model.fit_optimal_
            ncluster = len(clu_centers)

        case 'KS':
            cl_model = KS(x = tcr, rset = ratio)

        case 'RDM':
            cl_model = RDM(x = tcr, rset = ratio)

    # if clus_method in cluster_methods:
    #     inf.empty()

    if clus_method in ['KS', 'RDM']:
        _, selected_samples_idx = cl_model.calset
        labels = ["ind"]*n_samples
        ncluster = "1"
        selection_number = 'None'
        selection = 'None'

    new_tcr = tcr.iloc[clustered,:]    
    

# #################################################### III - Samples selection using the reduced data presentation ######


if not labels:
    custom_color_palette = px.colors.qualitative.Plotly[:1]
elif labels:
    num_clusters = len(np.unique(labels))
    custom_color_palette = px.colors.qualitative.Plotly[:num_clusters]
    if clus_method:
        match selection:
        # Strategy 0
            case 'center':
                # list samples at clusters centers - Use sklearn.metrics.pairwise_distances_argmin if you want more than 1 sample per cluster
                closest, _ = pairwise_distances_argmin_min(clu_centers, new_tcr)
                selected_samples_idx = np.array(new_tcr.index)[list(closest)]
                selected_samples_idx = selected_samples_idx.tolist()
                
            #### Strategy 1
            case 'random':
                selection_number = int(ratio/num_clusters)
                p_hash(selection_number)
                s = np.array(labels)[np.where(np.array(labels) !='Non clustered')[0]]
                for i in np.unique(s):
                    C = np.where(np.array(labels) == i)[0]
                    if C.shape[0] >= selection_number:
                        # scores.write(list(tcr.index)[labels== i])
                        km2 = KMeans(n_clusters = selection_number)
                        km2.fit(tcr.iloc[C,:])
                        clos, _ = pairwise_distances_argmin_min(km2.cluster_centers_, tcr.iloc[C,:])
                        selected_samples_idx.extend(tcr.iloc[C,:].iloc[list(clos)].index)
                    else:
                        selected_samples_idx.extend(new_tcr.iloc[C,:].index.to_list())
                # list indexes of selected samples for colored plot

# ################################      Plots visualization          ############################################


    ## Scores
if not t.empty:
    if meta_data.empty and clus_method in cluster_methods:
        filter = ['', clus_method]
    elif not meta_data.empty and clus_method in cluster_methods:
        filter = ['',clus_method] + md_df_st_.columns.tolist()
    elif not meta_data.empty and clus_method not in cluster_methods:
        filter = [''] + md_df_st_.columns.tolist()
    elif meta_data.empty and not clus_method in cluster_methods:
        filter = []

    with scores:
        st.write('Scores plot')
        tcr_plot = tcr.copy()
        colfilter = st.selectbox('Color by:', options= filter,format_func = lambda x: x if x else "<Select>")
        p_hash(colfilter)
        if colfilter in cluster_methods:
            tcr_plot[colfilter] = labels
        elif not meta_data.empty and colfilter in md_df_st_.columns.tolist():
            tcr_plot[f'{colfilter} :'] = list(map(str.lower,md_df_st_.loc[:,colfilter]))
        else:
            tcr_plot[f'{colfilter} :'] = ['sample'] * tcr_plot.shape[0]
        
        col_var_name = tcr_plot.columns.tolist()[-1]
        n_categories = len(np.unique(tcr_plot[col_var_name]))
        custom_color_palette = px.colors.qualitative.Plotly[:n_categories]

    with scores:
            if selected_samples_idx:# color selected samples
                t_selected = tcr_plot.iloc[selected_samples_idx,:]
            match t.shape[1]:
                case 3:
                    fig = px.scatter_3d(tcr_plot, x = axis[0], y = axis[1], z = axis[2], color = col_var_name ,color_discrete_sequence = custom_color_palette)
                    fig.update_traces(marker=dict(size=4))
                    if selected_samples_idx:# color selected samples
                        fig.add_scatter3d(x = t_selected.loc[:,axis[0]], y = t_selected.loc[:,axis[1]], z = t_selected.loc[:,axis[2]],
                                        mode ='markers', marker = dict(size = 5, color = 'black'), name = 'selected samples')
                    
                case 2:
                    fig = px.scatter(tcr_plot, x = axis[0], y = axis[1], color = col_var_name ,color_discrete_sequence = custom_color_palette)
                    if selected_samples_idx:# color selected samples
                        fig.add_scatter(x = t_selected.loc[:,axis[0]], y = t_selected.loc[:,axis[1]],
                                        mode ='markers', marker = dict(size = 5, color = 'black'), name = 'selected samples')

                
                case 1: 
                    fig = px.scatter(tcr_plot, x = axis[0], y = [0]*tcr_plot.shape[0], color = col_var_name ,color_discrete_sequence = custom_color_palette)
                    fig.add_scatter(x = t_selected.loc[:,axis[0]], y = [0]*tcr_plot.shape[0],
                                        mode ='markers', marker = dict(size = 5, color = 'black'), name = 'selected samples')
                    fig.update_yaxes(visible=False)

            st.plotly_chart(fig, use_container_width = True)

            if labels:
                fig_export = {}
                # export 2D scores plot
                if len(axis)== 3:
                    comb = [i for i in combinations(np.arange(len(axis)), 2)]
                    subcap = ['a','b','c']
                    for i in range(len(comb)):
                        fig_= px.scatter(tcr_plot, x = axis[(comb[i][0])], y=axis[(comb[i][1])],color = labels if list(labels) else None,color_discrete_sequence = custom_color_palette)
                        fig_.add_scatter(x = t_selected.loc[:,axis[(comb[i][0])]], y = t_selected.loc[:,axis[(comb[i][1])]], mode ='markers', marker = dict(size = 5, color = 'black'),
                                    name = 'selected samples')
                        fig_.update_layout(font=dict(size=23))
                        fig_.add_annotation(text= f'({subcap[i]})', align='center', showarrow= False, xref='paper', yref='paper', x=-0.13, y= 1,
                                                    font= dict(color= "black", size= 35), bgcolor ='white', borderpad= 2, bordercolor= 'black', borderwidth= 3)
                        fig_.update_traces(marker=dict(size= 10), showlegend= False)
                        fig_export[f'scores_pc{comb[i][0]}_pc{comb[i][1]}'] = fig_
                        # fig_export.write_image(f'./report/out/figures/scores_pc{str(comb[i][0])}_pc{str(comb[i][1])}.png')
                else:
                    fig_export['fig'] = fig
            


if not spectra.empty:
    if dim_red_method in ['PCA','NMF']:
        with loadings:
            st.write('Loadings plot')
            p = dr_model.loadings_
            freq = pd.DataFrame(colnames, index=p.index)
            if extension =='dx':
                if meta_data.loc[:,'xunits'][0] == '1/cm':
                    freq.columns = ['Wavenumber (1/cm)']
                    xlab = "Wavenumber (1/cm)"
                    inv = 'reversed'
                else:
                    freq.columns = ['Wavelength (nm)']
                    xlab = 'Wavelength (nm)'
                    inv = None
            else:
                freq.columns = ['Wavelength/Wavenumber']
                xlab = 'Wavelength/Wavenumber'
                inv = None
                
            pp = pd.concat([p, freq], axis=1)
            #########################################
            df1 = pp.melt(id_vars=freq.columns)
            loadingsplot = px.line(df1, x=freq.columns, y='value', color='variable', color_discrete_sequence=px.colors.qualitative.Plotly)
            loadingsplot.update_layout(legend=dict(x=1, y=0, font=dict(family="Courier", size=12, color="black"),
                                        bordercolor="black", borderwidth=2))
            loadingsplot.update_layout(xaxis_title = xlab,yaxis_title = "Intensity" ,xaxis = dict(autorange= inv))

            
            st.plotly_chart(loadingsplot, use_container_width=True)
    
#############################################################################################################
    if dim_red_method == 'PCA':
        influence, hotelling = st.columns([3, 3])
        with influence:
            st.write('Influence plot')
            # Laverage
            Hat =  t.to_numpy() @ np.linalg.inv(np.transpose(t.to_numpy()) @ t.to_numpy()) @ np.transpose(t.to_numpy())
            leverage = np.diag(Hat) / np.trace(Hat)
            tresh3 = 2 * tcr.shape[1]/n_samples
            # Loadings
            p = dr_model.loadings_.loc[:,axis]
            # Matrix reconstruction
            xp = np.dot(t,p.T)
            # Q residuals: Q residuals represent the magnitude of the variation remaining in each sample after projection through the model
            residuals = np.diag(np.subtract(xc.to_numpy(), xp)@ np.subtract(xc.to_numpy(), xp).T)
            tresh4 = sc.stats.chi2.ppf(0.05, df = len(axis))

            # color with metadata
            if colfilter:
                if colfilter == "":
                    l1 = ["Samples"]* n_samples

                elif colfilter == clus_method:
                    l1 = labels

                else:
                    l1 = tcr_plot[f'{colfilter} :']

            # elif meta_data.empty and clus_method:                        
            #     l1 = labels

            # elif meta_data.empty and not clus_method:
            #     l1 = ["Samples"]* n_samples
            
            # elif not meta_data.empty and not clus_method:
            #     l1 = list(map(str.lower,md_df_st_[col]))
            tcr_plot["leverage"] = leverage
            tcr_plot["residuals"] = residuals
            influence_plot = px.scatter(data_frame =tcr_plot, x = "leverage", y = "residuals", color=col_var_name,
                                            color_discrete_sequence= custom_color_palette)
            influence_plot.add_scatter(x = leverage[selected_samples_idx] , y = residuals[selected_samples_idx],
                                       mode ='markers', marker = dict(size = 5, color = 'black'), name = 'selected samples')
            
            influence_plot.add_vline(x = tresh3, line_width = 1, line_dash = 'solid', line_color = 'red')
            influence_plot.add_hline(y=tresh4, line_width=1, line_dash='solid', line_color='red')
            influence_plot.update_layout(xaxis_title="Leverage", yaxis_title = "Q-residuals", font=dict(size=20), width=800, height=600)

            out3 = leverage > tresh3
            out4 = residuals > tresh4

            for i in range(n_samples):
                if out3[i]:
                    if not meta_data.empty:
                        ann =  meta_data.loc[:,'name'][i]
                    else:
                        ann = t.index[i]
                    influence_plot.add_annotation(dict(x = leverage[i], y = residuals[i], showarrow=True, text = str(ann),font= dict(color= "black", size= 15),
                                xanchor = 'auto', yanchor = 'auto'))
            
            influence_plot.update_traces(marker=dict(size= 6), showlegend= True)
            influence_plot.update_layout(font=dict(size=23), width=800, height=500)
            st.plotly_chart(influence_plot, use_container_width=True)


            
            for annotation in influence_plot.layout.annotations:
                annotation.font.size = 35
            influence_plot.update_layout(font=dict(size=23), width=800, height=600)
            influence_plot.update_traces(marker=dict(size= 10), showlegend= False)
            influence_plot.add_annotation(text= '(a)', align='center', showarrow= False, xref='paper', yref='paper', x=-0.125, y= 1,
                                             font= dict(color= "black", size= 35), bgcolor ='white', borderpad= 2, bordercolor= 'black', borderwidth= 3)
            # influence_plot.write_image('./report/out/figures/influence_plot.png', engine = 'kaleido')
        
        
        with hotelling:
            st.write('T²-Hotelling vs Q-residuals plot')
            # Hotelling
            hotelling  = t.var(axis = 1)
            # Q residuals: Q residuals represent the magnitude of the variation remaining in each sample after projection through the model
            residuals = np.diag(np.subtract(xc.to_numpy(), xp)@ np.subtract(xc.to_numpy(), xp).T)

            fcri = sc.stats.f.isf(0.05, 3, n_samples)
            tresh0 = (3 * (n_samples ** 2 - 1) * fcri) / (n_samples * (n_samples - 3))
            tresh1 = sc.stats.chi2.ppf(0.05, df = 3)
            
            hotelling_plot = px.scatter(t, x = hotelling, y = residuals, color=labels if list(labels) else None,
                                            color_discrete_sequence= custom_color_palette)
            hotelling_plot.add_scatter(x = hotelling[selected_samples_idx] , y = residuals[selected_samples_idx],
                                       mode ='markers', marker = dict(size = 5, color = 'black'), name = 'selected samples')
            hotelling_plot.update_layout(xaxis_title="Hotelling-T² distance",yaxis_title="Q-residuals")
            hotelling_plot.add_vline(x=tresh0, line_width=1, line_dash='solid', line_color='red')
            hotelling_plot.add_hline(y=tresh1, line_width=1, line_dash='solid', line_color='red')

            out0 = hotelling > tresh0
            out1 = residuals > tresh1

            
            for i in range(n_samples):
                if out0[i]:
                    if not meta_data.empty:
                        ann =  meta_data.loc[:,'name'][i]
                    else:
                        ann = t.index[i]
                    hotelling_plot.add_annotation(dict(x = hotelling[i], y = residuals[i], showarrow=True, text = str(ann), font= dict(color= "black", size= 15),
                                xanchor = 'auto', yanchor = 'auto'))
                    
            hotelling_plot.update_traces(marker=dict(size= 6), showlegend= True)
            hotelling_plot.update_layout(font=dict(size=23), width=800, height=500)
            st.plotly_chart(hotelling_plot, use_container_width=True)


            for annotation in hotelling_plot.layout.annotations:
                annotation.font.size = 35
            hotelling_plot.update_layout(font=dict(size=23), width=800, height=600)
            hotelling_plot.update_traces(marker=dict(size= 10), showlegend= False)
            hotelling_plot.add_annotation(text= '(b)', align='center', showarrow= False, xref='paper', yref='paper', x=-0.125, y= 1,
                                             font= dict(color= "black", size= 35), bgcolor ='white', borderpad= 2, bordercolor= 'black', borderwidth= 3)
            # hotelling_plot.write_image("./report/out/figures/hotelling_plot.png", format="png")

st.header('III - Selected Samples for Reference Analysis', divider='blue')
if labels:
    sel, info = st.columns([3, 1])
    sel.write("Tabular identifiers of selected samples for reference analysis:")
    if selected_samples_idx:
        if meta_data.empty:
            sam1 = pd.DataFrame({'name': spectra.index[clustered][selected_samples_idx],
                                'cluster':np.array(labels)[clustered][selected_samples_idx]},
                                index = selected_samples_idx)
        else:
            sam1 = meta_data.iloc[clustered,:].iloc[selected_samples_idx,:]
            sam1.insert(loc=0, column='index', value=selected_samples_idx)
            sam1.insert(loc=1, column='cluster', value=np.array(labels)[selected_samples_idx])
        sam1.index = np.arange(len(selected_samples_idx))+1
        info.info(f'Information !\n - The total number of samples: {n_samples}.\n- The number of samples selected for reference analysis: {sam1.shape[0]}.\n - The proportion of samples selected for reference analysis: {round(sam1.shape[0]/n_samples*100)}%.')
        sam = sam1
        # if clus_method == cluster_methods[2]:
        #     unclus = sel.checkbox("Include non clustered samples (for HDBSCAN clustering)", value=True)

        if clus_method == cluster_methods[2]:
            unclus = sel.checkbox("Include non clustered samples (for HDBSCAN clustering)", value=True)

            if selected_samples_idx:
                if unclus:
                    if meta_data.empty:
                        sam2 = pd.DataFrame({'name': spectra.index[non_clustered],
                                            'cluster':['Non clustered']*len(spectra.index[non_clustered])},
                                            index = spectra.index[non_clustered])
                    else :
                        sam2 = meta_data.iloc[non_clustered,:]
                        sam2.insert(loc=0, column='index', value= spectra.index[non_clustered])
                        sam2.insert(loc=1, column='cluster', value=['Non clustered']*len(spectra.index[non_clustered]))
                    
                    sam = pd.concat([sam1, sam2], axis = 0)
                    sam.index = np.arange(sam.shape[0])+1
                    info.info(f'- The number of Non-clustered samples: {sam2.shape[0]}.\n - The proportion of Non-clustered samples: {round(sam2.shape[0]/n_samples*100)}%')
        else:
            sam = sam1
        sel.write(sam)


if not sam.empty:
    Nb_ech = str(n_samples)
    nb_clu = str(sam1.shape[0])
    st.header('Download the analysis results')
    st.write("**Note:** Please check the box only after you have finished processing your data and are satisfied with the results. Checking the box prematurely may slow down the app and could lead to crashes.")
    decis = st.checkbox("Yes, I want to download the results")
    if decis:
        ###################################################
        # ## generate report
        @st.cache_data
        def export_report(change):
            latex_report = report.report('Representative subset selection', file.name, dim_red_method,
                                        clus_method, Nb_ech, ncluster, selection, selection_number, nb_clu,tcr, sam)

        
        @st.cache_data
        def preparing_results_for_downloading(change):
            match extension:
                # load csv file
                case 'csv':
                    imp.to_csv('report/out/dataset/'+ file.name, sep = ';', encoding = 'utf-8', mode = 'a')
                case 'dx':
                    with open('report/out/dataset/'+file.name, 'w') as dd:
                        dd.write(dxdata)

            fig_spectra.savefig("./report/out/figures/spectra_plot.png", dpi=400) ## Export report

            if len(axis) == 3:
                for i in range(len(comb)):
                    fig_export[f'scores_pc{comb[i][0]}_pc{comb[i][1]}'].write_image(f'./report/out/figures/scores_pc{str(comb[i][0]+1)}_pc{str(comb[i][1]+1)}.png')
            elif len(axis)==2 :
                fig_export['fig'].write_image(f'./report/out/figures/scores_plot2D.png')
            elif len(axis)==1 :
                fig_export['fig'].write_image(f'./report/out/figures/scores_plot1D.png')
                    
            # Export du graphique
            if dim_red_method in ['PCA','NMF']:
                img = pio.to_image(loadingsplot, format="png")
                with open("./report/out/figures/loadings_plot.png", "wb") as f:
                    f.write(img)
            if dim_red_method == 'PCA': 
                hotelling_plot.write_image("./report/out/figures/hotelling_plot.png", format="png")
                influence_plot.write_image('./report/out/figures/influence_plot.png', engine = 'kaleido')
            
            sam.to_csv('./report/out/Selected_subset_for_calib_development.csv', sep = ';')

            export_report(change = hash_)
            if Path("./report/report.tex").exists():
                report.generate_report(change = hash_)
            if Path("./report/report.pdf").exists():
                shutil.move("./report/report.pdf", "./report/out/report.pdf")
            return change


        preparing_results_for_downloading(change = hash_)
        report.generate_report(change = hash_)

        

        import tempfile
        @st.cache_data
        def tempdir(change):
            with  tempfile.TemporaryDirectory( prefix="results", dir="./report") as temp_dir:# create a temp directory
                tempdirname = os.path.split(temp_dir)[1]

                if len(os.listdir('./report/out/figures/'))>=2:
                    shutil.make_archive(base_name="./report/Results", format="zip", base_dir="out", root_dir = "./report")# create a zip file
                    shutil.move("./report/Results.zip", f"./report/{tempdirname}/Results.zip")# put the inside the temp dir
                    with open(f"./report/{tempdirname}/Results.zip", "rb") as f:
                        zip_data = f.read()
            return tempdirname, zip_data

        date_time = datetime.datetime.now().strftime('%y%m%d%H%M')
        try :
            tempdirname, zip_data = tempdir(change = hash_)
            st.download_button(label = 'Download', data = zip_data, file_name = f'Nirs_Workflow_{date_time}_SamSel_.zip', mime ="application/zip",
                        args = None, kwargs = None,type = "primary",use_container_width = True)
        except:
            pass

        delete_files(keep = ['.py', '.pyc','.bib'])