from utils.eval_metrics import metrics
import numpy as np
from pandas import DataFrame

## try to automatically detect the field separator within the CSV
# def find_delimiter(filename):
#     import clevercsv
#     with open(filename, newline='') as csvfile:
#         delimiter = clevercsv.Sniffer().sniff(csvfile.read(100)).delimiter
#     # sniffer = csv.Sniffer()
#     # with open(filename) as fp:
#     #     delimiter = sniffer.sniff(fp.read(200)).delimiter
#     return delimiter

# def find_col_index(filename):
#     with open(filename) as fp:
#         lines = read_csv(fp, skiprows=3, nrows=3, index_col=False, sep=find_delimiter(filename))
#         col_index = 'yes' if lines.iloc[:,0].dtypes != np.float64 else 'no'
#     return col_index


# detection of columns categories and scaling
# def col_cat(data_import):
#     """detect numerical and categorical columns in the csv"""
#     # set first column as sample names
#     name_col = DataFrame(list(data_import.index), index = list(data_import.index))
#     # name_col=name_col.rename(columns = {0:'name'})
#     numerical_columns_list = []
#     categorical_columns_list = []
#     for i in data_import.columns:
#         if data_import[i].dtype == np.dtype("float64") or data_import[i].dtype == np.dtype("int64"):
#             numerical_columns_list.append(data_import[i])
#         else:
#             categorical_columns_list.append(data_import[i])
#     if len(numerical_columns_list) == 0:
#         empty = [0 for x in range(len(data_import))]
#         numerical_columns_list.append(empty)
#     if len(categorical_columns_list) > 0:
#         categorical_data = concat(categorical_columns_list, axis=1)
#         categorical_data.insert(0, 'name', name_col)
#     if len(categorical_columns_list) == 0:
#         categorical_data = DataFrame
#     # Create numerical data matrix from the numerical columns list and fill na with the mean of the column
#     numerical_data = concat(numerical_columns_list, axis=1)
#     numerical_data = numerical_data.apply(lambda x: x.fillna(x.mean())) #np.mean(x)))

#     return numerical_data, categorical_data

def fmt(x):
    return x if x else "<Select>"


def update_counter():
    import streamlit as st
    if 'counter' not in st.session_state:
        st.session_state["counter"] = 0
    else:
        st.session_state["counter"] += 1


def list_files(mypath, import_type):
    list_files = [f for f in listdir(mypath) if isfile(join(mypath, f)) and f.endswith(import_type + '.pkl')]
    if list_files == []:
        list_files = ['Please, create a model before - no model available yet']
    return list_files



from pandas import DataFrame
from sklearn.preprocessing import StandardScaler

def standardize(X: DataFrame, center: bool = True, scale: bool = False) -> DataFrame:
    """
    Standardizes the input DataFrame using z-score normalization.

    This function applies standardization to the features in the input DataFrame,
    centering and scaling the data according to the specified parameters. 

    Parameters
    ----------
    X : DataFrame
        A pandas DataFrame containing the data to be standardized. Each column represents a feature.

    center : bool, optional
        If True, the mean of each feature will be subtracted from the data. Default is True.

    scale : bool, optional
        If True, each feature will be scaled to unit variance. Default is False.

    Returns
    -------
    DataFrame
        A pandas DataFrame containing the standardized values, with the same indices and column names
        as the input DataFrame.
    """
    sk = StandardScaler(with_mean=center, with_std=scale)
    sc = DataFrame(sk.fit_transform(X), index=X.index, columns=X.columns)
    return sc

######################################## Spectral preprocessing
def Detrend(X):
    c = detrend(X, axis=-1, type='linear', bp=0, overwrite_data=False)
    return c




def Snv(X: DataFrame) -> DataFrame:
    """
    Performs Standard Normal Variate (SNV) transformation on the input DataFrame.

    This function standardizes each feature by removing the mean and scaling to unit variance.
    The standardization is performed column-wise, and the resulting DataFrame retains the original
    indices and column names.

    Parameters
    ----------
    X : DataFrame
        A pandas DataFrame containing the data to be transformed. Each column represents a feature.

    Returns
    -------
    DataFrame
        A pandas DataFrame containing the standardized values, with the same indices and column names
        as the input DataFrame.
    """
    xt = np.array(X).T
    c = (xt - xt.mean()) / xt.std(axis=0)
    return DataFrame(c.T, index=X.index, columns=X.columns)


def No_transformation(X):
    return X


######################################## Cross val split ############################
from typing import List, Dict, Tuple
import numpy as np
from pandas import DataFrame
from sklearn.linear_model import LinearRegression

class KF_CV:
    """
    A class for implementing cross-validation with Kennard-Stone fold generation.
    Provides methods for generating test set indices, cross-validating a model,
    calculating metrics, and analyzing predictions across folds.

    Methods
    -------
    CV(x, y, n_folds: int) -> Dict[str, np.ndarray]:
        Generates test set indices for each fold based on Kennard-Stone K-Fold.
        
    cross_val_predictor(model, folds: Dict[str, np.ndarray], x, y) -> Dict[str, np.ndarray]:
        Cross-validates the model, returning predictions for each fold.
        
    meas_pred_eq(y, ypcv: Dict[str, np.ndarray], folds: Dict[str, np.ndarray]) -> Tuple[DataFrame, DataFrame]:
        Analyzes predictions, returning dataframes for measured and predicted values
        with OLS regression equations and coefficients.
        
    metrics_cv(y, ypcv: Dict[str, np.ndarray], folds: Dict[str, np.ndarray]) -> Tuple[DataFrame, DataFrame]:
        Computes metrics for each fold, returning dataframes with metric scores per fold
        and summary statistics (mean, standard deviation, coefficient of variation).
        
    cv_scores(y, ypcv: Dict[str, np.ndarray], folds: Dict[str, np.ndarray]) -> Tuple[DataFrame, DataFrame]:
        Computes fold-wise metrics and provides a summary with mean, sd, and cv.
    """

    @staticmethod
    def CV(x, y, n_folds: int) -> Dict[str, np.ndarray]:
        """
        Generates test set indices for each fold using Kennard-Stone K-Fold.

        Parameters
        ----------
        x : array-like
            Feature matrix used for training.
        y : array-like
            Target variable.
        n_folds : int
            Number of folds for cross-validation.

        Returns
        -------
        Dict[str, np.ndarray]
            Dictionary where keys are fold names and values are numpy arrays
            containing indices of the test set for each fold.
        """
        from kennard_stone import KFold as ks_KFold
        test_folds = {}
        folds_name = [f'Fold{i+1}' for i in range(n_folds)]
        kf = ks_KFold(n_splits=n_folds, device='cpu')
        for i in range(n_folds):
            d = []
            for _, i_test in kf.split(x, y):
                d.append(i_test)
            test_folds[folds_name[i]] = d[i]        
        return test_folds

    @staticmethod
    def cross_val_predictor(model, folds: Dict[str, np.ndarray], x, y) -> Dict[str, np.ndarray]:
        """
        Cross-validates the model and returns predictions for each fold.

        Parameters
        ----------
        model : estimator object
            Model to be cross-validated.
        folds : Dict[str, np.ndarray]
            Dictionary with fold names as keys and test set indices as values (from CV method).
        x : array-like
            Feature matrix.
        y : array-like
            Target variable.

        Returns
        -------
        Dict[str, np.ndarray]
            Dictionary where keys are fold names and values are the predicted
            target values for each fold.
        """
        x = np.array(x)
        y = np.array(y)
        yp = {}
        key = list(folds.keys())
        n_folds = len(folds.keys())

        for i in range(n_folds):
            model.fit(np.delete(x, folds[key[i]], axis=0), np.delete(y, folds[key[i]], axis=0))
            yp[key[i]] = model.predict(x[folds[key[i]]])
        return yp

    @staticmethod
    def meas_pred_eq(y, ypcv: Dict[str, np.ndarray], folds: Dict[str, np.ndarray]) -> Tuple[DataFrame, DataFrame]:
        """
        Computes and returns measured vs predicted data with regression equations.

        Parameters
        ----------
        y : array-like
            Target variable.
        ypcv : Dict[str, np.ndarray]
            Dictionary with fold names as keys and predicted values per fold as values.
        folds : Dict[str, np.ndarray]
            Dictionary with fold names as keys and test set indices per fold as values.

        Returns
        -------
        Tuple[DataFrame, DataFrame]
            - DataFrame with measured and predicted values and regression equation per fold.
            - DataFrame with regression coefficients (slope and intercept) for each fold.
        """
        cvcv = {}
        coeff = {}
        y = np.array(y)
        for i, Fname in enumerate(folds.keys()):
            r = DataFrame()
            r['Predicted'] = ypcv[Fname]
            r['Measured'] = y[folds[Fname]]
            ols = LinearRegression().fit(DataFrame(y[folds[Fname]]), ypcv[Fname].reshape(-1,1))
            r.index = folds[Fname]
            r['Folds'] = [f'{Fname} (Predicted = {np.round(ols.intercept_[0], 2)} + {np.round(ols.coef_[0][0], 2)} x Measured'] * r.shape[0]
            cvcv[i] = r
            coeff[Fname] = [ols.coef_[0][0], ols.intercept_[0]]

        from pandas import concat
        data = concat(cvcv, axis=0)
        data['index'] = [data.index[i][1] for i in range(data.shape[0])]
        data.index = data['index']
        coeff = DataFrame(coeff, index=['Slope', 'Intercept'])
        return data, coeff

    @staticmethod
    def metrics_cv(y, ypcv: Dict[str, np.ndarray], folds: Dict[str, np.ndarray]) -> Tuple[DataFrame, DataFrame]:
        """
        Computes and returns evaluation metrics for each fold.

        Parameters
        ----------
        y : array-like
            Target variable.
        ypcv : Dict[str, np.ndarray]
            Dictionary with fold names as keys and predicted values per fold as values.
        folds : Dict[str, np.ndarray]
            Dictionary with fold names as keys and test set indices per fold as values.

        Returns
        -------
        Tuple[DataFrame, DataFrame]
            - DataFrame with metrics for each fold.
            - DataFrame with additional mean, standard deviation, and coefficient of variation.
        """
        y = np.array(y)
        e = {}
        for i in folds.keys():
            e[i] = metrics().reg_(y[folds[i]], ypcv[i])
        r = DataFrame(e)
        r_print = r.copy()
        r_print['mean'] = r.mean(axis=1)
        r_print['sd'] = r.std(axis=1)
        r_print['cv'] = 100 * r.std(axis=1) / r.mean(axis=1)
        return r.T, r_print.T

    @staticmethod
    def cv_scores(y, ypcv: Dict[str, np.ndarray], folds: Dict[str, np.ndarray]) -> Tuple[DataFrame, DataFrame]:
        """
        Computes and returns fold-wise evaluation scores with summary statistics.

        Parameters
        ----------
        y : array-like
            Target variable.
        ypcv : Dict[str, np.ndarray]
            Dictionary with fold names as keys and predicted values per fold as values.
        folds : Dict[str, np.ndarray]
            Dictionary with fold names as keys and test set indices per fold as values.

        Returns
        -------
        Tuple[DataFrame, DataFrame]
            - DataFrame with metric scores per fold.
            - DataFrame with metric scores along with mean, sd, and cv values.
        """
        y = np.array(y)
        e = {}
        for i in folds.keys():
            e[i] = metrics().reg_(y[folds[i]], ypcv[i])
        r = DataFrame(e)
        r_print = r
        r_print['mean'] = r.mean(axis=1)
        r_print['sd'] = r.std(axis=1)
        r_print['cv'] = 100 * r.std(axis=1) / r.mean(axis=1)
        return r.T, r_print.T

    
    # ### Return ycv
    # @staticmethod
    # def ycv(model, x, y, n_folds:int):
    #     ycv = np.zeros(y.shape[0])
    #     f, idx,_,_ = KF_CV.cross_val_predictor(model, x,y, n_folds)
    #     for i in f.keys():
    #         ycv[idx[i]] = f[i]            
    #     return ycv


### Selectivity ratio
def sel_ratio(model, x ):
    from scipy.stats import f

    x = DataFrame(x)
    wtp = model.coef_.T/ np.linalg.norm(model.coef_.T)
    ttp = np.array(x @ wtp)
    ptp = np.array(x.T) @ np.array(ttp)/(ttp.T @ ttp)
    qexpi = np.linalg.norm(ttp @ ptp.T, axis = 0)**2
    e = np.array(x-x.mean()) - ttp @ ptp.T
    qres = np.linalg.norm(e, axis = 0)**2
    sr = DataFrame(qexpi/qres, index = x.columns, columns = ['sr'])

    fcr = f.ppf(0.05, sr.shape[0]-2, sr.shape[0]-3)
    c = sr > fcr
    sr.index = np.arange(x.shape[1])
    SR = sr.iloc[c.to_numpy(),:]
    return SR





#####################################
from typing import List
from pathlib import Path

class HandleItems:
    """
    A utility class for managing files and directories, providing static methods to
    delete files, delete directories, and create directories based on given conditions.

    Methods
    -------
    delete_files(keep: List[str]):
        Deletes files from the "report" directory except specified files to keep.

    delete_dir(delete: List[str]):
        Deletes specified directories if they exist.

    create_dir(path: List[str]):
        Creates directories if they do not already exist.
    """

    @staticmethod
    def delete_files(keep: List[str]):
        """
        Deletes files in the "report" directory, except for those that match the
        specified extensions or the file 'logo_cefe.png'.

        Parameters
        ----------
        keep : List[str]
            A list of file extensions to keep in the directory. Files ending with any
            of these extensions will not be deleted.
        """
        from os import walk, remove, path
        
        # Walk through the directory
        for root, dirs, files in walk(Path("report"), topdown=False):
            for file in files:
                # Check if file should not be deleted
                if file != 'logo_cefe.png' and not any(file.endswith(ext) for ext in keep):
                    remove(path.join(root, file))

    @staticmethod
    def delete_dir(delete: List[str]):
        """
        Deletes specified directories if they exist.

        Parameters
        ----------
        delete : List[str]
            A list of directory paths to delete. Only directories that exist will be removed.
        """
        from shutil import rmtree
        for i in delete:
            dirpath = Path(i)
            if dirpath.exists() and dirpath.is_dir():
                rmtree(dirpath)

    @staticmethod
    def create_dir(path: List[str]):
        """
        Creates directories if they do not already exist.

        Parameters
        ----------
        path : List[str]
            A list of directory paths to create. Directories will only be created if
            they do not already exist.
        """
        for i in path:
            dirpath = Path(i)
            if not dirpath.exists():
                dirpath.mkdir(parents=True, exist_ok=True)