From edd165e7d4d9abfc878c4967ff060c3efc762bca Mon Sep 17 00:00:00 2001 From: barthes <nicolas.barthes@cefe.cnrs.fr> Date: Thu, 11 Apr 2024 17:51:09 +0200 Subject: [PATCH] added HDBSCAN clustering with optimization --- Class_Mod/UMAP_.py | 2 +- Packages.py | 5 ++--- app.py | 8 +++++++- pages/1-samples_selection.py | 20 ++++++++++++++++---- 4 files changed, 26 insertions(+), 9 deletions(-) diff --git a/Class_Mod/UMAP_.py b/Class_Mod/UMAP_.py index b563d4e..e9ae0dc 100644 --- a/Class_Mod/UMAP_.py +++ b/Class_Mod/UMAP_.py @@ -10,7 +10,7 @@ class Umap: self.x = scaled_values - self.model = UMAP(random_state=42, n_neighbors=20, n_components=4, min_dist=0.0) + self.model = UMAP(n_neighbors=20, n_components=4, min_dist=0.0,) # random_state=42,) self.model.fit(self.x) self.scores = self.model.transform(self.x) self.scores = pd.DataFrame(self.scores, index = self.numerical_data.index) diff --git a/Packages.py b/Packages.py index 68c0bf0..2f04530 100644 --- a/Packages.py +++ b/Packages.py @@ -17,8 +17,7 @@ from umap.umap_ import UMAP from sklearn.decomposition import PCA, NMF # Clustering -from sklearn.cluster import KMeans -#import hdbscan +from sklearn.cluster import KMeans, HDBSCAN # Modelling # import julia @@ -40,7 +39,7 @@ import plotly.express as px import matplotlib.pyplot as plt import seaborn as sns ### Important Metrics -from sklearn.metrics import pairwise_distances_argmin_min +from sklearn.metrics import pairwise_distances_argmin_min, adjusted_rand_score, adjusted_mutual_info_score ## Web app construction import streamlit as st diff --git a/app.py b/app.py index a68c036..8baa0d0 100644 --- a/app.py +++ b/app.py @@ -27,4 +27,10 @@ with st.container(): st.write("Samples selection (PCA, [UMAP](https://umap-learn.readthedocs.io/en/latest/how_umap_works.html), ...), Predictive Modelling ([Pinard](https://github.com/GBeurier/pinard), [LWPLSR](https://doi.org/10.1002/cem.3209), ...), and Predictions using your data (CSV or DX files) and/or PACE NIRS Database.") #st.image(img_general) st.markdown("### We could add documentation here ###") - + header1, header2, header3 = st.columns(3) + if header1.button("Samples Selection"): + st.switch_page('pages\\1-samples_selection.py') + if header2.button("Model Creation"): + st.switch_page('pages\\2-model_creation.py') + if header3.button("Predictions"): + st.switch_page('pages\\3-prediction.py') \ No newline at end of file diff --git a/pages/1-samples_selection.py b/pages/1-samples_selection.py index 42ce389..f494d1c 100644 --- a/pages/1-samples_selection.py +++ b/pages/1-samples_selection.py @@ -88,7 +88,7 @@ with container1: with container2: if sselectx_csv is not None: plot_type=['', 'PCA','UMAP', 'NMF'] - cluster_methods = ['', 'Kmeans','UMAP', 'AP'] + cluster_methods = ['', 'Kmeans','HDBSCAN', 'AP'] with pc: type_plot = st.selectbox("Dimensionality reduction techniques: ", options=plot_type, key=37) @@ -110,9 +110,13 @@ with container2: if type_cluster == 'Kmeans': cl = Sk_Kmeans(pd.concat([model.scores_.loc[:,axis1], model.scores_.loc[:,axis2], model.scores_.loc[:,axis3]], axis = 1), max_clusters = 30) + elif type_cluster == 'HDBSCAN': + from hdbscan import HDBSCAN_function + labels, hdbscan_score = HDBSCAN_function(data_import, min_cluster_size=10) + with scores: t = model.scores_ - if type_cluster in ['Kmeans','UMAP', 'AP']: + if type_cluster in ['AP', 'Kmeans']: st.write('Scree plot') fig2 = px.scatter(cl.inertia_.T, y = 'inertia') st.plotly_chart(fig2) @@ -123,14 +127,22 @@ with container2: st.write('Scores plot') fig = px.scatter_3d(data, x=axis1, y=axis2, z = axis3, color=colors) fig.update_traces(marker=dict(size=4)) + st.plotly_chart(fig) + + elif type_cluster in ['HDBSCAN']: + st.write('plot HDBSCAN clustering') + fig_hdbscan = px.scatter_3d(t, x=axis1, y=axis2, z = axis3, color=labels) + fig_hdbscan.update_traces(marker=dict(size=4)) + st.plotly_chart(fig_hdbscan) + st.write('DBCV score = ' + str(hdbscan_score)) + # st.dataframe(min_score.stack().agg(['min'])) else: fig = px.scatter_3d(t, x=axis1, y=axis2, z = axis3) fig.update_traces(marker=dict(size=4)) - - st.plotly_chart(fig) + st.plotly_chart(fig) if type_plot =='PCA': -- GitLab