import numpy as np
from pandas import DataFrame
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from scipy.spatial.distance import cdist

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~  kmeans ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 

from sklearn.cluster import KMeans, AffinityPropagation, HDBSCAN
from sklearn.metrics import silhouette_score
import pandas as pd

def clustering(X, method='kmeans', **kwargs):
    """
    Perform clustering on the given dataset using the specified method.

    Parameters
    ----------
    X : DataFrame or array-like, shape (n_samples, n_features)
        The input data for clustering.

    method : str, optional, default='kmeans'
        The clustering method to use. Options are:
        - 'kmeans': K-Means clustering.
        - 'ap': Affinity Propagation clustering.
        - 'hdbscan': HDBSCAN clustering.

    kwargs : dict
        Additional arguments specific to the clustering method.

    Returns
    -------
    DataFrame
        A DataFrame containing the cluster assignments for each sample. The index corresponds
        to the sample names (from X), and a column "names" lists the cluster labels.
    """
    if method == 'KMEANS':
        max_k = kwargs.get('max_k', 10)

        # Find the optimal number of clusters using Silhouette Score
        def find_optimal_k(X, max_k):
            best_k = 2
            best_score = -1
            for k in range(2, max_k + 1):
                model = KMeans(n_clusters=k, random_state=42, n_init=10, max_iter=300)
                labels = model.fit_predict(X)
                score = silhouette_score(X, labels)
                if score > best_score:
                    best_score = score
                    best_k = k
            return best_k

        optimal_k = find_optimal_k(X, max_k)
        model = KMeans(n_clusters=optimal_k, random_state=42, n_init=10, max_iter=300)
        labels = model.fit_predict(X)
        res = pd.DataFrame({'names': X.index}, index = [f'cluster#{i+1}' for i in labels])
        return res, len(set(labels))

    elif method == 'AP':
        model = AffinityPropagation(random_state=42)
        model.fit(X)
        labels = model.predict(X)
        res = pd.DataFrame({'names': X.index}, index = [f'cluster#{i+1}' for i in labels])
        return res, len(set(labels))

    elif method == 'HDBSCAN':
        min_samples = kwargs.get('min_samples', 8)
        min_cluster_size = kwargs.get('min_cluster_size', 10)
        metric = kwargs.get('metric', 'euclidean')

        model = HDBSCAN(min_samples=2, min_cluster_size=5, metric="euclidean")
        labels = model.fit_predict(X)
        res = pd.DataFrame({'names': X.index}, [f'cluster#{i+1}' if i != -1 else 'Non clustered' for i in labels])
        return res, len(set(labels))-1

    else:
        raise ValueError(f"Unknown clustering method: {method}")