import numpy as np
from pandas import DataFrame
from sklearn.cluster import KMeans



#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~  kmeans ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 
class Sk_Kmeans:
    """K-Means clustering for Samples selection.

    Returns:
        inertia_ (DataFrame): DataFrame with ...
        x (DataFrame): Initial data
        clu (DataFrame): Cluster name for each sample
        model.cluster_centers_ (DataFrame): Coordinates of the center of each cluster
    """
    def __init__(self, x, max_clusters):
        """Initiate the KMeans class.

        Args:
            x (DataFrame): the original reduced data to cluster
            max_cluster (Int): the max number of desired clusters.
        """
        self.x = x
        self.max_clusters = max_clusters

        self.inertia = DataFrame()
        for i in range(1, max_clusters+1):
            model = KMeans(n_clusters = i, init = 'k-means++', random_state = 42)
            model.fit(x)
            self.inertia[f'{i}_clust']= [model.inertia_]
        self.inertia.index = ['inertia']

    @property
    def inertia_(self):
        return self.inertia
    
    @property
    def suggested_n_clusters_(self):
        idxidx = []
        values = []

        s = self.inertia.to_numpy().ravel()
        for i in range(self.max_clusters-1):
            idxidx.append(f'{i+1}_clust')
            values.append((s[i] - s[i+1])*100 / s[i])

        id = np.max(np.where(np.array(values) > 5))+2
        return id
    
    @property
    def fit_optimal_(self):
        model = KMeans(n_clusters = self.suggested_n_clusters_, init = 'k-means++', random_state = 42)
        model.fit(self.x)
        yp = model.predict(self.x)+1
        clu = [f'cluster#{i}' for i in yp]

        return self.x, clu, model.cluster_centers_
    




    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~hdbscan ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class Hdbscan:
    """Runs an automatically optimized sklearn.HDBSCAN clustering on dimensionality reduced space.

    The HDBSCAN_scores_ @Property returns the cluster number of each sample (_labels) and the DBCV best score.

    Returns:
        _labels (DataFrame): DataFrame with the cluster belonging number for each sample
        _hdbscan_score (float): a float with the best DBCV score after optimization

    Examples:
        - clustering = HDBSCAN((data)
        - scores = clustering.HDBSCAN_scores_

    """
    def __init__(self, data):
        """Initiate the HDBSCAN calculation

        Args:
            data (DataFrame): the Dimensionality reduced space, raw result of the UMAP.fit()
            param_dist (dictionary): the HDBSCAN optimization parameters to test
            _score (DataFrame): is a dataframe with the DBCV value for each combination of param_dist. We search for the higher value to then compute an HDBSCAN with the best parameters.
        """
        # Really fast
        self._param_dist = {'min_samples': [8],
                      'min_cluster_size':[10],
                      'metric' : ['euclidean'],#,'manhattan'],
                      }
        # Medium
        # self._param_dist = {'min_samples': [1,10],
        #     'min_cluster_size':[5,50],
        #     'metric' : ['euclidean','manhattan'],
        #     }
        # Complete
        # self._param_dist = {'min_samples': [1,5,10,],
        #       'min_cluster_size':[5,25,50,],
        #       'metric' : ['euclidean','manhattan'],
        #       }

        self._clusterable_embedding = data

        # RandomizedSearchCV not working...
        # def scoring(model, clusterable_embedding):
        #     label = HDBSCAN().fit_predict(clusterable_embedding)
        #     hdbscan_score = DBCV(clusterable_embedding, label, dist_function=euclidean)
        #     return hdbscan_score
        # tunning = RandomizedSearchCV(estimator=HDBSCAN(), param_distributions=param_dist,  scoring=scoring)
        # tunning.fit(clusterable_embedding)
        # return tunning

        # compute optimization. Test each combination of parameters and store DBCV score into _score.
        # self._score = DataFrame()
        # for i in self._param_dist.get('min_samples'):
        #     for j in self._param_dist.get('min_cluster_size'):
        #         self._ij_label = HDBSCAN(min_samples=i, min_cluster_size=j).fit_predict(self._clusterable_embedding)
        #         self._ij_hdbscan_score = self.DBCV(self._clusterable_embedding, self._ij_label,)# dist_function=euclidean)
        #         self._score.at[i,j] = self._ij_hdbscan_score
        # get the best DBCV score
        # self._hdbscan_bscore  = max(self._score.max())
        # find the coordinates of the best clustering parameters and run HDBSCAN below
        # self._bparams = np.where(self._score == self._hdbscan_bscore)
        # run HDBSCAN with best params

        # self.best_hdbscan = HDBSCAN(min_samples=self._param_dist['min_samples'][self._bparams[0][0]], min_cluster_size=self._param_dist['min_cluster_size'][self._bparams[1][0]], metric=self._param_dist['metric'][self._bparams[1][0]], store_centers="medoid", )
        self.best_hdbscan = HDBSCAN(min_samples=self._param_dist['min_samples'][0], min_cluster_size=self._param_dist['min_cluster_size'][0], metric=self._param_dist['metric'][0], store_centers="medoid", )
        self.best_hdbscan.fit_predict(self._clusterable_embedding)
        self._labels = self.best_hdbscan.labels_
        self._centers = self.best_hdbscan.medoids_


    # def DBCV(self, X, labels, dist_function=euclidean):
    #     """
    #     Implimentation of Density-Based Clustering Validation "DBCV"
    #
    #     Citation: Moulavi, Davoud, et al. "Density-based clustering validation."
    #     Proceedings of the 2014 SIAM International Conference on Data Mining.
    #     Society for Industrial and Applied Mathematics, 2014.
    #
    #     Density Based clustering validation
    #
    #     Args:
    #         X (np.ndarray): ndarray with dimensions [n_samples, n_features]
    #             data to check validity of clustering
    #         labels (np.array): clustering assignments for data X
    #         dist_dunction (func): function to determine distance between objects
    #             func args must be [np.array, np.array] where each array is a point
    #
    #     Returns:
    #         cluster_validity (float): score in range[-1, 1] indicating validity of clustering assignments
    #     """
    #     graph = self._mutual_reach_dist_graph(X, labels, dist_function)
    #     mst = self._mutual_reach_dist_MST(graph)
    #     cluster_validity = self._clustering_validity_index(mst, labels)
    #     return cluster_validity
    #
    #
    # def _core_dist(self, point, neighbors, dist_function):
    #     """
    #     Computes the core distance of a point.
    #     Core distance is the inverse density of an object.
    #
    #     Args:
    #         point (np.array): array of dimensions (n_features,)
    #             point to compute core distance of
    #         neighbors (np.ndarray): array of dimensions (n_neighbors, n_features):
    #             array of all other points in object class
    #         dist_dunction (func): function to determine distance between objects
    #             func args must be [np.array, np.array] where each array is a point
    #
    #     Returns: core_dist (float)
    #         inverse density of point
    #     """
    #     n_features = np.shape(point)[0]
    #     n_neighbors = np.shape(neighbors)[0]
    #
    #     distance_vector = cdist(point.reshape(1, -1), neighbors)
    #     distance_vector = distance_vector[distance_vector != 0]
    #     numerator = ((1/distance_vector)**n_features).sum()
    #     core_dist = (numerator / (n_neighbors - 1)) ** (-1/n_features)
    #     return core_dist
    #
    # def _mutual_reachability_dist(self, point_i, point_j, neighbors_i,
    #                               neighbors_j, dist_function):
    #     """.
    #     Computes the mutual reachability distance between points
    #
    #     Args:
    #         point_i (np.array): array of dimensions (n_features,)
    #             point i to compare to point j
    #         point_j (np.array): array of dimensions (n_features,)
    #             point i to compare to point i
    #         neighbors_i (np.ndarray): array of dims (n_neighbors, n_features):
    #             array of all other points in object class of point i
    #         neighbors_j (np.ndarray): array of dims (n_neighbors, n_features):
    #             array of all other points in object class of point j
    #         dist_function (func): function to determine distance between objects
    #             func args must be [np.array, np.array] where each array is a point
    #
    #     Returns:
    #         mutual_reachability (float)
    #         mutual reachability between points i and j
    #
    #     """
    #     core_dist_i = self._core_dist(point_i, neighbors_i, dist_function)
    #     core_dist_j = self._core_dist(point_j, neighbors_j, dist_function)
    #     dist = dist_function(point_i, point_j)
    #     mutual_reachability = np.max([core_dist_i, core_dist_j, dist])
    #     return mutual_reachability
    #
    #
    # def _mutual_reach_dist_graph(self, X, labels, dist_function):
    #     """
    #     Computes the mutual reach distance complete graph.
    #     Graph of all pair-wise mutual reachability distances between points
    #
    #     Args:
    #         X (np.ndarray): ndarray with dimensions [n_samples, n_features]
    #             data to check validity of clustering
    #         labels (np.array): clustering assignments for data X
    #         dist_dunction (func): function to determine distance between objects
    #             func args must be [np.array, np.array] where each array is a point
    #
    #     Returns: graph (np.ndarray)
    #         array of dimensions (n_samples, n_samples)
    #         Graph of all pair-wise mutual reachability distances between points.
    #
    #     """
    #     n_samples = np.shape(X)[0]
    #     graph = []
    #     counter = 0
    #     for row in range(n_samples):
    #         graph_row = []
    #         for col in range(n_samples):
    #             point_i = X[row]
    #             point_j = X[col]
    #             class_i = labels[row]
    #             class_j = labels[col]
    #             members_i = self._get_label_members(X, labels, class_i)
    #             members_j = self._get_label_members(X, labels, class_j)
    #             dist = self._mutual_reachability_dist(point_i, point_j,
    #                                              members_i, members_j,
    #                                              dist_function)
    #             graph_row.append(dist)
    #         counter += 1
    #         graph.append(graph_row)
    #     graph = np.array(graph)
    #     return graph
    #
    #
    # def _mutual_reach_dist_MST(self, dist_tree):
    #     """
    #     Computes minimum spanning tree of the mutual reach distance complete graph
    #
    #     Args:
    #         dist_tree (np.ndarray): array of dimensions (n_samples, n_samples)
    #             Graph of all pair-wise mutual reachability distances
    #             between points.
    #
    #     Returns: minimum_spanning_tree (np.ndarray)
    #         array of dimensions (n_samples, n_samples)
    #         minimum spanning tree of all pair-wise mutual reachability
    #             distances between points.
    #     """
    #     mst = minimum_spanning_tree(dist_tree).toarray()
    #     return mst + np.transpose(mst)
    #
    #
    # def _cluster_density_sparseness(self, MST, labels, cluster):
    #     """
    #     Computes the cluster density sparseness, the minimum density
    #         within a cluster
    #
    #     Args:
    #         MST (np.ndarray): minimum spanning tree of all pair-wise
    #             mutual reachability distances between points.
    #         labels (np.array): clustering assignments for data X
    #         cluster (int): cluster of interest
    #
    #     Returns: cluster_density_sparseness (float)
    #         value corresponding to the minimum density within a cluster
    #     """
    #     indices = np.where(labels == cluster)[0]
    #     cluster_MST = MST[indices][:, indices]
    #     cluster_density_sparseness = np.max(cluster_MST)
    #     return cluster_density_sparseness
    #
    #
    # def _cluster_density_separation(self, MST, labels, cluster_i, cluster_j):
    #     """
    #     Computes the density separation between two clusters, the maximum
    #         density between clusters.
    #
    #     Args:
    #         MST (np.ndarray): minimum spanning tree of all pair-wise
    #             mutual reachability distances between points.
    #         labels (np.array): clustering assignments for data X
    #         cluster_i (int): cluster i of interest
    #         cluster_j (int): cluster j of interest
    #
    #     Returns: density_separation (float):
    #         value corresponding to the maximum density between clusters
    #     """
    #     indices_i = np.where(labels == cluster_i)[0]
    #     indices_j = np.where(labels == cluster_j)[0]
    #     shortest_paths = csgraph.dijkstra(MST, indices=indices_i)
    #     relevant_paths = shortest_paths[:, indices_j]
    #     density_separation = np.min(relevant_paths)
    #     return density_separation
    #
    #
    # def _cluster_validity_index(self, MST, labels, cluster):
    #     """
    #     Computes the validity of a cluster (validity of assignmnets)
    #
    #     Args:
    #         MST (np.ndarray): minimum spanning tree of all pair-wise
    #             mutual reachability distances between points.
    #         labels (np.array): clustering assignments for data X
    #         cluster (int): cluster of interest
    #
    #     Returns: cluster_validity (float)
    #         value corresponding to the validity of cluster assignments
    #     """
    #     min_density_separation = np.inf
    #     for cluster_j in np.unique(labels):
    #         if cluster_j != cluster:
    #             cluster_density_separation = self._cluster_density_separation(MST,
    #                                                                      labels,
    #                                                                      cluster,
    #                                                                      cluster_j)
    #             if cluster_density_separation < min_density_separation:
    #                 min_density_separation = cluster_density_separation
    #     cluster_density_sparseness = self._cluster_density_sparseness(MST,
    #                                                              labels,
    #                                                              cluster)
    #     numerator = min_density_separation - cluster_density_sparseness
    #     denominator = np.max([min_density_separation, cluster_density_sparseness])
    #     cluster_validity = numerator / denominator
    #     return cluster_validity
    #
    #
    # def _clustering_validity_index(self, MST, labels):
    #     """
    #     Computes the validity of all clustering assignments for a
    #     clustering algorithm
    #
    #     Args:
    #         MST (np.ndarray): minimum spanning tree of all pair-wise
    #             mutual reachability distances between points.
    #         labels (np.array): clustering assignments for data X
    #
    #     Returns: validity_index (float):
    #         score in range[-1, 1] indicating validity of clustering assignments
    #     """
    #     n_samples = len(labels)
    #     validity_index = 0
    #     for label in np.unique(labels):
    #         fraction = np.sum(labels == label) / float(n_samples)
    #         cluster_validity = self._cluster_validity_index(MST, labels, label)
    #         validity_index += fraction * cluster_validity
    #     return validity_index
    #
    #
    # def _get_label_members(self, X, labels, cluster):
    #     """
    #     Helper function to get samples of a specified cluster.
    #
    #     Args:
    #         X (np.ndarray): ndarray with dimensions [n_samples, n_features]
    #             data to check validity of clustering
    #         labels (np.array): clustering assignments for data X
    #         cluster (int): cluster of interest
    #
    #     Returns: members (np.ndarray)
    #         array of dimensions (n_samples, n_features) of samples of the
    #         specified cluster.
    #     """
    #     indices = np.where(labels == cluster)[0]
    #     members = X[indices]
    #     return members

    @property
    def centers_(self):
        # return self._labels, self._hdbscan_bscore, self._centers
        return self._centers
    @property
    def labels_(self):
        labels = [f'cluster#{i+1}' if i !=-1 else 'Non clustered' for i in self._labels]
        return labels
    @property
    def non_clustered(self):
        labels = [f'cluster#{i+1}' if i !=-1 else 'Non clustered' for i in self._labels]
        non_clustered = np.where(np.array(labels) == 'Non clustered')[0]
        return non_clustered



    # ~~~~~~~~~~~~~~~~~~~~~~~~~ ap  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class AP:
    def __init__(self, X):
        ## input matrix
        self.__x = np.array(X)

        # Fit PCA model
        self.M = AffinityPropagation(damping=0.5, max_iter=200, convergence_iter=15, copy=True, preference=None,
                                 affinity='euclidean', verbose=False, random_state=None)
        self.M.fit(self.__x)
        self.yp = self.M.predict(self.__x)+1
    @property
    def fit_optimal_(self):
        clu = [f'cluster#{i}' for i in self.yp]
        return self.__x, clu, self.M.cluster_centers_