from pathlib import Path
from typing import List
from sklearn.linear_model import LinearRegression
from typing import List, Dict, Tuple
from sklearn.preprocessing import StandardScaler
from utils.eval_metrics import metrics
import numpy as np
from pandas import DataFrame

# try to automatically detect the field separator within the CSV
# def find_delimiter(filename):
#     import clevercsv
#     with open(filename, newline='') as csvfile:
#         delimiter = clevercsv.Sniffer().sniff(csvfile.read(100)).delimiter
#     # sniffer = csv.Sniffer()
#     # with open(filename) as fp:
#     #     delimiter = sniffer.sniff(fp.read(200)).delimiter
#     return delimiter

# def find_col_index(filename):
#     with open(filename) as fp:
#         lines = read_csv(fp, skiprows=3, nrows=3, index_col=False, sep=find_delimiter(filename))
#         col_index = 'yes' if lines.iloc[:,0].dtypes != np.float64 else 'no'
#     return col_index


# detection of columns categories and scaling
# def col_cat(data_import):
#     """detect numerical and categorical columns in the csv"""
#     # set first column as sample names
#     name_col = DataFrame(list(data_import.index), index = list(data_import.index))
#     # name_col=name_col.rename(columns = {0:'name'})
#     numerical_columns_list = []
#     categorical_columns_list = []
#     for i in data_import.columns:
#         if data_import[i].dtype == np.dtype("float64") or data_import[i].dtype == np.dtype("int64"):
#             numerical_columns_list.append(data_import[i])
#         else:
#             categorical_columns_list.append(data_import[i])
#     if len(numerical_columns_list) == 0:
#         empty = [0 for x in range(len(data_import))]
#         numerical_columns_list.append(empty)
#     if len(categorical_columns_list) > 0:
#         categorical_data = concat(categorical_columns_list, axis=1)
#         categorical_data.insert(0, 'name', name_col)
#     if len(categorical_columns_list) == 0:
#         categorical_data = DataFrame
#     # Create numerical data matrix from the numerical columns list and fill na with the mean of the column
#     numerical_data = concat(numerical_columns_list, axis=1)
#     numerical_data = numerical_data.apply(lambda x: x.fillna(x.mean())) #np.mean(x)))

#     return numerical_data, categorical_data

def fmt(x):
    """
    Returns a formatted string based on the input value.

    If the input `x` evaluates to a falsy value (e.g., `None`, `False`, `0`, `''`), 
    the function returns the string "<Select>". Otherwise, it returns the value of `x` itself.

    Parameters:
    -----------
    x : any type
        The input value to be formatted. Can be any type (e.g., string, integer, etc.).

    Returns:
    --------
    str
        If `x` is a truthy value, the function returns the value of `x`. If `x` is a falsy value, 
        it returns the string "<Select>".

    Example usage:
    --------------
    fmt("Hello")   # Returns: "Hello"
    fmt("")        # Returns: "<Select>"
    fmt(None)      # Returns: "<Select>"
    fmt(0)         # Returns: "<Select>"
    fmt(123)       # Returns: "123"
    """
    return x if x else "<Select>"


def st_var(variable, initialize=True, update=False, type='increment'):
    """
    Manages a variable in the Streamlit session state, allowing it to be initialized, updated, 
    and retained across interactions.

    Parameters:
    -----------
    variable : str
        The name of the variable to store in Streamlit's session state.

    initialize : bool, optional, default=True
        If True, initializes the variable in the session state if it does not exist.
        If False, it does not initialize the variable.

    update : bool, optional, default=False
        If True, increments the value of the variable by 1. This only happens if 
        the variable is already initialized in the session state.

    Notes:
    ------
    - The variable is initialized to `0` when first created if not already in the session state.
    - If `update` is set to True, the function will increment the variable’s value by 1 each time it is called.

    Example usage:
    --------------
    # To initialize the variable
    st_var("counter", initialize=True)

    # To update the variable
    st_var("counter", update=True)
    """

    import streamlit as st

    # Initialize the variable if needed
    if initialize:
        if variable not in st.session_state:
            if type == 'increment':
                st.session_state[variable] = 0
            elif type == 'boolean':
                st.session_state[variable] = False
        else:
            pass

    # Update the variable if needed
    if update:
        if type == 'increment':
            st.session_state[variable] += 1
        elif type == 'boolean':
            st.session_state[variable] = not st.session_state[variable]


def list_files(mypath, import_type):
    """
    Lists all files with a specific extension (based on `import_type`) in the given directory.

    The function searches for files in the directory specified by `mypath` and returns a list of file 
    names with a `.pkl` extension that match the `import_type`. If no such files are found, a message 
    is returned indicating that no models are available.

    Parameters:
    -----------
    mypath : str
        The path to the directory where the files are stored.

    import_type : str
        The type of the model to search for. This string will be appended to `.pkl` to form the file extension.
        For example, if `import_type` is 'svm', the function will look for files with a `.svm.pkl` extension.

    Returns:
    --------
    list
        A list of file names that match the given `import_type` and have the `.pkl` extension.
        If no matching files are found, a list containing a message is returned.

    Example usage:
    --------------
    # To list all SVM model files in the directory
    list_files("/models", "svm")

    # Output might be something like:
    # ['svm_model1.pkl', 'svm_model2.pkl']

    # If no model is found
    list_files("/models", "svm")

    # Output: ['Please, create a model before - no model available yet']
    """

    from os import listdir
    from os.path import isfile, join

    # List files with the specified extension (.pkl and matching import_type)
    list_files = [f for f in listdir(mypath) if isfile(
        join(mypath, f)) and f.endswith(import_type + '.pkl')]

    # Return a message if no files are found
    if list_files == []:
        list_files = ['Please, create a model before - no model available yet']

    return list_files


def standardize(X: DataFrame, center: bool = True, scale: bool = False) -> DataFrame:
    """
    Standardizes the input DataFrame using z-score normalization.

    This function applies standardization to the features in the input DataFrame,
    centering and scaling the data according to the specified parameters. 

    Parameters
    ----------
    X : DataFrame
        A pandas DataFrame containing the data to be standardized. Each column represents a feature.

    center : bool, optional
        If True, the mean of each feature will be subtracted from the data. Default is True.

    scale : bool, optional
        If True, each feature will be scaled to unit variance. Default is False.

    Returns
    -------
    DataFrame
        A pandas DataFrame containing the standardized values, with the same indices and column names
        as the input DataFrame.
    """
    sk = StandardScaler(with_mean=center, with_std=scale)
    sc = DataFrame(sk.fit_transform(X), index=X.index, columns=X.columns)
    return sc

# Spectral preprocessing


def Detrend(X):
    c = detrend(X, axis=-1, type='linear', bp=0, overwrite_data=False)
    return c


def Snv(X: DataFrame) -> DataFrame:
    """
    Performs Standard Normal Variate (SNV) transformation on the input DataFrame.

    This function standardizes each feature by removing the mean and scaling to unit variance.
    The standardization is performed column-wise, and the resulting DataFrame retains the original
    indices and column names.

    Parameters
    ----------
    X : DataFrame
        A pandas DataFrame containing the data to be transformed. Each column represents a feature.

    Returns
    -------
    DataFrame
        A pandas DataFrame containing the standardized values, with the same indices and column names
        as the input DataFrame.
    """
    xt = np.array(X).T
    c = (xt - xt.mean()) / xt.std(axis=0)
    return DataFrame(c.T, index=X.index, columns=X.columns)


def No_transformation(X):
    return X


######################################## Cross val split ############################


class KF_CV:
    """
    A class for implementing cross-validation with Kennard-Stone fold generation.
    Provides methods for generating test set indices, cross-validating a model,
    calculating metrics, and analyzing predictions across folds.

    Methods
    -------
    CV(x, y, n_folds: int) -> Dict[str, np.ndarray]:
        Generates test set indices for each fold based on Kennard-Stone K-Fold.

    cross_val_predictor(model, folds: Dict[str, np.ndarray], x, y) -> Dict[str, np.ndarray]:
        Cross-validates the model, returning predictions for each fold.

    meas_pred_eq(y, ypcv: Dict[str, np.ndarray], folds: Dict[str, np.ndarray]) -> Tuple[DataFrame, DataFrame]:
        Analyzes predictions, returning dataframes for measured and predicted values
        with OLS regression equations and coefficients.

    metrics_cv(y, ypcv: Dict[str, np.ndarray], folds: Dict[str, np.ndarray]) -> Tuple[DataFrame, DataFrame]:
        Computes metrics for each fold, returning dataframes with metric scores per fold
        and summary statistics (mean, standard deviation, coefficient of variation).

    cv_scores(y, ypcv: Dict[str, np.ndarray], folds: Dict[str, np.ndarray]) -> Tuple[DataFrame, DataFrame]:
        Computes fold-wise metrics and provides a summary with mean, sd, and cv.
    """

    @staticmethod
    def CV(x, y, n_folds: int) -> Dict[str, np.ndarray]:
        """
        Generates test set indices for each fold using Kennard-Stone K-Fold.

        Parameters
        ----------
        x : array-like
            Feature matrix used for training.
        y : array-like
            Target variable.
        n_folds : int
            Number of folds for cross-validation.

        Returns
        -------
        Dict[str, np.ndarray]
            Dictionary where keys are fold names and values are numpy arrays
            containing indices of the test set for each fold.
        """
        from kennard_stone import KFold as ks_KFold
        test_folds = {}
        folds_name = ['Fold'+str(i+1) for i in range(n_folds)]
        kf = ks_KFold(n_splits=n_folds, device='cpu')
        for i in range(n_folds):
            d = []
            for _, i_test in kf.split(x, y):
                d.append(i_test)
            test_folds[folds_name[i]] = d[i]
        return test_folds

    @staticmethod
    def cross_val_predictor(model, folds: Dict[str, np.ndarray], x, y) -> Dict[str, np.ndarray]:
        """
        Cross-validates the model and returns predictions for each fold.

        Parameters
        ----------
        model : estimator object
            Model to be cross-validated.
        folds : Dict[str, np.ndarray]
            Dictionary with fold names as keys and test set indices as values (from CV method).
        x : array-like
            Feature matrix.
        y : array-like
            Target variable.

        Returns
        -------
        Dict[str, np.ndarray]
            Dictionary where keys are fold names and values are the predicted
            target values for each fold.
        """
        x = np.array(x)
        y = np.array(y)
        yp = {}
        key = list(folds.keys())
        n_folds = len(folds.keys())

        for i in range(n_folds):
            model.fit(np.delete(x, folds[key[i]], axis=0), np.delete(
                y, folds[key[i]], axis=0))
            yp[key[i]] = model.predict(x[folds[key[i]]])
        return yp

    @staticmethod
    def meas_pred_eq(y, ypcv: Dict[str, np.ndarray], folds: Dict[str, np.ndarray]) -> Tuple[DataFrame, DataFrame]:
        """
        Computes and returns measured vs predicted data with regression equations.

        Parameters
        ----------
        y : array-like
            Target variable.
        ypcv : Dict[str, np.ndarray]
            Dictionary with fold names as keys and predicted values per fold as values.
        folds : Dict[str, np.ndarray]
            Dictionary with fold names as keys and test set indices per fold as values.

        Returns
        -------
        Tuple[DataFrame, DataFrame]
            - DataFrame with measured and predicted values and regression equation per fold.
            - DataFrame with regression coefficients (slope and intercept) for each fold.
        """
        cvcv = {}
        coeff = {}
        y = np.array(y)
        for i, Fname in enumerate(folds.keys()):
            r = DataFrame()
            r['Predicted'] = ypcv[Fname]
            r['Measured'] = y[folds[Fname]]
            ols = LinearRegression().fit(
                DataFrame(y[folds[Fname]]), ypcv[Fname].reshape(-1, 1))
            r.index = folds[Fname]
            r['Folds'] = [str(Fname)+'(Predicted = '+str(np.round(ols.intercept_[0], 2)) +
                          str(np.round(ols.coef_[0][0], 2))+' x Measured'+ ')'] * r.shape[0]
            cvcv[i] = r
            coeff[Fname] = [ols.coef_[0][0], ols.intercept_[0]]

        from pandas import concat
        data = concat(cvcv, axis=0)
        data['index'] = [data.index[i][1] for i in range(data.shape[0])]
        data.index = data['index']
        coeff = DataFrame(coeff, index=['Slope', 'Intercept'])
        return data, coeff

    @staticmethod
    def metrics_cv(y, ypcv: Dict[str, np.ndarray], folds: Dict[str, np.ndarray]) -> Tuple[DataFrame, DataFrame]:
        """
        Computes and returns evaluation metrics for each fold.

        Parameters
        ----------
        y : array-like
            Target variable.
        ypcv : Dict[str, np.ndarray]
            Dictionary with fold names as keys and predicted values per fold as values.
        folds : Dict[str, np.ndarray]
            Dictionary with fold names as keys and test set indices per fold as values.

        Returns
        -------
        Tuple[DataFrame, DataFrame]
            - DataFrame with metrics for each fold.
            - DataFrame with additional mean, standard deviation, and coefficient of variation.
        """
        y = np.array(y)
        e = {}
        for i in folds.keys():
            e[i] = metrics().reg_(y[folds[i]], ypcv[i])
        r = DataFrame(e)
        r_print = r.copy()
        r_print['mean'] = r.mean(axis=1)
        r_print['sd'] = r.std(axis=1)
        r_print['cv'] = 100 * r.std(axis=1) / r.mean(axis=1)
        return r.T, r_print.T

    @staticmethod
    def cv_scores(y, ypcv: Dict[str, np.ndarray], folds: Dict[str, np.ndarray]) -> Tuple[DataFrame, DataFrame]:
        """
        Computes and returns fold-wise evaluation scores with summary statistics.

        Parameters
        ----------
        y : array-like
            Target variable.
        ypcv : Dict[str, np.ndarray]
            Dictionary with fold names as keys and predicted values per fold as values.
        folds : Dict[str, np.ndarray]
            Dictionary with fold names as keys and test set indices per fold as values.

        Returns
        -------
        Tuple[DataFrame, DataFrame]
            - DataFrame with metric scores per fold.
            - DataFrame with metric scores along with mean, sd, and cv values.
        """
        y = np.array(y)
        e = {}
        for i in folds.keys():
            e[i] = metrics().reg_(y[folds[i]], ypcv[i])
        r = DataFrame(e)
        r_print = r
        r_print['mean'] = r.mean(axis=1)
        r_print['sd'] = r.std(axis=1)
        r_print['cv'] = 100 * r.std(axis=1) / r.mean(axis=1)
        return r.T, r_print.T

    # ### Return ycv
    # @staticmethod
    # def ycv(model, x, y, n_folds:int):
    #     ycv = np.zeros(y.shape[0])
    #     f, idx,_,_ = KF_CV.cross_val_predictor(model, x,y, n_folds)
    #     for i in f.keys():
    #         ycv[idx[i]] = f[i]
    #     return ycv


# Selectivity ratio
def sel_ratio(model, x):
    """
    Computes the Selectivity Ratio (SR) for variable importance based on the provided regression model 
    and dataset. The SR is calculated as the ratio of explained variance to residual variance for each 
    feature, and it is used to identify significant features in the model.

    Parameters:
    -----------
    model : sklearn estimator
        A fitted model with the `coef_` attribute (e.g., linear regression, PCA, PLS) that contains the 
        coefficients used to predict the target variable.

    x : array-like or pandas DataFrame
        The dataset (features) for which the Selectivity Ratio is to be calculated. It should be a 2D array 
        or a pandas DataFrame where columns represent the features.

    Returns:
    --------
    pandas DataFrame
        A DataFrame containing the Selectivity Ratio (SR) for each feature. Features with SR greater than 
        a critical F-value are considered significant and are returned in the output DataFrame.

    Notes:
    ------
    The Selectivity Ratio (SR) is computed as:
        SR = qexpi / qres
    where:
        - qexpi is the explained variance for each feature.
        - qres is the residual variance for each feature.

    The critical F-value is determined using the 0.05 percentile of the F-distribution (`scipy.stats.f.ppf`), 
    which serves as a threshold to decide if a feature is statistically significant.

    Example usage:
    --------------
    # Assuming `model` is a fitted model and `x` is the dataset
    SR = sel_ratio(model, x)
    """

    from scipy.stats import f
    import numpy as np
    from pandas import DataFrame

    # Convert input dataset to DataFrame
    x = DataFrame(x)

    # Normalize the model's coefficients
    wtp = model.coef_.T / np.linalg.norm(model.coef_.T)

    # Calculate the scores (ttp)
    ttp = np.array(x @ wtp)

    # Calculate the projection matrix (ptp)
    ptp = np.array(x.T) @ np.array(ttp) / (ttp.T @ ttp)

    # Calculate the explained variance for each feature
    qexpi = np.linalg.norm(ttp @ ptp.T, axis=0) ** 2

    # Calculate residuals (e) and residual variance for each feature
    e = np.array(x - x.mean()) - ttp @ ptp.T
    qres = np.linalg.norm(e, axis=0) ** 2

    # Compute the selection ratio for each feature
    sr = DataFrame(qexpi / qres, index=x.columns, columns=['sr'])

    # Determine the critical value from the F-distribution
    fcr = f.ppf(0.05, sr.shape[0] - 2, sr.shape[0] - 3)

    # Identify features with SR greater than the critical value
    c = sr > fcr

    # Reindex the result
    sr.index = np.arange(x.shape[1])

    # Return only features that pass the statistical test
    SR = sr.iloc[c.to_numpy(), :]
    return SR


#####################################


class HandleItems:
    """
    A utility class for managing files and directories, providing static methods to
    delete files, delete directories, and create directories based on given conditions.

    Methods
    -------
    delete_files(keep: List[str]):
        Deletes files from the "report" directory except specified files to keep.

    delete_dir(delete: List[str]):
        Deletes specified directories if they exist.

    create_dir(path: List[str]):
        Creates directories if they do not already exist.
    """

    @staticmethod
    def delete_files(keep: List[str]):
        """
        Deletes files in the "report" directory, except for those that match the
        specified extensions or the file 'logo_cefe.png'.

        Parameters
        ----------
        keep : List[str]
            A list of file extensions to keep in the directory. Files ending with any
            of these extensions will not be deleted.
        """
        from os import walk, remove, path

        # Walk through the directory
        for root, dirs, files in walk(Path("report"), topdown=False):
            for file in files:
                # Check if file should not be deleted
                if file != 'logo_cefe.png' and not any(file.endswith(ext) for ext in keep):
                    remove(path.join(root, file))

    @staticmethod
    def delete_dir(delete: List[str]):
        """
        Deletes specified directories if they exist.

        Parameters
        ----------
        delete : List[str]
            A list of directory paths to delete. Only directories that exist will be removed.
        """
        from shutil import rmtree
        for i in delete:
            dirpath = Path(i)
            if dirpath.exists() and dirpath.is_dir():
                rmtree(dirpath)

    @staticmethod
    def create_dir(path: List[str]):
        """
        Creates directories if they do not already exist.

        Parameters
        ----------
        path : List[str]
            A list of directory paths to create. Directories will only be created if
            they do not already exist.
        """
        for i in path:
            dirpath = Path(i)
            if not dirpath.exists():
                dirpath.mkdir(parents=True, exist_ok=True)