from utils.data_parsing import meta_st
from common import *
st.set_page_config(page_title="NIRS Utils", page_icon=":goat:", layout="wide")


# layout
UiComponents(pagespath=pages_folder, csspath=css_file, imgpath=image_path,
             header=True, sidebar=True, bgimg=False, colborders=True)
st.header("Calibration Subset Selection")  # page title
st.markdown(
    "Select a representative subset of samples for NIR calibration development.")
c1, c2 = st.columns([3, 1])
c1.image("./images/sample selection.png",
         use_column_width=True)  # graphical abstract


# empty temp figures
report_path = Path("report")
report_path_rel = Path("./report")

# ~~~~~~~~~~~~~~~~ clean the analysis results dir ~~~~~~~~~~~~~~~~
HandleItems.delete_files(keep=['.py', '.pyc', '.bib', '.tex'])
HandleItems.delete_dir(delete=['report/results/model'])


################################### I - Data Loading and Visualization ########################################
# loader for datafile
file = c2.file_uploader("Data file", type=[
                        "csv", "dx"], help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns")

# Preallocation of data structure
spectra = DataFrame()
meta_data = DataFrame()
md_df_st_ = DataFrame()
tcr = DataFrame()
sam = DataFrame()
sam1 = DataFrame()
selected_samples = DataFrame()
selected = []
l1 = []
color_palette = None
dr_model = None  # dimensionality reduction model
cl_model = None  # clustering model
selection = None
selection_number = "None"
samples_df_chem = DataFrame
selected_samples = []
selected_samples_idx = []

if not file:
    c2.info('Info: Please load data file !')

else:
    # extension = file.name.split(".")[-1]
    userfilename = file.name.replace(f".{file.name.split(".")[-1]}", '')

    match file.name.split(".")[-1]:
        # Load .csv file
        case 'csv':
            with c2:
                # ~~~~~~~~ select file dialect
                c2_1, c2_2 = st.columns([.5, .5])
                with c2_1:
                    dec = st.radio('decimal:', options=[
                                   ".", ","], horizontal=True)
                    sep = st.radio("separator:", options=[
                                   ";", ","], horizontal=True)
                with c2_2:
                    hdr = st.radio("header: ", options=[
                                   "yes", "no"], horizontal=True)
                    names = st.radio("samples name:", options=[
                                     "yes", "no"], horizontal=True)

                hdr = 0 if hdr == "yes" else None
                names = 0 if names == "yes" else None
                hash_ = ObjectHash(current=None, add=[
                                   file.getvalue(), hdr, names, dec, sep])

                # ~~~~~~~~ read the csv file
                try:
                    # spectra = read_csv(file, decimal=dec, sep=sep, index_col=names)
                    spectra, meta_data = csv_parser(
                        path=file, decimal=dec, separator=sep, index_col=names, header=hdr, change=hash_)
                    if spectra.shape[1] > 20:
                        st.success(
                            "The data have been loaded successfully and spectral data was successfully detected, you might need to tune dialect.", icon="✅")
                    else:
                        st.warning(
                            "The data have been loaded successfully and but spectral data was not detected.")

                except:
                    st.error('''Error: The format of the file does not correspond to the expected dialect settings.
                              To read the file correctly, please adjust the dialect parameters.''')

        # Load .dx file
        case 'dx':
            with c2:
                with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
                    tmp.write(file.read())
                    dxfile = tmp.name
                    hash_ = ObjectHash(current=None, add=file.getvalue())

                try:
                    from utils.data_parsing import jcamp_parser
                    spectra, _, meta_data = jcamp_parser(
                        path=dxfile, include=['x_block', 'meta'], change=hash_)
                    st.success(
                        "The data have been loaded successfully", icon="✅")
                except:
                    st.error(
                        '''Error: an issue was encontered while parsing the uploaded file.''')

if not spectra.empty:
    if len(spectra.index) > len(set(spectra.index)):
        c2.warning(
            "Duplicate sample IDs found. Suffixes (#1, #2, ...) have been added to duplicate IDs.")
        meta_data['names'] = spectra.index
        # Keep all duplicates (True for replicated)
        mask = spectra.index.duplicated(keep=False)
        # For the duplicated sample_ids, apply suffix (_1, _2, etc.)
        spectra.index = spectra.index.where(~mask,
                                            spectra.groupby(spectra.index).cumcount().add(1).astype(str).radd(spectra.index.astype(str) + '#'))

if not spectra.empty:
    if not meta_data.empty:
        meta_data.index = [str(i) for i in spectra.index]
        md_df_st_ = meta_st(meta_data)

        if md_df_st_.shape[1] > 0:
            n_colors = 30
            # Evenly spaced hues
            hues = np.linspace(0, 1, n_colors, endpoint=False)
            import random
            random.seed(42)
            import matplotlib.colors as mcolors
            colorslist = [mcolors.rgb2hex(plt.cm.hsv(hue)) for hue in hues]
            random.shuffle(colorslist)

        else:
            colorslist = None

    if spectra.select_dtypes(include=['float']).shape[1] < 50:
        c2.warning(
            'Error: Your data is not multivariable, check the number of variables in your data or well tune the dialect.')
        spectra = DataFrame


if not spectra.empty:
    n_specs = spectra.shape[0]  # n_samples
    nwls = spectra.shape[1]  # nwl
    wls = list(spectra.columns)  # colnames
    spectra.index = [str(i) for i in list(spectra.index)]

    id = spectra.index  # rownames

    with c2:
        st.write('Data summary:')
        st.write('- the number of spectra:'+spectra.shape[0])
        st.write('- the number of wavelengths:'+spectra.shape[1])
        st.write('- the number of categorical variables:'+meta_data.shape[1])
################################################### END : I- Data loading and preparation ####################################################


################################################### BEGIN : visualize and split the data ####################################################
st.subheader("I - Spectral Data Visualization", divider='blue')
if not spectra.empty:
    c3, c4 = st.columns([3, 1])
    with c4:
        st.info('Color spectra based on a categorical variable (for easier visualization: only relevant variables with fewer than 60 categories are displated in the dropdown list.)')
        filter = ['']+md_df_st_.columns.to_list()
        specs_col = st.selectbox('Color by:', options=filter, format_func=fmt,
                                 disabled=True if len(filter) == 1 else False)
        if len(filter) == 1:
            st.write("No categorical variable was provided!")

    with c3:
        if specs_col != '':
            cmap = dict(
                zip(set(md_df_st_[specs_col]), colorslist[:len(set(md_df_st_[specs_col]))]))
            fig_spectra = plot_spectra(spectra, color=md_df_st_[
                                       specs_col], cmap=cmap, xunits='Wavelength/Wavenumber', yunits="Signal intensity")

        else:
            fig_spectra = plot_spectra(
                spectra, color=None, cmap=None, xunits='Wavelength/Wavenumber', yunits="Signal intensity")
            cmap = None
        st.pyplot(fig_spectra)

    with c4:
        if specs_col != '':
            st.write('The distribution of samples across categories')
            barh = barhplot(md_df_st_[[specs_col]], cmap=cmap)
            st.pyplot(barh)

        elif len(filter) > 1 and specs_col == '':
            st.write("No categorical variable was selected!")

    if st.session_state.interface == 'advanced':
        with c3:
            values = st.slider('Select a range of values',
                               min_value=0, max_value=nwls, value=(0, nwls))
            hash_ = ObjectHash(current=hash_, add=values)
            spectra = spectra.iloc[:, values[0]:values[1]]
            nwls = spectra.shape[1]
            wls = wls[values[0]:values[1]]

            st.pyplot(plot_spectra(
                spectra.mean(), xunits='Wavelength/Wavenumber', yunits="Signal intensity"))

        # st.selectbox('Variable', options= [''], disabled=True if len(colfilter)>1, else False)
        # st.write(data_info) ## table showing the number of samples in the data file

################################################### END : visualize and split the data ####################################################


############################## Exploratory data analysis ###############################
st.subheader(
    "II - Exploratory Data Analysis-Multivariable Data Analysis", divider='blue')
# ~~~~~~~~~~~~~~ algorithms available on our app ~~~~~~~~~~~~~~~~
match st.session_state["interface"]:
    case 'simple':
        dim_red_methods, cluster_methods, seltechs = ['PCA'], [''], ['random']

    case 'advanced':
        # List of dimensionality reduction algos
        dim_red_methods = ['PCA', 'UMAP', 'NMF']
        # List of clustering algos
        cluster_methods = ['KMEANS', 'HDBSCAN', 'AP']
        seltechs = ['random', 'kennard-stone', 'meta-medoids', 'meta-ks']

###### 1- Dimensionality reduction ######
t = DataFrame  # scores
p = DataFrame  # loadings
if not spectra.empty:
    xc = standardize(spectra, center=True, scale=False)

    c5, c6, c7, c8, c9, c10, c11 = st.columns([1, 1, 0.6, 0.6, 0.6, 1.5, 1.5])
    with c5:
        # select a dimensionality reduction algorithm
        dim_red_method = st.selectbox("Dimensionality reduction techniques: ",
                                      options=['']+dim_red_methods if len(dim_red_methods) > 2 else dim_red_methods, format_func=fmt,
                                      disabled=False if len(dim_red_methods) > 2 else True)
        hash_ = ObjectHash(current=hash_, add=dim_red_method)

        match dim_red_method:
            case '':
                st.info('Info: Select a dimensionality reduction technique!')

            case 'UMAP':
                supervised = st.selectbox('Supervised UMAP by(optional):', options=filter,
                                          format_func=fmt, disabled=False if len(filter) > 1 else True)
                umapsupervisor = None if supervised == '' else md_df_st_[
                    supervised]
                hash_ = ObjectHash(current=hash_, add=umapsupervisor)

        # select a clustering reduction algorithm
        disablewidgets = [False if (
            dim_red_method and st.session_state.interface == 'advanced') else True][0]
        clus_method = st.selectbox("Clustering techniques(optional): ",
                                   options=[
                                       ''] + cluster_methods if len(cluster_methods) > 2 else cluster_methods,
                                   key=38, format_func=fmt, disabled=disablewidgets)

        # if disablewidgets == False and dim_red_method in dim_red_methods:
        #     inf = st.info('Info: Select a clustering technique!')

        if dim_red_method:
            @st.cache_data
            def dimensionality_reduction(dim_red_method, change):
                match dim_red_method:
                    case "PCA":
                        from utils.dim_reduction import LinearPCA
                        dr_model = LinearPCA(xc, Ncomp=8)
                    case "UMAP":
                        from utils.dim_reduction import Umap
                        dr_model = Umap(numerical_data=spectra,
                                        cat_data=umapsupervisor)
                    case 'NMF':
                        from utils.dim_reduction import Nmf
                        dr_model = Nmf(spectra, Ncomp=3)
                return dr_model
            dr_model = dimensionality_reduction(dim_red_method, change=hash_)

        if dr_model:
            axis1 = c7.selectbox(
                "x-axis", options=dr_model.scores_.columns, index=0)
            axis2 = c8.selectbox(
                "y-axis", options=dr_model.scores_.columns, index=1)
            axis3 = c9.selectbox(
                "z-axis", options=dr_model.scores_.columns, index=2)
            axis = np.unique([axis1, axis2, axis3])

            t = dr_model.scores_.loc[:, axis]
            t.index = spectra.index
            tcr = standardize(t)

if not t.empty:
    if dim_red_method == 'UMAP':
        c12 = st.container()
    else:
        c12, c13 = st.columns([3, 3])


if not spectra.empty:
    with c6:
        sel_ratio = st.number_input('Enter the number/fraction of samples to be selected:', min_value=0.01,
                                    max_value=float("{:.2f}".format(spectra.shape[0])), value=0.20,
                                    format="%.2f", disabled=disablewidgets)
        if sel_ratio > 1.00:
            ratio = int(sel_ratio)
        elif sel_ratio < 1.00:
            ratio = int(sel_ratio * spectra.shape[0])
        ObjectHash(current=hash_, add=ratio)

        if dr_model and not clus_method:
            seltech = st.radio('Select samples selection strategy:', options=[
                               'random', 'kennard-stone'], disabled=True if st.session_state.interface == 'simple' else False)

        elif dr_model and clus_method:
            disabled1 = False if clus_method in cluster_methods else True
            seltech = st.radio('Select samples selection strategy:',
                               options=seltechs, disabled=disabled1)


if not t.empty:
    # ~~~~~~~~~~~~~~~~~~~~~~~ II- Clustering ~~~~~~~~~~~~~~~~~~~~~~~~~~
    if clus_method:
        from utils.clustering import clustering
        labels, n_clusters = clustering(X=tcr, method=clus_method)

    # ~~~~~~  III - Samples selection based on the reduced data presentation ~~~~~~~
    from utils.samsel import selection_method
    ObjectHash(current=hash_, add=seltech)
    if 'labels' not in globals():
        custom_color_palette = px.colors.qualitative.Plotly[:1]
        selected = selection_method(X=tcr, method=seltech, rset=ratio)

    else:
        custom_color_palette = px.colors.qualitative.Plotly[:n_clusters]
        selected = []
        for i in [i for i in set(labels.index) if i != 'Non clustered']:
            rset_meta = .5 if tcr.loc[labels.loc[i].values.ravel(
            ), :].shape[0] > 1 else 1
            selected += selection_method(X=tcr.loc[labels.loc[i].values.ravel(), :], method=seltech,
                                         rset=ratio, rset_meta=.4)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ results visualization ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Scores plot
if not t.empty:
    if clus_method:
        filter[0] = clus_method
        desactivatelist = True if len(filter) <= 1 else False
    else:
        desactivatelist = True if len(filter) <= 1 else False
    with c12:
        st.write('Scores plot')

        if len(axis) == 1:
            tcr['1d'] = np.random.uniform(-.5, .5, tcr.shape[0])

        colfilter = st.selectbox('Color by :', options=filter,
                                 format_func=fmt, disabled=desactivatelist)
        ObjectHash(colfilter)

        if colfilter:
            if colfilter not in cluster_methods:  # case meta variable
                cmap = dict(
                    zip(set(md_df_st_[colfilter]), colorslist[:len(set(md_df_st_[colfilter]))]))
                tcr['color'] = md_df_st_.loc[:, colfilter]

            elif colfilter in cluster_methods:  # case clustering
                if 'colorslist' not in globals():
                    n_colors = len(set(labels.index))
                    # Evenly spaced hues
                    hues = np.linspace(0, 1, n_colors, endpoint=False)
                    st.write(555)
                    st.write(hues)
                    st.write(555)
                    import random
                    random.seed(42)
                    import matplotlib.colors as mcolors

                    colorslist = [mcolors.rgb2hex(
                        plt.cm.hsv(hue)) for hue in hues]
                    random.shuffle(colorslist)

                cmap = dict(
                    zip(set(labels.index), colorslist[:len(set(labels.index))]))
                tcr['color'] = labels.index
        else:
            cmap = {'Sample': "#7ab0c7"}
            tcr['color'] = ['Sample'] * tcr.shape[0]

        # start visualization
        match t.shape[1]:
            case 3:
                hover1 = {'sample:': tcr.index, 'color': False,
                          axis[0]: False, axis[1]: False, axis[2]: False}
                fig = px.scatter_3d(tcr, x=axis[0], y=axis[1], z=axis[2], color='color',
                                    color_discrete_map=cmap, hover_data=hover1)
                fig.add_scatter3d(x=tcr.loc[selected, axis[0]], y=tcr.loc[selected, axis[1]], z=tcr.loc[selected, axis[2]],
                                  mode='markers', marker=dict(size=5, color='black'),
                                  name='selected samples', hovertext=tcr.loc[selected, :].index)

            case 2:
                hover1 = {'sample:': tcr.index, 'color': False,
                          axis[0]: False, axis[1]: False}
                fig = px.scatter(tcr, x=axis[0], y=axis[1], color='color',
                                 color_discrete_map=cmap, hover_data=hover1)
                fig.add_scatter(x=tcr.loc[selected, axis[0]], y=tcr.loc[selected, axis[1]],
                                mode='markers', marker=dict(size=5, color='black'),
                                name='selected samples', hovertext=tcr.loc[selected, :].index)

            case 1:
                hover1 = {'sample:': tcr.index, 'color': False,
                          '1d': False, axis[0]: False}
                yy = np.random.uniform(-.5, .5, tcr.shape[0])
                fig = px.scatter(tcr, x=axis[0], y='1d', color="color",
                                 color_discrete_map=cmap, hover_data=hover1)

                fig.add_scatter(x=tcr.loc[selected, axis[0]], y=tcr.loc[selected, '1d'],
                                mode='markers', marker=dict(size=5, color='black'),
                                name='selected samples',
                                hovertext=tcr.loc[selected, :].index)
                fig.update_layout(yaxis_range=[-1.6, 1.6])
                fig.update_yaxes(visible=False)

        st.plotly_chart(fig, use_container_width=True)


if not spectra.empty:
    if dim_red_method in ['PCA', 'NMF']:
        with c13:
            st.write('Loadings plot')
            if file.name.split(".")[-1] == 'dx':
                xlab = ["Wavenumbers (1/cm)" if meta_data.loc[:,
                                                              'xunits'].iloc[0] == '1/cm' else 'Wavelengths (nm)']
            elif file.name.split(".")[-1] == 'csv':
                xlab = ['Wavelength/Wavenumber']

            p = dr_model.loadings_.T
            freq = DataFrame(wls, columns=xlab, index=p.index)
            df1 = concat([p, freq], axis=1).melt(
                id_vars=freq.columns,  var_name='Loadings:', value_name='Value')

            loadingsplot = px.line(df1, x=xlab, y='Value', color='Loadings:',
                                   color_discrete_sequence=px.colors.qualitative.Plotly)
            loadingsplot.update_layout(legend=dict(x=1, y=0, font=dict(family="Courier", size=12, color="black"),
                                                   bordercolor="black", borderwidth=2))
            loadingsplot.update_layout(
                xaxis_title=xlab[0], yaxis_title='Value')

            st.plotly_chart(loadingsplot, use_container_width=True)

# #############################################################################################################
    if dim_red_method == 'PCA':
        c14, c15 = st.columns([3, 3])
        with c14:
            st.write('Influence plot')
            # Q residuals: Q residuals represent the magnitude of the variation remaining in each sample after projection through the model
            p = p.loc[:, axis]
            xp = np.dot(t, p.T)
            tcr["residuals"] = np.diag(np.subtract(
                xc.values, xp) @ np.subtract(xc.values, xp).T)

            # Laverage
            # Tr(T(T'T)^(-1)T'): #reference :Introduction to Multi- and Megavariate Data Analysis using Projection Methods (PCA and PLS),
            # L. Eriksson, E. Johansson, N. Kettaneh-Wold and S. Wold, Umetrics 1999, p. 466
            Hat = t.loc[:, axis].values @ np.linalg.inv(
                t.loc[:, axis].values.T @ t.loc[:, axis].values) @ t.loc[:, axis].values.T
            tcr["leverage"] = DataFrame(
                np.diag(Hat) / np.trace(Hat), index=spectra.index, columns=['Leverage'])

            # compute tresholds
            tresh3 = 2 * tcr.shape[1]/n_specs
            from scipy.stats import chi2
            tresh4 = chi2.ppf(0.05, df=len(axis))

            # Retrieve the index names of these rows
            exceed_lev = tcr[(tcr['leverage'] > tresh3) & (
                tcr['residuals'] > tresh4)].index.tolist()

            # plot results
            influence_plot = px.scatter(tcr, x="leverage", y="residuals", color='color',
                                        color_discrete_map=cmap, hover_data=hover1)
            influence_plot.add_scatter(x=tcr.loc[selected, "leverage"], y=tcr.loc[selected, "residuals"],
                                       mode='markers', marker=dict(size=5, color='black'),
                                       name='selected samples', hovertext=tcr.loc[selected, :].index)
            influence_plot.add_vline(
                x=tresh3, line_width=1, line_dash='dash', line_color='red')
            influence_plot.add_hline(
                y=tresh4, line_width=1, line_dash='dash', line_color='red')

            # add labels for the outliers
            for i in exceed_lev:
                influence_plot.add_annotation(dict(x=tcr['leverage'].loc[i], y=tcr['residuals'].loc[i], showarrow=True,
                                                   text=i, font=dict(color="black", size=15), xanchor='auto', yanchor='auto'))

            influence_plot.update_traces(marker=dict(size=6), showlegend=True)
            influence_plot.update_layout(xaxis_title="Leverage", yaxis_title="Q-residuals",
                                         font=dict(size=20), width=800, height=600)
            st.plotly_chart(influence_plot, use_container_width=True)


#             influence_plot.update_traces(marker=dict(size= 6), showlegend= True)
#             influence_plot.update_layout(font=dict(size=23), width=800, height=500)
#             for annotation in influence_plot.layout.annotations:
#                 annotation.font.size = 35
#             influence_plot.update_traces(marker=dict(size= 10), showlegend= False)
#             influence_plot.add_annotation(text= '(a)', align='center', showarrow= False, xref='paper', yref='paper', x=-0.125, y= 1,
#                                              font= dict(color= "black", size= 35), bgcolor ='white', borderpad= 2, bordercolor= 'black', borderwidth= 3)
#             # influence_plot.write_image('./report/results/figures/influence_plot.png', engine = 'kaleido')

        with c15:
            st.write('T²-Hotelling vs Q-residuals plot')
            # Hotelling
            tcr['hotelling'] = (t**2/t.std()).sum(axis=1)

            # compute tresholds
            from scipy.stats import f, chi2
            fcri = f.isf(0.05, 3, n_specs)
            tresh0 = (3 * (n_specs ** 2 - 1) * fcri) / \
                (n_specs * (n_specs - 3))
            tresh1 = chi2.ppf(0.05, df=3)

            # Retrieve the index names of these rows
            exceed_hot = tcr[(tcr['hotelling'] > tresh0) & (
                tcr['residuals'] > tresh1)].index.tolist()

            # plot results
            hotelling_plot = px.scatter(tcr, x='hotelling', y='residuals', color="color",
                                        color_discrete_map=cmap, hover_data=hover1)
            hotelling_plot.add_scatter(x=tcr.loc[selected, 'hotelling'], y=tcr.loc[selected, 'residuals'],
                                       mode='markers', marker=dict(size=5, color='black'),
                                       name='selected samples', hovertext=tcr.loc[selected, :].index)

            hotelling_plot.update_layout(xaxis_title="Hotelling-T² distance", yaxis_title="Q-residuals",
                                         font=dict(size=20), width=800, height=600)
            hotelling_plot.add_vline(
                x=tresh0, line_width=1, line_dash='dash', line_color='red')
            hotelling_plot.add_hline(
                y=tresh1, line_width=1, line_dash='dash', line_color='red')

            # add labels for the outliers
            for i in exceed_hot:
                hotelling_plot.add_annotation(dict(x=tcr['hotelling'].loc[i], y=tcr['residuals'].loc[i], showarrow=True, text=i,
                                                   font=dict(color="black", size=15), xanchor='auto', yanchor='auto'))

            hotelling_plot.update_traces(marker=dict(size=6), showlegend=True)
            hotelling_plot.update_layout(
                font=dict(size=23), width=800, height=500)
            st.plotly_chart(hotelling_plot, use_container_width=True)


#             # for annotation in hotelling_plot.layout.annotations:
#             #     annotation.font.size = 35
#             # hotelling_plot.update_layout(font=dict(size=23), width=800, height=600)
#             # hotelling_plot.update_traces(marker=dict(size= 10), showlegend= False)
#             # hotelling_plot.add_annotation(text= '(b)', align='center', showarrow= False, xref='paper', yref='paper', x=-0.125, y= 1,
#             #                                  font= dict(color= "black", size= 35), bgcolor ='white', borderpad= 2, bordercolor= 'black', borderwidth= 3)
# #             # hotelling_plot.write_image("./report/results/figures/hotelling_plot.png", format="png")

st.subheader('III - Selected Samples for Reference Analysis', divider='blue')
if selected:
    c16, c17 = st.columns([3, 1])
    with c16:
        st.write("Tabular identifiers of selected samples for reference analysis:")

        if 'labels' in globals():
            labels['cluster'] = labels.index
            labels.index = labels['names']
            result = DataFrame({'names': selected,
                                'cluster': selected}, index=selected)

        else:
            if not meta_data.empty:
                if 'name' in meta_data.columns:
                    subset = meta_data.drop('name', axis=1).loc[selected]
                else:
                    subset = meta_data.loc[selected]
            else:
                subset = DataFrame(selected, columns=['names'])
            st.write(subset)

        with c17:
            if clus_method in filter:
                filter.remove(clus_method)
            st.info('Information !\r\n - The total number of samples: ' + n_specs + '.\r\n- The number of samples selected for reference analysis: ' +
                    len(selected) +'.\r\n - The proportion of samples selected for reference analysis: '+ round(len(selected)/n_specs*100) + '%.')
            selected_col = st.selectbox('Color by:  ', options=filter, format_func=fmt,
                                        disabled=True if len(filter) == 1 else False)
            if selected_col:
                cmap2 = dict(
                    zip(set(md_df_st_.loc[selected][selected_col]), colorslist[:len(set(md_df_st_.loc[selected][selected_col]))]))
                st.write('The distribution of selected samples across categories')

                barhsel = barhplot(
                    md_df_st_.loc[selected][[selected_col]], cmap=cmap2)
                st.pyplot(barhsel)


#         if meta_data.empty:
#             # clustered: a list of ints
#             # sam1 = DataFrame({'name': selected_samples_idx,
#             #                     'cluster':np.array(labels)[selected_samples_idx]},
#             #                     index = selected_samples_idx)
#             st.write(selected_samples_idx)
#             st.write(clustered)
#         else:
#             sam1 = meta_data.iloc[clustered,:].loc[selected_samples_idx,:]
#             sam1.insert(loc=0, column='index', value=selected_samples_idx)
#             sam1.insert(loc=1, column='cluster', value=np.array(labels)[selected_samples_idx])
#         sam1.index = np.arange(len(selected_samples_idx))+1
#         sam = sam1

#         if clus_method =='HDBSCAN':
#             with c16:
#                 unclus = st.checkbox("Include non clustered samples (for HDBSCAN clustering)", value=True)

#             if selected_samples_idx:
#                 if unclus:
#                     if meta_data.empty:
#                         sam2 = DataFrame({'name': spectra.index[non_clustered],
#                                             'cluster':['Non clustered']*len(spectra.index[non_clustered])},
#                                             index = spectra.index[non_clustered])
#                     else :
#                         sam2 = meta_data.iloc[non_clustered,:]
#                         sam2.insert(loc=0, column='index', value= spectra.index[non_clustered])
#                         sam2.insert(loc=1, column='cluster', value=['Non clustered']*len(spectra.index[non_clustered]))

#                     sam = concat([sam1, sam2], axis = 0)
#                     sam.index = np.arange(sam.shape[0])+1
#                     with c17:
#                         st.info(f'- The number of Non-clustered samples: {sam2.shape[0]}.\n - The proportion of Non-clustered samples: {round(sam2.shape[0]/n_specs*100)}%')
#         else:
#             sam = sam1
#         with c16:
#             st.write(sam)


# if not sam.empty:
#     zip_data = ""
#     Nb_ech = str(n_specs)
#     nb_clu = str(sam1.shape[0])
#     st.subheader('Download the analysis results')
#     st.write("**Note:** Please check the box only after you have finished processing your data and are satisfied with the results. Checking the box prematurely may slow down the app and could lead to crashes.")
#     decis = st.checkbox("Yes, I want to download the results")
#     if decis:
#         ###################################################
#         # ## generate report
#         @st.cache_data
#         def export_report(change):
#             latex_report = report.report('Representative subset selection', file.name, dim_red_method,
#                                         clus_method, Nb_ech, ncluster, selection, selection_number, nb_clu,tcr, sam)

#         @st.cache_data
#         def preparing_results_for_downloading(change):
#             # path_to_report = Path("report")############################### i am here
#             match file.name.split(".")[-1]:
#                 # load csv file
#                 case 'csv':
#                     imp.to_csv('report/results/dataset/'+ file.name, sep = ';', encoding = 'utf-8', mode = 'a')
#                 case 'dx':
#                     with open('report/results/dataset/'+file.name, 'w') as dd:
#                         dd.write(dxdata)

#             fig_spectra.savefig(report_path_rel/"results/figures/spectra_plot.png", dpi = 400) ## Export report

#             if len(axis) == 3:
#                 for i in range(len(comb)):
#                     fig_export[f'scores_pc{comb[i][0]}_pc{comb[i][1]}'].write_image(report_path_rel/f'results/figures/scores_pc{str(comb[i][0]+1)}_pc{str(comb[i][1]+1)}.png')
#             elif len(axis)==2 :
#                 fig_export['fig'].write_image(report_path_rel/'results/figures/scores_plot2D.png')
#             elif len(axis)==1 :
#                 fig_export['fig'].write_image(report_path_rel/'results/figures/scores_plot1D.png')

#             # Export du graphique
#             if dim_red_method in ['PCA','NMF']:
#                 import plotly.io as pio
#                 img = pio.to_image(loadingsplot, format="png")
#                 with open(report_path_rel/"results/figures/loadings_plot.png", "wb") as f:
#                     f.write(img)
#             if dim_red_method == 'PCA':
#                 hotelling_plot.write_image(report_path_rel/"results/figures/hotelling_plot.png", format="png")
#                 influence_plot.write_image(report_path_rel/'results/figures/influence_plot.png', engine = 'kaleido')

#             sam.to_csv(report_path_rel/'results/Selected_subset_for_calib_development.csv', sep = ';')
#             export_report(change = hash_)
#             if Path(report_path_rel/"report.tex").exists():
#                 report.generate_report(change = hash_)
#             if Path(report_path_rel/"report.pdf").exists():
#                 move(report_path_rel/"report.pdf", "./report/results/report.pdf")
#             return change


#         preparing_results_for_downloading(change = hash_)
#         report.generate_report(change = hash_)


#         @st.cache_data
#         def tempdir(change):
#             from tempfile import TemporaryDirectory
#             with  TemporaryDirectory( prefix="results", dir="./report") as temp_dir:# create a temp directory
#                 tempdirname = os.path.split(temp_dir)[1]

#                 if len(os.listdir(report_path_rel/'results/figures/'))>=2:

#                     make_archive(base_name= report_path_rel/"Results", format="zip", base_dir="results", root_dir = "./report")# create a zip file
#                     move(report_path_rel/"Results.zip", f"./report/{tempdirname}/Results.zip")# put the inside the temp dir
#                     with open(report_path_rel/f"{tempdirname}/Results.zip", "rb") as f:
#                         zip_data = f.read()
#             return tempdirname, zip_data

#         try :
#             tempdirname, zip_data = tempdir(change = hash_)
#             # st.download_button(label = 'Download', data = zip_data, file_name = f'Nirs_Workflow_{date_time}_SamSel_.zip', mime ="application/zip",
#             #             args = None, kwargs = None,type = "primary",use_container_width = True)
#         except:
#             pass
#     date_time = datetime.now().strftime('%y%m%d%H%M')
#     disabled_down = True if zip_data == '' else False
#     st.download_button(label = 'Download', data = zip_data, file_name = f'Nirs_Workflow_{date_time}_SamSel_.zip', mime ="application/zip",
#                 args = None, kwargs = None,type = "primary",use_container_width = True, disabled = disabled_down)


#     HandleItems.delete_files(keep = ['.py', '.pyc','.bib'])