Skip to content
Snippets Groups Projects
1-samples_selection.py 31.1 KiB
Newer Older
  • Learn to ignore specific revisions
  • from Packages import *
    st.set_page_config(page_title="NIRS Utils", page_icon=":goat:", layout="wide")
    from Modules import *
    
    
    # empty temp figures
    
    DIANE's avatar
    DIANE committed
    for i in ['Report/figures','Report/datasets']:
        repertoire_a_vider = Path(i)
        if os.path.exists(repertoire_a_vider):
            for fichier in os.listdir(repertoire_a_vider):
                chemin_fichier = os.path.join(repertoire_a_vider, fichier)
                if os.path.isfile(chemin_fichier) or os.path.islink(chemin_fichier):
                    os.unlink(chemin_fichier)
                elif os.path.isdir(chemin_fichier):
                    shutil.rmtree(chemin_fichier)
    
    
    # HTML pour le bandeau "CEFE - CNRS"
    
    #load specific model page css
    local_css(css_file / "style_model.css")
    
    DIANE's avatar
    DIANE committed
    
    # algorithms available in our app
    
    dim_red_methods=['', 'PCA','UMAP', 'NMF']  # List of dimensionality reduction algos
    
    cluster_methods = ['', 'Kmeans','HDBSCAN', 'AP', 'KS', 'RDM'] # List of clustering algos
    
    selec_strategy = ['center','random']
    
    DIANE's avatar
    DIANE committed
    match st.session_state["interface"]:
        case 'simple':
            st.write(':red[Automated Simple Interface]')
            # hide_pages("Predictions")
            if 37 not in st.session_state:
                default_reduction_option = 1
            else:
                default_reduction_option = dim_red_methods.index(st.session_state.get(37))
            if 38 not in st.session_state:
                default_clustering_option = 1
            else:
                default_clustering_option = cluster_methods.index(st.session_state.get(38))
            if 102 not in st.session_state:
                default_sample_selection_option = 1
            else:
                default_sample_selection_option = selec_strategy.index(st.session_state.get(102))
            
        case'advanced':
            default_reduction_option = 0
            default_clustering_option = 0
            default_sample_selection_option = 0
    
    ################################### I - Data Loading and Visualization ########################################
    
    DIANE's avatar
    DIANE committed
    date_time = datetime.datetime.now().strftime('_%y_%m_%d_%H_%M_')
    
    st.title("Calibration Subset Selection")
    
    col2, col1 = st.columns([3, 1])
    
    DIANE's avatar
    DIANE committed
    col2.image("./images/sample selection.png", use_column_width=True)
    
    ## Preallocation of data structure
    
    DIANE's avatar
    DIANE committed
    spectra = pd.DataFrame()
    meta_data = pd.DataFrame()
    tcr=pd.DataFrame()
    sam=pd.DataFrame()
    sam1=pd.DataFrame()
    selected_samples = pd.DataFrame()
    
    DIANE's avatar
    DIANE committed
    non_clustered = None
    
    DIANE's avatar
    DIANE committed
    l1 = []
    
    DIANE's avatar
    DIANE committed
    labels = []
    color_palette = None
    dr_model = None # dimensionality reduction model
    cl_model = None # clustering model
    
    selection = None
    selection_number = None
    
    DIANE's avatar
    DIANE committed
    
    
    DIANE's avatar
    DIANE committed
    # loader for datafile
    
    data_file = col1.file_uploader("Data file", type=["csv","dx"], help=" :mushroom: select a csv matrix with samples as rows and lambdas as columns", key=5)
    
    if not data_file:
        col1.warning('⚠️ Please load data file !')
    else:
    
    DIANE's avatar
    DIANE committed
        # Retrieve the extension of the file
    
    DIANE's avatar
    DIANE committed
        # test = data_file.name[data_file.name.find('.'):]
        
        
        extension = data_file.name.split(".")[-1]
        userfilename = data_file.name.replace(f".{extension}", '')
    
        match extension:
    
    DIANE's avatar
    DIANE committed
        ## Load .csv file
    
    DIANE's avatar
    DIANE committed
            case 'csv':
    
    DIANE's avatar
    DIANE committed
                    # Select list for CSV delimiter
    
    DIANE's avatar
    DIANE committed
                    psep = st.radio("Select csv separator - _detected_: " + str(find_delimiter('data/'+data_file.name)), options=[";", ","], index=[";", ","].index(str(find_delimiter('data/'+data_file.name))),horizontal=True, key=9)
    
    DIANE's avatar
    DIANE committed
                        # Select list for CSV header True / False
    
    DIANE's avatar
    DIANE committed
                    phdr = st.radio("indexes column in csv? - _detected_: " + str(find_col_index('data/'+data_file.name)), options=["no", "yes"], index=["no", "yes"].index(str(find_col_index('data/'+data_file.name))),horizontal=True, key=31)
    
    DIANE's avatar
    DIANE committed
                    if phdr == 'yes':
                        col = 0
                    else:
                        col = False
                    imp = pd.read_csv(data_file, sep=psep, index_col=col)
    
    DIANE's avatar
    DIANE committed
                    imp.to_csv("./Report/datasets/"+data_file.name,sep = ';', encoding='utf-8', mode='a')
                    
    
    DIANE's avatar
    DIANE committed
                    # spectra = col_cat(imp)[0]
                    # meta_data = col_cat(imp)[1]
                    spectra, md_df_st_ = col_cat(imp)
                    meta_data = md_df_st_
    
    DIANE's avatar
    DIANE committed
                    st.success("The data have been loaded successfully", icon="")
    
    DIANE's avatar
    DIANE committed
            ## Load .dx file
    
    DIANE's avatar
    DIANE committed
            case 'dx':
    
    DIANE's avatar
    DIANE committed
                # Create a temporary file to save the uploaded file
                with NamedTemporaryFile(delete=False, suffix=".dx") as tmp:
                    tmp.write(data_file.read())
    
    DIANE's avatar
    DIANE committed
                    with open(tmp.name, 'r') as dd:
                        dxdata = dd.read()
                    with open('Report/datasets/'+data_file.name, 'w') as dd:
                        dd.write(dxdata)
    
    DIANE's avatar
    DIANE committed
                    tmp_path = tmp.name
                    with col1:
                        _, spectra, meta_data, md_df_st_ = read_dx(file = tmp_path)
                        st.success("The data have been loaded successfully", icon="")
                os.unlink(tmp_path)
    
    maimouni.mouhcine's avatar
    maimouni.mouhcine committed
        
    
    DIANE's avatar
    DIANE committed
    ## Visualize spectra
    
    st.header("I - Spectral Data Visualization", divider='blue')
    
    DIANE's avatar
    DIANE committed
    if not spectra.empty:
    
        n_samples = spectra.shape[0]
        nwl = spectra.shape[1]
    
        # retrieve columns name and rows name of spectra
        colnames = list(spectra.columns)
        rownames = [str(i) for i in list(spectra.index)]
        spectra.index = rownames
    
    DIANE's avatar
    DIANE committed
            fig, ax = plt.subplots(figsize = (30,7))
    
    DIANE's avatar
    DIANE committed
            if extension =='dx':
    
    DIANE's avatar
    DIANE committed
                lab = ['Wavenumber (1/cm)' if meta_data.loc[:,'xunits'][0] == '1/cm' else 'Wavelength (nm)']
                if lab[0] =='Wavenumber (1/cm)':
                    spectra.T.plot(legend=False, ax = ax).invert_xaxis()
                else :
                    spectra.T.plot(legend=False, ax = ax)
                ax.set_xlabel(lab[0], fontsize=18)
    
    DIANE's avatar
    DIANE committed
            else:
    
    DIANE's avatar
    DIANE committed
                spectra.T.plot(legend=False, ax = ax)
                ax.set_xlabel('Wavelength/Wavenumber', fontsize=18)
            
            ax.set_ylabel('Signal intensity', fontsize=18)
            plt.margins(x = 0)
            plt.tight_layout()
    
            st.pyplot(fig)
    
    DIANE's avatar
    DIANE committed
            
    
            # update lines size
            for line in ax.get_lines():
                line.set_linewidth(0.8)  # Set the desired line width here
    
    
    DIANE's avatar
    DIANE committed
            # Update the size of plot axis for exprotation to report
            l, w = fig.get_size_inches()
            fig.set_size_inches(8, 3)
            for label in (ax.get_xticklabels()+ax.get_yticklabels()):
    
                ax.xaxis.label.set_size(9.5)
                ax.yaxis.label.set_size(9.5)
    
    DIANE's avatar
    DIANE committed
            plt.tight_layout()
            fig.savefig("./Report/figures/spectra_plot.png", dpi=400) ## Export report
            fig.set_size_inches(l, w)# reset the plot size to its original size
            data_info = pd.DataFrame({'Name': [data_file.name],
    
                                    'Number of scanned samples': [n_samples]},
    
    DIANE's avatar
    DIANE committed
                                      index = ['Input file'])
    
        with col1:
            st.info('Information on the loaded data file')
    
    DIANE's avatar
    DIANE committed
            st.write(data_info) ## table showing the number of samples in the data file
    
    DIANE's avatar
    DIANE committed
    
    
    ############################## Exploratory data analysis ###############################
    
    DIANE's avatar
    DIANE committed
    st.header("II - Exploratory Data Analysis-Multivariable Data Analysis", divider='blue')
    
    ###### 1- Dimensionality reduction ######
    
    DIANE's avatar
    DIANE committed
    t = pd.DataFrame # scores
    p = pd.DataFrame # loadings
    if not spectra.empty:
    
        bb1, bb2, bb3, bb4, bb5, bb6, bb7 = st.columns([1,1,0.6,0.6,0.6,1.5,1.5])
        dim_red_method = bb1.selectbox("Dimensionality reduction techniques: ", options = dim_red_methods, index = default_reduction_option, key = 37)
        clus_method = bb2.selectbox("Clustering/sampling techniques: ", options = cluster_methods, index = default_clustering_option, key = 38)
    
    DIANE's avatar
    DIANE committed
        xc = standardize(spectra, center=True, scale=False)
    
    DIANE's avatar
    DIANE committed
        
        match dim_red_method:
            case "":
                    bb1.warning('⚠️ Please choose an algorithm !')
            
            case "PCA":
    
    DIANE's avatar
    DIANE committed
                @st.cache_data
                def dr_model_(change):
                    dr_model = LinearPCA(xc, Ncomp=8)
                    return dr_model
                dr_model = dr_model_(change = hash_data(xc))
    
    DIANE's avatar
    DIANE committed
    
            case "UMAP":
                if not meta_data.empty:
                    filter = md_df_st_.columns
                    filter = filter.insert(0, 'Nothing')
                    col = bb1.selectbox('Supervised UMAP by:', options= filter, key=108)
                    if col == 'Nothing':
                        supervised = None
                    else:
                        supervised = md_df_st_[col]
    
    DIANE's avatar
    DIANE committed
                    supervised = None
    
    DIANE's avatar
    DIANE committed
                @st.cache_data
                def dr_model_(change):
                    dr_model = Umap(numerical_data = MinMaxScale(spectra), cat_data = supervised)
                    return dr_model
                dr_model = dr_model_(change = hash_data(spectra))
                    
    
    DIANE's avatar
    DIANE committed
            case 'NMF':
    
    DIANE's avatar
    DIANE committed
                @st.cache_data
                def dr_model_(change):
                    dr_model = Nmf(spectra, Ncomp= 3)
                    return dr_model
                dr_model = dr_model_(change = hash_data(spectra))
            
    
            axis1 = bb3.selectbox("x-axis", options = dr_model.scores_.columns, index=0)
            axis2 = bb4.selectbox("y-axis", options = dr_model.scores_.columns, index=1)
            axis3 = bb5.selectbox("z-axis", options = dr_model.scores_.columns, index=2)
    
    DIANE's avatar
    DIANE committed
    
    
            t = pd.concat([dr_model.scores_.loc[:,axis1], dr_model.scores_.loc[:,axis2], dr_model.scores_.loc[:,axis3]], axis = 1)
    
    
    
    DIANE's avatar
    DIANE committed
    
    
    DIANE's avatar
    DIANE committed
    ###### II - clustering #######
    
    if not t.empty:
    
    DIANE's avatar
    DIANE committed
        clustered = np.arange(n_samples)
        non_clustered = None
    
    
        if dim_red_method == 'UMAP':
            scores = st.container()
        else:
            scores, loadings= st.columns([3,3])
    
    
    DIANE's avatar
    DIANE committed
        tcr = standardize(t)
    
    DIANE's avatar
    DIANE committed
        
    
    DIANE's avatar
    DIANE committed
        # Clustering
        match clus_method:
            case '':
                bb2.warning('⚠️ Please choose an algothithm !')
            case 'Kmeans':
                cl_model = Sk_Kmeans(tcr, max_clusters = 25)
                ncluster = scores.number_input(min_value=2, max_value=25, value=cl_model.suggested_n_clusters_, label = 'Select the desired number of clusters')  
                data, labels, clu_centers = cl_model.fit_optimal(nclusters = ncluster)
    
            # 2- HDBSCAN clustering
            case 'HDBSCAN':
                optimized_hdbscan = Hdbscan(np.array(tcr))
                all_labels, clu_centers = optimized_hdbscan.HDBSCAN_scores_
                labels = [f'cluster#{i+1}' if i !=-1 else 'Non clustered' for i in all_labels]
                ncluster = len(clu_centers)
                non_clustered = np.where(np.array(labels) == 'Non clustered')[0]
    
            # 3- Affinity propagation
            case 'AP':
                cl_model = AP(X = tcr)
                data, labels, clu_centers = cl_model.fit_optimal_
                ncluster = len(clu_centers)
    
            case 'KS':
                rset = scores.number_input(min_value=0, max_value=100, value=20, label = 'The ratio of data to be sampled (%)')
                cl_model = KS(x = tcr, rset = rset)
                calset = cl_model.calset
                labels = ["ind"]*n_samples
                ncluster = "1"
                selection_number = 'None'
    
            case 'RDM':
                rset = scores.number_input(min_value=0, max_value=100, value=20, label = 'The ratio of data to be sampled (%)')
                cl_model = RDM(x = tcr, rset = rset)
                calset = cl_model.calset
                labels = ["ind"]*n_samples
                ncluster = "1"
                selection_number = 'None'            
    
    DIANE's avatar
    DIANE committed
    
    
    DIANE's avatar
    DIANE committed
        new_tcr = tcr.iloc[clustered,:]    
        
    
    DIANE's avatar
    DIANE committed
    #################################################### III - Samples selection using the reduced data preentation ######
    
    DIANE's avatar
    DIANE committed
    samples_df_chem = pd.DataFrame
    selected_samples = []
    selected_samples_idx = []
    
    
    if not labels:
        custom_color_palette = px.colors.qualitative.Plotly[:1]
    elif labels:
    
        num_clusters = len(np.unique(labels))
        custom_color_palette = px.colors.qualitative.Plotly[:num_clusters]
    
    DIANE's avatar
    DIANE committed
        if clus_method:
    
    DIANE's avatar
    DIANE committed
            if clus_method in ['KS', 'RDM']:
    
                selected_samples_idx = calset[1]
                selection = 'None'
    
            else:
    
                selection = scores.radio('Select samples selection strategy:',
                                            options = selec_strategy, index = default_sample_selection_option, key=102)
    
    DIANE's avatar
    DIANE committed
            
            match selection:
    
    DIANE's avatar
    DIANE committed
                case 'center':
                    # list samples at clusters centers - Use sklearn.metrics.pairwise_distances_argmin if you want more than 1 sample per cluster
                    closest, _ = pairwise_distances_argmin_min(clu_centers, new_tcr)
                    selected_samples_idx = np.array(new_tcr.index)[list(closest)]
                    selected_samples_idx = selected_samples_idx.tolist()
                    
                #### Strategy 1
                case 'random':
                    selection_number = scores.number_input('How many samples per cluster?',
                                                            min_value = 1, step=1, value = 3)
                    s = np.array(labels)[np.where(np.array(labels) !='Non clustered')[0]]
                    for i in np.unique(s):
                        C = np.where(np.array(labels) == i)[0]
                        if C.shape[0] >= selection_number:
                            # scores.write(list(tcr.index)[labels== i])
                            km2 = KMeans(n_clusters = selection_number)
                            km2.fit(tcr.iloc[C,:])
                            clos, _ = pairwise_distances_argmin_min(km2.cluster_centers_, tcr.iloc[C,:])
                            selected_samples_idx.extend(tcr.iloc[C,:].iloc[list(clos)].index)
                        else:
                            selected_samples_idx.extend(new_tcr.iloc[C,:].index.to_list())
                        # list indexes of selected samples for colored plot    
    
    DIANE's avatar
    DIANE committed
    
    
    ################################      Plots visualization          ############################################
    
    DIANE's avatar
    DIANE committed
    
    
    DIANE's avatar
    DIANE committed
        ## Scores
    
    if not t.empty:
        with scores:
    
    DIANE's avatar
    DIANE committed
            fig1, ((ax1, ax2),(ax3,ax4)) = plt.subplots(2,2)
    
            st.write('Scores plot')
            # scores plot with clustering
    
            if list(labels) and meta_data.empty:
    
                fig = px.scatter_3d(tcr, x=axis1, y=axis2, z = axis3, color=labels ,color_discrete_sequence= custom_color_palette)
    
    DIANE's avatar
    DIANE committed
                sns.scatterplot(data = tcr, x = axis1, y =axis2 , hue = labels, ax = ax1)
    
            # scores plot with metadata
    
            elif len(list(labels)) == 0 and not meta_data.empty:
    
    DIANE's avatar
    DIANE committed
                filter = md_df_st_.columns
    
                col = st.selectbox('Color by:', options= filter)
    
    DIANE's avatar
    DIANE committed
                    fig = px.scatter_3d(tcr, x=axis1, y=axis2, z = axis3)
    
    DIANE's avatar
    DIANE committed
                    sns.scatterplot(data = tcr, x = axis1, y =axis2 , ax = ax1)
    
                    sns.scatterplot(data = tcr, x = axis2, y =axis3 , ax = ax2)
    
    DIANE's avatar
    DIANE committed
                    sns.scatterplot(data = tcr, x = axis1, y =axis3 , hue = list(map(str.lower,md_df_st_[col])), ax = ax3)
    
    DIANE's avatar
    DIANE committed
                    fig = px.scatter_3d(tcr, x=axis1, y=axis2, z = axis3, color = list(map(str.lower,md_df_st_[col])) )
                    sns.scatterplot(data = tcr, x = axis1, y =axis2 , hue = list(map(str.lower,md_df_st_[col])), ax = ax1)
                    sns.scatterplot(data = tcr, x = axis2, y =axis3 , hue = list(map(str.lower,md_df_st_[col])), ax = ax2)
                    sns.scatterplot(data = tcr, x = axis1, y =axis3 , hue = list(map(str.lower,md_df_st_[col])), ax = ax3)
    
            # color with scores and metadata
            elif len(list(labels)) > 0  and not meta_data.empty:
                if clus_method in cluster_methods[1:]:
                    filter = ['None', clus_method]
    
    DIANE's avatar
    DIANE committed
                    filter.extend(md_df_st_.columns)
    
    DIANE's avatar
    DIANE committed
                    filter = md_df_st_.columns.insert(0,'None')
    
                col = st.selectbox('Color by:', options= filter)
    
                    fig = px.scatter_3d(tcr, x=axis1, y=axis2, z = axis3)
    
    DIANE's avatar
    DIANE committed
                    sns.scatterplot(data = tcr, x = axis1, y =axis2 , ax = ax1)
    
    DIANE's avatar
    DIANE committed
                    fig = px.scatter_3d(tcr, x=axis1, y=axis2, z = axis3, color = labels)
    
    DIANE's avatar
    DIANE committed
                    sns.scatterplot(data = tcr, x = axis1, y =axis2 , ax = ax1)
    
    DIANE's avatar
    DIANE committed
                    fig = px.scatter_3d(tcr, x=axis1, y=axis2, z = axis3, color = list(map(str.lower,md_df_st_[col])))
                    sns.scatterplot(data = tcr, x = axis1, y =axis2 , hue = list(map(str.lower,md_df_st_[col])), ax = ax1)
                    sns.scatterplot(data = tcr, x = axis1, y =axis2 , hue = list(map(str.lower,md_df_st_[col])), ax = ax2)
                    sns.scatterplot(data = tcr, x = axis1, y =axis2 , hue = list(map(str.lower,md_df_st_[col])), ax = ax3)
    
                fig = px.scatter_3d(tcr, x=axis1, y=axis2, z = axis3, color=labels if list(labels) else None,color_discrete_sequence= custom_color_palette)
    
    DIANE's avatar
    DIANE committed
                sns.scatterplot(data = tcr, x = axis1, y =axis2 , ax = ax1)
    
            fig.update_traces(marker=dict(size=4))
    
    DIANE's avatar
    DIANE committed
    
            if selected_samples_idx:
                tt = tcr.iloc[selected_samples_idx,:]
    
                fig.add_scatter3d(x = tt.loc[:,axis1], y = tt.loc[:,axis2],z = tt.loc[:,axis3],
                                   mode ='markers', marker = dict(size = 5, color = 'black'),
    
    DIANE's avatar
    DIANE committed
                                  name = 'selected samples')        
    
            st.plotly_chart(fig, use_container_width = True)
    
    DIANE's avatar
    DIANE committed
            if labels:
    
                # export 2D scores plot
    
    DIANE's avatar
    DIANE committed
                comb = [i for i in combinations([1,2,3], 2)]
                subcap = ['a','b','c']
                for i in range(len(comb)):
    
                    fig_export = px.scatter(tcr, x = eval(f'axis{str(comb[i][0])}'), y=eval(f'axis{str(comb[i][1])}'),
                                                color = labels if list(labels) else None,
                                                color_discrete_sequence = custom_color_palette)
                    fig_export.add_scatter(x = tt.loc[:,eval(f'axis{str(comb[i][0])}')], y = tt.loc[:,eval(f'axis{str(comb[i][1])}')],
                                   mode ='markers', marker = dict(size = 5, color = 'black'),
                                  name = 'selected samples')
                    fig_export.update_layout(font=dict(size=23))
                    fig_export.add_annotation(text= f'({subcap[i]})', align='center', showarrow= False, xref='paper', yref='paper', x=-0.13, y= 1,
    
    DIANE's avatar
    DIANE committed
                                                 font= dict(color= "black", size= 35), bgcolor ='white', borderpad= 2, bordercolor= 'black', borderwidth= 3)
    
                    fig_export.update_traces(marker=dict(size= 10), showlegend= False)
                    fig_export.write_image(f'./Report/Figures/scores_pc{str(comb[i][0])}_pc{str(comb[i][1])}.png')
    
    DIANE's avatar
    DIANE committed
    if not spectra.empty:
    
    DIANE's avatar
    DIANE committed
        if dim_red_method in ['PCA','NMF']:
    
            with loadings:
                st.write('Loadings plot')
                p = dr_model.loadings_
    
                freq = pd.DataFrame(colnames, index=p.index)
    
    DIANE's avatar
    DIANE committed
                if extension =='dx':
    
    DIANE's avatar
    DIANE committed
                    if meta_data.loc[:,'xunits'][0] == '1/cm':
                        freq.columns = ['Wavenumber (1/cm)']
    
    DIANE's avatar
    DIANE committed
                        xlab = "Wavenumber (1/cm)"
                        inv = 'reversed'
    
    DIANE's avatar
    DIANE committed
                    else:
                        freq.columns = ['Wavelength (nm)']
    
    DIANE's avatar
    DIANE committed
                        xlab = 'Wavelength (nm)'
                        inv = None
    
    DIANE's avatar
    DIANE committed
                else:
                    freq.columns = ['Wavelength/Wavenumber']
    
    DIANE's avatar
    DIANE committed
                    xlab = 'Wavelength/Wavenumber'
                    inv = None
    
    DIANE's avatar
    DIANE committed
                    
                pp = pd.concat([p, freq], axis=1)
                #########################################
                df1 = pp.melt(id_vars=freq.columns)
                fig = px.line(df1, x=freq.columns, y='value', color='variable', color_discrete_sequence=px.colors.qualitative.Plotly)
    
                fig.update_layout(legend=dict(x=1, y=0, font=dict(family="Courier", size=12, color="black"),
                                            bordercolor="black", borderwidth=2))
    
    DIANE's avatar
    DIANE committed
                fig.update_layout(xaxis_title = xlab,yaxis_title = "Intensity" ,xaxis = dict(autorange= inv))
    
    
                st.plotly_chart(fig, use_container_width=True)
    
    
                # Export du graphique
                img = pio.to_image(fig, format="png")
    
    DIANE's avatar
    DIANE committed
                with open("./Report/figures/loadings_plot.png", "wb") as f:
    
    DIANE's avatar
    DIANE committed
    #############################################################################################################
    
    DIANE's avatar
    DIANE committed
        if dim_red_method == 'PCA':
    
            influence, hotelling = st.columns([3, 3])
    
            with influence:
                st.write('Influence plot')
    
    DIANE's avatar
    DIANE committed
                # Laverage
                Hat =  t.to_numpy() @ np.linalg.inv(np.transpose(t.to_numpy()) @ t.to_numpy()) @ np.transpose(t.to_numpy())
                leverage = np.diag(Hat) / np.trace(Hat)
    
    DIANE's avatar
    DIANE committed
                # Loadings
                p = pd.concat([dr_model.loadings_.loc[:,axis1], dr_model.loadings_.loc[:,axis2], dr_model.loadings_.loc[:,axis3]], axis = 1)
                # Matrix reconstruction
                xp = np.dot(t,p.T)
                # Q residuals: Q residuals represent the magnitude of the variation remaining in each sample after projection through the model
                residuals = np.diag(np.subtract(xc.to_numpy(), xp)@ np.subtract(xc.to_numpy(), xp).T)
                tresh4 = sc.stats.chi2.ppf(0.05, df = 3)
    
                # color with metadata
                if not meta_data.empty and clus_method:
                    if col == "None":
    
    DIANE's avatar
    DIANE committed
    
                    elif col == clus_method:
                        l1 = labels
                    
                    else:
                        l1 = list(map(str.lower,md_df_st_[col]))
    
                elif meta_data.empty and clus_method:                        
                    l1 = labels
    
                elif meta_data.empty and not clus_method:
    
    DIANE's avatar
    DIANE committed
                
                elif not meta_data.empty and not clus_method:
                    l1 = list(map(str.lower,md_df_st_[col]))
    
    
    DIANE's avatar
    DIANE committed
                fig = px.scatter(x = leverage, y = residuals, color=labels if list(labels) else None,
                                                color_discrete_sequence= custom_color_palette)
    
    DIANE's avatar
    DIANE committed
                fig.add_vline(x = tresh3, line_width = 1, line_dash = 'solid', line_color = 'red')
                fig.add_hline(y=tresh4, line_width=1, line_dash='solid', line_color='red')
    
    DIANE's avatar
    DIANE committed
                fig.update_layout(xaxis_title="Leverage", yaxis_title = "Q-residuals", font=dict(size=20), width=800, height=600)
    
    DIANE's avatar
    DIANE committed
    
                out3 = leverage > tresh3
                out4 = residuals > tresh4
    
    
    DIANE's avatar
    DIANE committed
                    if out3[i]:
                        if not meta_data.empty:
                            ann =  meta_data.loc[:,'name'][i]
                        else:
                            ann = t.index[i]
    
    DIANE's avatar
    DIANE committed
                        fig.add_annotation(dict(x = leverage[i], y = residuals[i], showarrow=True, text = ann,font= dict(color= "black", size= 15),
    
    DIANE's avatar
    DIANE committed
                                    xanchor = 'auto', yanchor = 'auto'))
    
    DIANE's avatar
    DIANE committed
                
                fig.update_traces(marker=dict(size= 6), showlegend= True)
                fig.update_layout(font=dict(size=23), width=800, height=500)
                st.plotly_chart(fig, use_container_width=True)
    
    DIANE's avatar
    DIANE committed
                
                for annotation in fig.layout.annotations:
                    annotation.font.size = 35
                fig.update_layout(font=dict(size=23), width=800, height=600)
                fig.update_traces(marker=dict(size= 10), showlegend= False)
    
                fig.add_annotation(text= '(a)', align='center', showarrow= False, xref='paper', yref='paper', x=-0.125, y= 1,
                                                 font= dict(color= "black", size= 35), bgcolor ='white', borderpad= 2, bordercolor= 'black', borderwidth= 3)
    
    DIANE's avatar
    DIANE committed
                fig.write_image('./Report/figures/influence_plot.png', engine = 'kaleido')
    
            with hotelling:
    
    DIANE's avatar
    DIANE committed
                st.write('T²-Hotelling vs Q-residuals plot')
    
    DIANE's avatar
    DIANE committed
                # Hotelling
                hotelling  = t.var(axis = 1)
                # Q residuals: Q residuals represent the magnitude of the variation remaining in each sample after projection through the model
                residuals = np.diag(np.subtract(xc.to_numpy(), xp)@ np.subtract(xc.to_numpy(), xp).T)
    
    
                fcri = sc.stats.f.isf(0.05, 3, n_samples)
                tresh0 = (3 * (n_samples ** 2 - 1) * fcri) / (n_samples * (n_samples - 3))
    
    DIANE's avatar
    DIANE committed
                tresh1 = sc.stats.chi2.ppf(0.05, df = 3)
                
    
    DIANE's avatar
    DIANE committed
                fig = px.scatter(t, x = hotelling, y = residuals, color=labels if list(labels) else None,
                                                color_discrete_sequence= custom_color_palette)
                fig.update_layout(xaxis_title="Hotelling-T² distance",yaxis_title="Q-residuals")
    
    DIANE's avatar
    DIANE committed
                fig.add_vline(x=tresh0, line_width=1, line_dash='solid', line_color='red')
                fig.add_hline(y=tresh1, line_width=1, line_dash='solid', line_color='red')
    
                out0 = hotelling > tresh0
                out1 = residuals > tresh1
    
    DIANE's avatar
    DIANE committed
                
    
    DIANE's avatar
    DIANE committed
                    if out0[i]:
                        if not meta_data.empty:
                            ann =  meta_data.loc[:,'name'][i]
                        else:
                            ann = t.index[i]
    
    DIANE's avatar
    DIANE committed
                        fig.add_annotation(dict(x = hotelling[i], y = residuals[i], showarrow=True, text = ann, font= dict(color= "black", size= 15),
    
    DIANE's avatar
    DIANE committed
                                    xanchor = 'auto', yanchor = 'auto'))
    
    DIANE's avatar
    DIANE committed
                        
                fig.update_traces(marker=dict(size= 6), showlegend= True)
                fig.update_layout(font=dict(size=23), width=800, height=500)
    
                st.plotly_chart(fig, use_container_width=True)
    
    DIANE's avatar
    DIANE committed
    
    
                for annotation in fig.layout.annotations:
                    annotation.font.size = 35
                fig.update_layout(font=dict(size=23), width=800, height=600)
                fig.update_traces(marker=dict(size= 10), showlegend= False)
    
                fig.add_annotation(text= '(b)', align='center', showarrow= False, xref='paper', yref='paper', x=-0.125, y= 1,
                                                 font= dict(color= "black", size= 35), bgcolor ='white', borderpad= 2, bordercolor= 'black', borderwidth= 3)
    
    DIANE's avatar
    DIANE committed
                fig.write_image("./Report/figures/hotelling_plot.png", format="png")
    
    
    st.header('III - Selected Samples for Reference Analysis', divider='blue')
    if labels:
        sel, info = st.columns([3, 1])
        sel.write("Tabular identifiers of selected samples for reference analysis:")
        if selected_samples_idx:
            if meta_data.empty:
                sam1 = pd.DataFrame({'name': spectra.index[clustered][selected_samples_idx],
                                    'cluster':np.array(labels)[clustered][selected_samples_idx]},
                                    index = selected_samples_idx)
            else:
                sam1 = meta_data.iloc[clustered,:].iloc[selected_samples_idx,:]
                sam1.insert(loc=0, column='index', value=selected_samples_idx)
                sam1.insert(loc=1, column='cluster', value=np.array(labels)[selected_samples_idx])
            sam1.index = np.arange(len(selected_samples_idx))+1
            info.info(f'Information !\n - The total number of samples: {n_samples}.\n- The number of samples selected for reference analysis: {sam1.shape[0]}.\n - The proportion of samples selected for reference analysis: {round(sam1.shape[0]/n_samples*100)}%.')
            sam = sam1
    
    DIANE's avatar
    DIANE committed
            # if clus_method == cluster_methods[2]:
            #     unclus = sel.checkbox("Include non clustered samples (for HDBSCAN clustering)", value=True)
    
    
            if clus_method == cluster_methods[2]:
                unclus = sel.checkbox("Include non clustered samples (for HDBSCAN clustering)", value=True)
    
                if selected_samples_idx:
                    if unclus:
                        if meta_data.empty:
                            sam2 = pd.DataFrame({'name': spectra.index[non_clustered],
                                                'cluster':['Non clustered']*len(spectra.index[non_clustered])},
                                                index = spectra.index[non_clustered])
                        else :
                            sam2 = meta_data.iloc[non_clustered,:]
                            sam2.insert(loc=0, column='index', value= spectra.index[non_clustered])
                            sam2.insert(loc=1, column='cluster', value=['Non clustered']*len(spectra.index[non_clustered]))
                        
                        sam = pd.concat([sam1, sam2], axis = 0)
                        sam.index = np.arange(sam.shape[0])+1
    
    DIANE's avatar
    DIANE committed
                        info.info(f'- The number of Non-clustered samples: {sam2.shape[0]}.\n - The proportion of Non-clustered samples: {round(sam2.shape[0]/n_samples*100)}%')
    
    DIANE's avatar
    DIANE committed
            
    
    
    DIANE's avatar
    DIANE committed
    # st.write(hash_data(change = './Report/report.tex'))
    with open('./Report/report.tex') as myfile:
        filehash = hash_data(myfile.read())
    
    DIANE's avatar
    DIANE committed
    
    # figs_list = os.listdir("./Report/figures")
    if data_file:
    
        Nb_ech = str(n_samples)
        nb_clu = str(sam1.shape[0])
    
    DIANE's avatar
    DIANE committed
        ###################################################
        ## generate report
        latex_report = report.report('Representative subset selection', data_file.name, dim_red_method,
                                      clus_method, Nb_ech, ncluster, selection, selection_number, nb_clu,tcr, sam)
        
        @st.cache_data
        def download_res(file,sam):
            zipname = f'results{date_time}subset_selection_{file.name.split('.')[0]}.zip' # name of the zipfile
            with open('./temp/fname.json', 'w') as f: # dump filename and save it as a .json file
                json.dump(zipname, f)
            shutil.make_archive(base_name = zipname.split('.')[0],format = "zip",root_dir = "./Report", base_dir = "figures")# create zip containing figures and report
            
            file_path = Path("./temp/"+zipname)
            sam.to_csv("./"+zipname,sep = ';',
              encoding='utf-8', mode='a',
              compression=dict(method='zip',archive_name=f"selected subset for reference analysis_{userfilename}_{date_time}_.csv")) 
            
            with zipfile.ZipFile("./"+zipname, 'a') as newzip:
                    newzip.write("./Report/report.pdf", arcname="report.pdf")
                    newzip.write("./Report/datasets/"+os.listdir("./Report/datasets")[0], arcname=os.listdir("./Report/datasets")[0])
    
    DIANE's avatar
    DIANE committed
                    
    
    DIANE's avatar
    DIANE committed
    
            # #### add data to zip
            # match data_file.name:
            #     case 'csv':
            #         with open(data_file.name, 'wb') as cs:
            #             st.write(data_file.getbuffer())
            #     case 'dx':
            #         st.write(4)
    
    DIANE's avatar
    DIANE committed
                    
    
    DIANE's avatar
    DIANE committed
            ### move the .zip file to the temp directory
            shutil.move('./'+zipname,'./temp/'+ zipname)
    
        a =''
        for i in (data_file.name, dim_red_method,clus_method, Nb_ech, tcr.astype(str)):
            a += str(i)
    
        myfilepdf = Path("./Report/report.pdf")
        if 'htest' not in st.session_state:
            st.session_state.htest = '0'
            report.compile_latex(change =hash_data(a))
            st.write(hash_data(a))
            if myfilepdf.is_file():
                download_res(file = data_file, sam = sam)
    
        elif st.session_state['htest'] != hash_data(a):
            st.session_state['htest'] = hash_data(a)
            report.compile_latex(change =hash_data(a))
            st.write(hash_data(a))
            if myfilepdf.is_file():
                download_res(file = data_file, sam = sam)
        else:
            pass
    
        
        list_of_files = glob.glob(r"./temp/*.zip")
        if len(list_of_files) >3:
            oldest_file = min(list_of_files, key=os.path.getctime)
            os.remove(oldest_file)
        list_of_files = glob.glob(r"./temp/*.zip")
        recent_file = max(list_of_files, key=os.path.getctime)
    
        with open('./temp/fname.json', 'r') as f:
            zipname = json.load(f)
        if os.path.split(recent_file)[1] == os.path.split(zipname)[1]:
            with open("./temp/"+zipname, "rb") as fp:
                    st.download_button('Download', data = fp, file_name=zipname, mime="application/zip",
                                    args=None, kwargs=None,type="primary",use_container_width=True)