from Packages import *
st.set_page_config(page_title="NIRS Utils", page_icon=":goat:", layout="wide")
from Modules import *

# empty temp figures
repertoire_a_vider = Path('Report/figures')
if os.path.exists(repertoire_a_vider):
    for fichier in os.listdir(repertoire_a_vider):
        chemin_fichier = os.path.join(repertoire_a_vider, fichier)
        if os.path.isfile(chemin_fichier) or os.path.islink(chemin_fichier):
            os.unlink(chemin_fichier)
        elif os.path.isdir(chemin_fichier):
            shutil.rmtree(chemin_fichier)
# HTML pour le bandeau "CEFE - CNRS"
add_header()
#load specific model page css
local_css(css_file / "style_model.css")
add_sidebar(pages_folder)



# algorithms available in our app
dim_red_methods=['', 'PCA','UMAP', 'NMF']  # List of dimensionality reduction algos
cluster_methods = ['', 'Kmeans','HDBSCAN', 'AP', 'KS', 'RDM'] # List of clustering algos
selec_strategy = ['center','random']

match st.session_state["interface"]:
    case 'simple':
        st.write(':red[Automated Simple Interface]')
        # hide_pages("Predictions")
        if 37 not in st.session_state:
            default_reduction_option = 1
        else:
            default_reduction_option = dim_red_methods.index(st.session_state.get(37))
        if 38 not in st.session_state:
            default_clustering_option = 1
        else:
            default_clustering_option = cluster_methods.index(st.session_state.get(38))
        if 102 not in st.session_state:
            default_sample_selection_option = 1
        else:
            default_sample_selection_option = selec_strategy.index(st.session_state.get(102))
        
    case'advanced':
        default_reduction_option = 0
        default_clustering_option = 0
        default_sample_selection_option = 0

################################### I - Data Loading and Visualization ########################################
st.title("Calibration Subset Selection")
col2, col1 = st.columns([3, 1])
col2.image("C:/Users/diane/Desktop/nirs_workflow/src/images/graphical_abstract.jpg", use_column_width=True)
## Preallocation of data structure
spectra = pd.DataFrame()
meta_data = pd.DataFrame()
tcr=pd.DataFrame()
sam=pd.DataFrame()
sam1=pd.DataFrame()
selected_samples = pd.DataFrame()
non_clustered = None
l1 = []
labels = []
color_palette = None
dr_model = None # dimensionality reduction model
cl_model = None # clustering model
selection = None
selection_number = None

# loader for datafile
data_file = col1.file_uploader("Data file", type=["csv","dx"], help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns", key=5)


if not data_file:
    col1.warning('⚠️ Please load data file !')
else:
    # Retrieve the extension of the file
    test = data_file.name[data_file.name.find('.'):]
    match test:
    ## Load .csv file
        case '.csv':
            with col1:
                # Select list for CSV delimiter
                psep = st.radio("Select csv separator - _detected_: " + str(find_delimiter('data/'+data_file.name)), options=[";", ","], index=[";", ","].index(str(find_delimiter('data/'+data_file.name))),horizontal=True, key=9)
                    # Select list for CSV header True / False
                phdr = st.radio("indexes column in csv? - _detected_: " + str(find_col_index('data/'+data_file.name)), options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+data_file.name))),horizontal=True, key=31)
                if phdr == 'yes':
                    col = 0
                else:
                    col = False
                imp = pd.read_csv(data_file, sep=psep, index_col=col)
                # spectra = col_cat(imp)[0]
                # meta_data = col_cat(imp)[1]
                spectra, md_df_st_ = col_cat(imp)
                meta_data = md_df_st_
                st.success("The data have been loaded successfully", icon="✅")
        ## Load .dx file
        case '.dx':
            # Create a temporary file to save the uploaded file
            with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
                tmp.write(data_file.read())
                tmp_path = tmp.name
                with col1:
                    _, spectra, meta_data, md_df_st_ = read_dx(file = tmp_path)
                    st.success("The data have been loaded successfully", icon="✅")
            os.unlink(tmp_path)


    
## Visualize spectra
st.header("I - Spectral Data Visualization", divider='blue')
if not spectra.empty:
    n_samples = spectra.shape[0]
    nwl = spectra.shape[1]
    # retrieve columns name and rows name of spectra
    colnames = list(spectra.columns)
    rownames = [str(i) for i in list(spectra.index)]
    spectra.index = rownames
    col2, col1 = st.columns([3, 1])
    with col2:
        fig, ax = plt.subplots(figsize = (30,7))
        if test =='.dx':
            lab = ['Wavenumber (1/cm)' if meta_data.loc[:,'xunits'][0] == '1/cm' else 'Wavelength (nm)']
            if lab[0] =='Wavenumber (1/cm)':
                spectra.T.plot(legend=False, ax = ax).invert_xaxis()
            else :
                spectra.T.plot(legend=False, ax = ax)
            ax.set_xlabel(lab[0], fontsize=18)
        else:
            spectra.T.plot(legend=False, ax = ax)
            ax.set_xlabel('Wavelength/Wavenumber', fontsize=18)
        
        ax.set_ylabel('Signal intensity', fontsize=18)
        plt.margins(x = 0)
        plt.tight_layout()
        st.pyplot(fig)
        
        # update lines size
        for line in ax.get_lines():
            line.set_linewidth(0.8)  # Set the desired line width here

        # Update the size of plot axis for exprotation to report
        l, w = fig.get_size_inches()
        fig.set_size_inches(8, 3)
        for label in (ax.get_xticklabels()+ax.get_yticklabels()):
            ax.xaxis.label.set_size(9.5)
            ax.yaxis.label.set_size(9.5)
        plt.tight_layout()
        fig.savefig("./Report/figures/spectra_plot.png", dpi=400) ## Export report
        fig.set_size_inches(l, w)# reset the plot size to its original size
        data_info = pd.DataFrame({'Name': [data_file.name],
                                'Number of scanned samples': [n_samples]},
                                  index = ['Input file'])
    with col1:
        st.info('Information on the loaded data file')
        st.write(data_info) ## table showing the number of samples in the data file

############################## Exploratory data analysis ###############################
st.header("II - Exploratory Data Analysis-Multivariable Data Analysis", divider='blue')

###### 1- Dimensionality reduction ######
t = pd.DataFrame # scores
p = pd.DataFrame # loadings
if not spectra.empty:
    bb1, bb2, bb3, bb4, bb5, bb6, bb7 = st.columns([1,1,0.6,0.6,0.6,1.5,1.5])
    dim_red_method = bb1.selectbox("Dimensionality reduction techniques: ", options = dim_red_methods, index = default_reduction_option, key = 37)
    clus_method = bb2.selectbox("Clustering/sampling techniques: ", options = cluster_methods, index = default_clustering_option, key = 38)
    xc = standardize(spectra, center=True, scale=False)
    
    match dim_red_method:
        case "":
                bb1.warning('⚠️ Please choose an algorithm !')
        
        case "PCA":
            dr_model = LinearPCA(xc, Ncomp=8)

        case "UMAP":
            if not meta_data.empty:
                filter = md_df_st_.columns
                filter = filter.insert(0, 'Nothing')
                col = bb1.selectbox('Supervised UMAP by:', options= filter, key=108)
                if col == 'Nothing':
                    supervised = None
                else:
                    supervised = md_df_st_[col]
            else:
                supervised = None
            dr_model = Umap(numerical_data = MinMaxScale(spectra), cat_data = supervised)

        case 'NMF':
            dr_model = Nmf(spectra, Ncomp= 3)

    if dr_model:
        axis1 = bb3.selectbox("x-axis", options = dr_model.scores_.columns, index=0)
        axis2 = bb4.selectbox("y-axis", options = dr_model.scores_.columns, index=1)
        axis3 = bb5.selectbox("z-axis", options = dr_model.scores_.columns, index=2)

        t = pd.concat([dr_model.scores_.loc[:,axis1], dr_model.scores_.loc[:,axis2], dr_model.scores_.loc[:,axis3]], axis = 1)



###### II - clustering #######
if not t.empty:
    clustered = np.arange(n_samples)
    non_clustered = None

    if dim_red_method == 'UMAP':
        scores = st.container()
    else:
        scores, loadings= st.columns([3,3])

    tcr = standardize(t)
    
    # Clustering
    match clus_method:
        case '':
            bb2.warning('⚠️ Please choose an algothithm !')
        case 'Kmeans':
            cl_model = Sk_Kmeans(tcr, max_clusters = 25)
            ncluster = scores.number_input(min_value=2, max_value=25, value=cl_model.suggested_n_clusters_, label = 'Select the desired number of clusters')  
            data, labels, clu_centers = cl_model.fit_optimal(nclusters = ncluster)

        # 2- HDBSCAN clustering
        case 'HDBSCAN':
            optimized_hdbscan = Hdbscan(np.array(tcr))
            all_labels, clu_centers = optimized_hdbscan.HDBSCAN_scores_
            labels = [f'cluster#{i+1}' if i !=-1 else 'Non clustered' for i in all_labels]
            ncluster = len(clu_centers)
            non_clustered = np.where(np.array(labels) == 'Non clustered')[0]

        # 3- Affinity propagation
        case 'AP':
            cl_model = AP(X = tcr)
            data, labels, clu_centers = cl_model.fit_optimal_
            ncluster = len(clu_centers)

        case 'KS':
            rset = scores.number_input(min_value=0, max_value=100, value=20, label = 'The ratio of data to be sampled (%)')
            cl_model = KS(x = tcr, rset = rset)
            calset = cl_model.calset
            labels = ["ind"]*n_samples
            ncluster = "1"
            selection_number = 'None'

        case 'RDM':
            rset = scores.number_input(min_value=0, max_value=100, value=20, label = 'The ratio of data to be sampled (%)')
            cl_model = RDM(x = tcr, rset = rset)
            calset = cl_model.calset
            labels = ["ind"]*n_samples
            ncluster = "1"
            selection_number = 'None'            

    new_tcr = tcr.iloc[clustered,:]    
    

#################################################### III - Samples selection using the reduced data preentation ######
samples_df_chem = pd.DataFrame
selected_samples = []
selected_samples_idx = []

if not labels:
    custom_color_palette = px.colors.qualitative.Plotly[:1]
elif labels:
    num_clusters = len(np.unique(labels))
    custom_color_palette = px.colors.qualitative.Plotly[:num_clusters]
    if clus_method:
        if clus_method in ['KS', 'RDM']:
            selected_samples_idx = calset[1]
            selection = 'None'
        else:
            selection = scores.radio('Select samples selection strategy:',
                                        options = selec_strategy, index = default_sample_selection_option, key=102)
        
        match selection:
        # Strategy 0
            case 'center':
                # list samples at clusters centers - Use sklearn.metrics.pairwise_distances_argmin if you want more than 1 sample per cluster
                closest, _ = pairwise_distances_argmin_min(clu_centers, new_tcr)
                selected_samples_idx = np.array(new_tcr.index)[list(closest)]
                selected_samples_idx = selected_samples_idx.tolist()
                
            #### Strategy 1
            case 'random':
                selection_number = scores.number_input('How many samples per cluster?',
                                                        min_value = 1, step=1, value = 3)
                s = np.array(labels)[np.where(np.array(labels) !='Non clustered')[0]]
                for i in np.unique(s):
                    C = np.where(np.array(labels) == i)[0]
                    if C.shape[0] >= selection_number:
                        # scores.write(list(tcr.index)[labels== i])
                        km2 = KMeans(n_clusters = selection_number)
                        km2.fit(tcr.iloc[C,:])
                        clos, _ = pairwise_distances_argmin_min(km2.cluster_centers_, tcr.iloc[C,:])
                        selected_samples_idx.extend(tcr.iloc[C,:].iloc[list(clos)].index)
                    else:
                        selected_samples_idx.extend(new_tcr.iloc[C,:].index.to_list())
                    # list indexes of selected samples for colored plot    

################################      Plots visualization          ############################################


    ## Scores
if not t.empty:
    with scores:
        fig1, ((ax1, ax2),(ax3,ax4)) = plt.subplots(2,2)
        st.write('Scores plot')
        # scores plot with clustering
        if list(labels) and meta_data.empty:
            fig = px.scatter_3d(tcr, x=axis1, y=axis2, z = axis3, color=labels ,color_discrete_sequence= custom_color_palette)
            sns.scatterplot(data = tcr, x = axis1, y =axis2 , hue = labels, ax = ax1)
        # scores plot with metadata
        elif len(list(labels)) == 0 and not meta_data.empty:
            filter = md_df_st_.columns
            col = st.selectbox('Color by:', options= filter)
            if col == 0:
                fig = px.scatter_3d(tcr, x=axis1, y=axis2, z = axis3)
                sns.scatterplot(data = tcr, x = axis1, y =axis2 , ax = ax1)
                sns.scatterplot(data = tcr, x = axis2, y =axis3 , ax = ax2)
                sns.scatterplot(data = tcr, x = axis1, y =axis3 , hue = list(map(str.lower,md_df_st_[col])), ax = ax3)


            else:
                fig = px.scatter_3d(tcr, x=axis1, y=axis2, z = axis3, color = list(map(str.lower,md_df_st_[col])) )
                sns.scatterplot(data = tcr, x = axis1, y =axis2 , hue = list(map(str.lower,md_df_st_[col])), ax = ax1)
                sns.scatterplot(data = tcr, x = axis2, y =axis3 , hue = list(map(str.lower,md_df_st_[col])), ax = ax2)
                sns.scatterplot(data = tcr, x = axis1, y =axis3 , hue = list(map(str.lower,md_df_st_[col])), ax = ax3)

        # color with scores and metadata
        elif len(list(labels)) > 0  and not meta_data.empty:
            if clus_method in cluster_methods[1:]:
                filter = ['None', clus_method]
                filter.extend(md_df_st_.columns)
            else:
                filter = md_df_st_.columns.insert(0,'None')

            col = st.selectbox('Color by:', options= filter)
            if col == "None":
                fig = px.scatter_3d(tcr, x=axis1, y=axis2, z = axis3)
                sns.scatterplot(data = tcr, x = axis1, y =axis2 , ax = ax1)
            elif col == clus_method:
                fig = px.scatter_3d(tcr, x=axis1, y=axis2, z = axis3, color = labels)
                sns.scatterplot(data = tcr, x = axis1, y =axis2 , ax = ax1)
            else:
                fig = px.scatter_3d(tcr, x=axis1, y=axis2, z = axis3, color = list(map(str.lower,md_df_st_[col])))
                sns.scatterplot(data = tcr, x = axis1, y =axis2 , hue = list(map(str.lower,md_df_st_[col])), ax = ax1)
                sns.scatterplot(data = tcr, x = axis1, y =axis2 , hue = list(map(str.lower,md_df_st_[col])), ax = ax2)
                sns.scatterplot(data = tcr, x = axis1, y =axis2 , hue = list(map(str.lower,md_df_st_[col])), ax = ax3)

        else:
            fig = px.scatter_3d(tcr, x=axis1, y=axis2, z = axis3, color=labels if list(labels) else None,color_discrete_sequence= custom_color_palette)
            sns.scatterplot(data = tcr, x = axis1, y =axis2 , ax = ax1)
        fig.update_traces(marker=dict(size=4))

        if selected_samples_idx:
            tt = tcr.iloc[selected_samples_idx,:]
            fig.add_scatter3d(x = tt.loc[:,axis1], y = tt.loc[:,axis2],z = tt.loc[:,axis3],
                               mode ='markers', marker = dict(size = 5, color = 'black'),
                              name = 'selected samples')        
        st.plotly_chart(fig, use_container_width = True)

        if labels:
            # export 2D scores plot
            comb = [i for i in combinations([1,2,3], 2)]
            subcap = ['a','b','c']
            for i in range(len(comb)):
                fig_export = px.scatter(tcr, x = eval(f'axis{str(comb[i][0])}'), y=eval(f'axis{str(comb[i][1])}'),
                                            color = labels if list(labels) else None,
                                            color_discrete_sequence = custom_color_palette)
                fig_export.add_scatter(x = tt.loc[:,eval(f'axis{str(comb[i][0])}')], y = tt.loc[:,eval(f'axis{str(comb[i][1])}')],
                               mode ='markers', marker = dict(size = 5, color = 'black'),
                              name = 'selected samples')
                fig_export.update_layout(font=dict(size=23))
                fig_export.add_annotation(text= f'({subcap[i]})', align='center', showarrow= False, xref='paper', yref='paper', x=-0.13, y= 1,
                                             font= dict(color= "black", size= 35), bgcolor ='white', borderpad= 2, bordercolor= 'black', borderwidth= 3)
                fig_export.update_traces(marker=dict(size= 10), showlegend= False)
                fig_export.write_image(f'./Report/Figures/scores_pc{str(comb[i][0])}_pc{str(comb[i][1])}.png')



if not spectra.empty:
    if dim_red_method in ['PCA','NMF']:
        with loadings:
            st.write('Loadings plot')
            p = dr_model.loadings_
            freq = pd.DataFrame(colnames, index=p.index)
            if test =='.dx':
                if meta_data.loc[:,'xunits'][0] == '1/cm':
                    freq.columns = ['Wavenumber (1/cm)']
                    xlab = "Wavenumber (1/cm)"
                    inv = 'reversed'
                else:
                    freq.columns = ['Wavelength (nm)']
                    xlab = 'Wavelength (nm)'
                    inv = None
            else:
                freq.columns = ['Wavelength/Wavenumber']
                xlab = 'Wavelength/Wavenumber'
                inv = None
                
            pp = pd.concat([p, freq], axis=1)
            #########################################
            df1 = pp.melt(id_vars=freq.columns)
            fig = px.line(df1, x=freq.columns, y='value', color='variable', color_discrete_sequence=px.colors.qualitative.Plotly)
            fig.update_layout(legend=dict(x=1, y=0, font=dict(family="Courier", size=12, color="black"),
                                        bordercolor="black", borderwidth=2))
            fig.update_layout(xaxis_title = xlab,yaxis_title = "Intensity" ,xaxis = dict(autorange= inv))

            
            st.plotly_chart(fig, use_container_width=True)
            

            # Export du graphique
            img = pio.to_image(fig, format="png")
            with open("./Report/figures/loadings_plot.png", "wb") as f:
                f.write(img)
#############################################################################################################
    if dim_red_method == 'PCA':
        influence, hotelling = st.columns([3, 3])
        with influence:
            st.write('Influence plot')
            # Laverage
            Hat =  t.to_numpy() @ np.linalg.inv(np.transpose(t.to_numpy()) @ t.to_numpy()) @ np.transpose(t.to_numpy())
            leverage = np.diag(Hat) / np.trace(Hat)
            tresh3 = 2 * tcr.shape[1]/n_samples
            # Loadings
            p = pd.concat([dr_model.loadings_.loc[:,axis1], dr_model.loadings_.loc[:,axis2], dr_model.loadings_.loc[:,axis3]], axis = 1)
            # Matrix reconstruction
            xp = np.dot(t,p.T)
            # Q residuals: Q residuals represent the magnitude of the variation remaining in each sample after projection through the model
            residuals = np.diag(np.subtract(xc.to_numpy(), xp)@ np.subtract(xc.to_numpy(), xp).T)
            tresh4 = sc.stats.chi2.ppf(0.05, df = 3)

            # color with metadata
            if not meta_data.empty and clus_method:
                if col == "None":
                    l1 = ["Samples"]* n_samples

                elif col == clus_method:
                    l1 = labels
                
                else:
                    l1 = list(map(str.lower,md_df_st_[col]))

            elif meta_data.empty and clus_method:                        
                l1 = labels

            elif meta_data.empty and not clus_method:
                l1 = ["Samples"]* n_samples
            
            elif not meta_data.empty and not clus_method:
                l1 = list(map(str.lower,md_df_st_[col]))

            fig = px.scatter(x = leverage, y = residuals, color=labels if list(labels) else None,
                                            color_discrete_sequence= custom_color_palette)
            fig.add_vline(x = tresh3, line_width = 1, line_dash = 'solid', line_color = 'red')
            fig.add_hline(y=tresh4, line_width=1, line_dash='solid', line_color='red')
            fig.update_layout(xaxis_title="Leverage", yaxis_title = "Q-residuals", font=dict(size=20), width=800, height=600)

            out3 = leverage > tresh3
            out4 = residuals > tresh4

            for i in range(n_samples):
                if out3[i]:
                    if not meta_data.empty:
                        ann =  meta_data.loc[:,'name'][i]
                    else:
                        ann = t.index[i]
                    fig.add_annotation(dict(x = leverage[i], y = residuals[i], showarrow=True, text = ann,font= dict(color= "black", size= 15),
                                xanchor = 'auto', yanchor = 'auto'))
            
            fig.update_traces(marker=dict(size= 6), showlegend= True)
            fig.update_layout(font=dict(size=23), width=800, height=500)
            st.plotly_chart(fig, use_container_width=True)


            
            for annotation in fig.layout.annotations:
                annotation.font.size = 35
            fig.update_layout(font=dict(size=23), width=800, height=600)
            fig.update_traces(marker=dict(size= 10), showlegend= False)
            fig.add_annotation(text= '(a)', align='center', showarrow= False, xref='paper', yref='paper', x=-0.125, y= 1,
                                             font= dict(color= "black", size= 35), bgcolor ='white', borderpad= 2, bordercolor= 'black', borderwidth= 3)
            fig.write_image('./Report/figures/influence_plot.png', engine = 'kaleido')
        
        
        with hotelling:
            st.write('T²-Hotelling vs Q-residuals plot')
            # Hotelling
            hotelling  = t.var(axis = 1)
            # Q residuals: Q residuals represent the magnitude of the variation remaining in each sample after projection through the model
            residuals = np.diag(np.subtract(xc.to_numpy(), xp)@ np.subtract(xc.to_numpy(), xp).T)

            fcri = sc.stats.f.isf(0.05, 3, n_samples)
            tresh0 = (3 * (n_samples ** 2 - 1) * fcri) / (n_samples * (n_samples - 3))
            tresh1 = sc.stats.chi2.ppf(0.05, df = 3)
            
            fig = px.scatter(t, x = hotelling, y = residuals, color=labels if list(labels) else None,
                                            color_discrete_sequence= custom_color_palette)
            fig.update_layout(xaxis_title="Hotelling-T² distance",yaxis_title="Q-residuals")
            fig.add_vline(x=tresh0, line_width=1, line_dash='solid', line_color='red')
            fig.add_hline(y=tresh1, line_width=1, line_dash='solid', line_color='red')

            out0 = hotelling > tresh0
            out1 = residuals > tresh1

            
            for i in range(n_samples):
                if out0[i]:
                    if not meta_data.empty:
                        ann =  meta_data.loc[:,'name'][i]
                    else:
                        ann = t.index[i]
                    fig.add_annotation(dict(x = hotelling[i], y = residuals[i], showarrow=True, text = ann, font= dict(color= "black", size= 15),
                                xanchor = 'auto', yanchor = 'auto'))
                    
            fig.update_traces(marker=dict(size= 6), showlegend= True)
            fig.update_layout(font=dict(size=23), width=800, height=500)
            st.plotly_chart(fig, use_container_width=True)


            for annotation in fig.layout.annotations:
                annotation.font.size = 35
            fig.update_layout(font=dict(size=23), width=800, height=600)
            fig.update_traces(marker=dict(size= 10), showlegend= False)
            fig.add_annotation(text= '(b)', align='center', showarrow= False, xref='paper', yref='paper', x=-0.125, y= 1,
                                             font= dict(color= "black", size= 35), bgcolor ='white', borderpad= 2, bordercolor= 'black', borderwidth= 3)
            fig.write_image("./Report/figures/hotelling_plot.png", format="png")

st.header('III - Selected Samples for Reference Analysis', divider='blue')
if labels:
    sel, info = st.columns([3, 1])
    sel.write("Tabular identifiers of selected samples for reference analysis:")
    if selected_samples_idx:
        if meta_data.empty:
            sam1 = pd.DataFrame({'name': spectra.index[clustered][selected_samples_idx],
                                'cluster':np.array(labels)[clustered][selected_samples_idx]},
                                index = selected_samples_idx)
        else:
            sam1 = meta_data.iloc[clustered,:].iloc[selected_samples_idx,:]
            sam1.insert(loc=0, column='index', value=selected_samples_idx)
            sam1.insert(loc=1, column='cluster', value=np.array(labels)[selected_samples_idx])
        sam1.index = np.arange(len(selected_samples_idx))+1
        info.info(f'Information !\n - The total number of samples: {n_samples}.\n- The number of samples selected for reference analysis: {sam1.shape[0]}.\n - The proportion of samples selected for reference analysis: {round(sam1.shape[0]/n_samples*100)}%.')
        sam = sam1
        # if clus_method == cluster_methods[2]:
        #     unclus = sel.checkbox("Include non clustered samples (for HDBSCAN clustering)", value=True)

        if clus_method == cluster_methods[2]:
            unclus = sel.checkbox("Include non clustered samples (for HDBSCAN clustering)", value=True)

            if selected_samples_idx:
                if unclus:
                    if meta_data.empty:
                        sam2 = pd.DataFrame({'name': spectra.index[non_clustered],
                                            'cluster':['Non clustered']*len(spectra.index[non_clustered])},
                                            index = spectra.index[non_clustered])
                    else :
                        sam2 = meta_data.iloc[non_clustered,:]
                        sam2.insert(loc=0, column='index', value= spectra.index[non_clustered])
                        sam2.insert(loc=1, column='cluster', value=['Non clustered']*len(spectra.index[non_clustered]))
                    
                    sam = pd.concat([sam1, sam2], axis = 0)
                    sam.index = np.arange(sam.shape[0])+1
                    info.info(f'- The number of Non-clustered samples: {sam2.shape[0]}.\n - The proportion of Non-clustered samples: {round(sam2.shape[0]/n_samples*100)}%')
        else:
            sam = sam1
        sel.write(sam)
        



# figs_list = os.listdir("./Report/figures")
if data_file:
    Nb_ech = str(n_samples)
    nb_clu = str(sam1.shape[0])
    ###############################
    st.header('Download Analysis Results', divider='blue')
    M9, M10 = st.columns([1,1])
    M10.info('The results are automatically converted into LaTeX code, a strong typesetting system noted for its remarkable document formatting.\
                The comprehensive capabilities of LaTeX ensure that your data and findings are cleanly and properly presented,\
                    with accurate formatting and organizing.')

    items_download = M9.selectbox('To proceed, please choose the file or files you want to download from the list below:',
                    options = ['','Selected Subset', 'Report', 'Both Selected Subset & Report'], index=0, format_func=lambda x: x if x else "<Select>",
                    key=None, help=None, on_change=None, args=None, kwargs=None, placeholder="Choose an option", disabled=False, label_visibility="visible")


    ## Save model and download report

    # st.session_state.a = "Please wait while your LaTeX report is being compiled..."
    date_time = datetime.datetime.strftime(datetime.date.today(), '_%Y_%m_%d_')
    # match items_download:
    #     case '':

    if items_download:
        if M9.button('Download', type="primary"):
            match items_download:
                case '':
                    M9.warning('Please select an item from the dropdown list!')
                case 'Selected Subset':
                    sam.to_csv('./data/subset/seleced subset.csv', sep = ";")
                
                case 'Report':
                    # M9.info("Please wait while your LaTeX report is being compiled...")
                    latex_report = report.report('Representative subset selection', data_file.name, dim_red_method, clus_method, Nb_ech, ncluster, selection, selection_number, nb_clu,tcr, sam)
                    report.compile_latex()

                case 'Both Selected Subset & Report':
                    sam.to_csv('./data/subset/seleced subset.csv', sep = ";")
                    latex_report = report.report('Representative subset selection', data_file.name, dim_red_method, clus_method, Nb_ech, ncluster, selection, selection_number, nb_clu,tcr, sam)
                    report.compile_latex()
            M9.success('The selected item has been exported successfully!')