import jcamp as jc
import numpy as np
from tempfile import NamedTemporaryFile


def jcamp_parser(path, include, change=None):
    """
    Parses a JCAMP-DX file and extracts spectral data, target concentrations, 
    and metadata as per the specified `include` parameter.

    Parameters:
        path (str): The file path to the JCAMP-DX file to be parsed.
        include (list): Specifies which data blocks to include in the output. 
                        Options are:
                          - 'x_block': Extract spectra.
                          - 'y_block': Extract target concentrations.
                          - 'meta': Extract metadata.
                          - 'all': Extract all available information (default).

    Returns:
        tuple: (x_block, y_block, met)
            - x_block (DataFrame): Spectral data with samples as rows and wavelengths as columns.
            - y_block (DataFrame): Target concentrations with samples as rows and analytes as columns.
            - met (DataFrame): Metadata for each sample.
    """
    import jcamp as jc
    import numpy as np
    from pandas import DataFrame
    import re

    # Read the JCAMP-DX file
    dxfile = jc.jcamp_readfile(path)
    nb = dxfile['blocks']
    list_of_blocks = dxfile['children']

    idx = []  # List to store sample names
    metdata = {}  # Dictionary to store metadata

    # Preallocate matrix for spectral data if 'x_block' or 'all' is included
    if 'x_block' in include or 'all' in include:
        specs = np.zeros((nb, len(list_of_blocks[0]["y"])), dtype=float)

    # Initialize containers for target concentrations if 'y_block' or 'all' is included
    if 'y_block' in include or 'all' in include:
        targets_tuple = {}
        pattern = r"\(([^,]+),(\d+(\.\d+)?),([^)]+)"
        aa = list_of_blocks[0]['concentrations']
        a = '\n'.join(line for line in aa.split('\n')
                      if "NCU" not in line and "<<undef>>" not in line)
        n_elements = a.count('(')
        # Extract chemical element names
        elements_name = [match[0] for match in re.findall(pattern, a)]

        # Helper function to extract concentration values
        def conc(sample=None, pattern=None):
            prep = '\n'.join(line for line in sample.split(
                '\n') if "NCU" not in line and "<<undef>>" not in line)
            c = [np.NaN if match[1] == '0' else np.float64(
                match[1]) for match in re.findall(pattern, prep)]
            return np.array(c)

    # Loop through all blocks in the file
    for i in range(nb):
        idx.append(str(list_of_blocks[i]['title']))  # Store sample names

        # Extract spectra if 'x_block' or 'all' is included
        if 'x_block' in include or 'all' in include:
            specs[i] = list_of_blocks[i]['y']

        # Extract metadata if 'meta' or 'all' is included
        block = list_of_blocks[i]
        if 'meta' in include or 'all' in include:
            metdata[i] = {
                'name': block['title'],
                'origin': block['origin'],
                'date': block['date'],
                'spectrometer': block['spectrometer/data system'].split('\n$$')[0],
                'n_scans': block['spectrometer/data system'].split('\n$$')[6].split('=')[1],
                'resolution': block['spectrometer/data system'].split('\n$$')[8].split('=')[1],
                'xunits': block['xunits'],
                'yunits': block['yunits'],
                'firstx': block['firstx'],
                'lastx': block['lastx'],
                'npoints': block['npoints'],
            }

        # Extract target concentrations if 'y_block' or 'all' is included
        if 'y_block' in include or 'all' in include:
            targets_tuple[i] = conc(
                sample=block['concentrations'], pattern=pattern)

    # Create DataFrame for target concentrations
    if 'y_block' in include or 'all' in include:
        y_block = DataFrame(targets_tuple).T
        y_block.columns = elements_name
        y_block.index = idx
    else:
        y_block = DataFrame

    # Create DataFrame for spectral data
    if 'x_block' in include or 'all' in include:
        wls = list_of_blocks[0]["x"]  # Wavelengths/frequencies/range
        x_block = DataFrame(specs, columns=wls, index=idx).astype('float64')
    else:
        x_block = DataFrame

    # Create DataFrame for metadata
    if 'meta' in include or 'all' in include:
        m = DataFrame(metdata).T
        m.index = idx
        met = m.drop(m.columns[(m == '').all()], axis=1)
    else:
        met = DataFrame

    return x_block, y_block, met


def csv_parser(path, decimal, separator, index_col, header, change=None):
    """
    Parse a CSV file and return two DataFrames: one with floating point columns and the other with non-floating point columns.

    Parameters:
    -----------
    path : str
        The file path to the CSV file to be read.

    decimal : str
        Character to recognize as decimal separator (e.g., '.' or ',').

    separator : str
        The character used to separate values in the CSV file (e.g., ',' or '\t').

    index_col : int or str, optional
        Column to set as the index of the DataFrame. Default is None.

    header : int, list of int, or None, optional
        Row(s) to use as the header. Default is 'infer'.

    Returns:
    --------
    tuple
        A tuple containing two DataFrames:
        - float : DataFrame with columns that are of type float.
        - non_float : DataFrame with non-floating point columns, with strings uppercased if applicable.

    Notes:
    ------
    - This function reads a CSV file into a pandas DataFrame, then separates the columns into floating point and non-floating point types.
    - The non-floating columns will be converted to uppercase if they are of string type, unless a `change` function is provided to modify them otherwise.
    - If `change` is provided, it will be applied to the non-floating point columns before returning them.
    """
    from pandas import read_csv
    df = read_csv(path, decimal=decimal, sep=separator,
                  index_col=index_col, header=header)

    # Select columns with float data type
    float = df.select_dtypes(include='float')

    # Select columns without float data type and apply changes (like uppercasing strings)
    non_float = df.select_dtypes(exclude='float')

    return float, non_float


def meta_st(df):
    """
    Preprocesses a DataFrame by retaining columns with between 2 and 59 unique values 
    and converting string columns to uppercase.

    Parameters:
    -----------
    df : pandas.DataFrame
        The input DataFrame to be processed.

    Returns:
    --------
    pandas.DataFrame
        A DataFrame that:
        - Retains columns with between 2 and 59 unique values.
        - Converts string columns to uppercase (if applicable).
        - Returns an empty DataFrame if the input DataFrame is empty.

    Notes:
    ------
    - The function filters out columns with fewer than 2 unique values or more than 59 unique values.
    - String columns (non-numeric columns) are converted to uppercase.
    - If the input DataFrame is empty, it returns an empty DataFrame.

    Example:
    --------
    import pandas as pd

    data = {
        'Name': ['alice', 'bob', 'charlie'],
        'Age': [25, 30, 35],
        'Country': ['usa', 'uk', 'canada'],
        'Score': [90.5, 88.0, 92.3],
        'IsActive': [True, False, True]
    }

    df = pd.DataFrame(data)

    # Apply the function
    result = meta_st(df)

    print(result)
    """
    import pandas as pd

    if not df.empty:
        # Convert string columns to uppercase
        for i in df.columns:
            try:
                df[[i]].astype('float')
            except:
                df[[i]] = df[[i]].apply(lambda x: x.str.upper())

        # Retain columns with unique values between 2 and 59
        retained = df.loc[:, (df.nunique() > 1) & (df.nunique() < 60)]
    else:
        # Return an empty DataFrame if the input DataFrame is empty
        retained = pd.DataFrame()

    return retained

    # def parse(self):
    #     import pandas as pd

    #     dec_dia = ['.', ',']
    #     sep_dia = [',', ';']
    #     dec, sep = [], []

    #     with open(self.file, mode = 'r') as csvfile:
    #         lines = [csvfile.readline() for i in range(3)]
    #         for i in lines:
    #             for j in range(2):
    #                 dec.append(i.count(dec_dia[j]))
    #                 sep.append(i.count(sep_dia[j]))

    #     if dec[0] != dec[2]:
    #         header = 0
    #     else:
    #         header = 0

    #     semi = np.sum([sep[2*i+1] for i in range(3)])
    #     commas = np.sum([sep[2*i] for i in range(3)])

    #     if semi>commas:separator = ';'
    #     elif semi<commas: separator = ','

    #     elif semi ==0 and commas == 0: separator = ';'

    #     commasdec = np.sum([dec[2*i+1] for i in range(1,3)])
    #     dot = np.sum([dec[2*i] for i in range(1,3)])
    #     if commasdec>dot:decimal = ','
    #     elif commasdec<=dot:decimal = '.'

    #     if decimal == separator or len(np.unique(dec)) <= 2:
    #         decimal = "."

    #     df = pd.read_csv(self.file, decimal=decimal, sep=separator, header=None, index_col=None)
    #     try:
    #         rat = np.mean(df.iloc[0,50:60]/df.iloc[5,50:60])>10
    #         header = 0 if rat or np.nan else None
    #     except:
    #         header = 0

    #     from pandas.api.types import is_float_dtype

    #     if is_float_dtype(df.iloc[1:,0]):
    #         index_col = None
    #     else:
    #         try:
    #             te = df.iloc[1:,0].to_numpy().astype(float).dtype

    #         except:
    #             te = set(df.iloc[1:,0])

    #         if len(te) == df.shape[0]-1:
    #             index_col = 0
    #         elif len(te) < df.shape[0]-1:
    #             index_col = None
    #         else:
    #             index_col = None

    #     # index_col = 0 if len(set(df.iloc[1:,0])) == df.shape[0]-1 and is_float_dtype(df.iloc[:,0])==False else None
    #     df = pd.read_csv(self.file, decimal=decimal, sep=separator, header=header, index_col=index_col)
    #     # st.write(decimal, separator, index_col, header)

    #     if df.select_dtypes(exclude='float').shape[1] >0:
    #         non_float = df.select_dtypes(exclude='float')

    #     else:
    #         non_float = pd.DataFrame()

    #     if df.select_dtypes(include='float').shape[1] >0:
    #         float_data = df.select_dtypes(include='float')

    #     else:
    #         float_data = pd.DataFrame()
    #     return float_data, non_float


# ############## new function
# def csv_loader(file):
#     import clevercsv
#     import numpy as np
#     import pandas as pd

#     dec_dia = ['.',',']
#     sep_dia = [',',';']
#     dec, sep = [], []
#     with open(file, mode = 'r') as csvfile:
#         lines = [csvfile.readline() for i in range(3)]
#         for i in lines:
#             for j in range(2):
#                 dec.append(i.count(dec_dia[j]))
#                 sep.append(i.count(sep_dia[j]))

#     if dec[0] != dec[2]:
#         header = 0
#     else:
#         header = 0


#     semi = np.sum([sep[2*i+1] for i in range(3)])
#     commas = np.sum([sep[2*i] for i in range(3)])

#     if semi>commas:separator = ';'
#     elif semi<commas: separator = ','

#     elif semi ==0 and commas == 0: separator = ';'


#     commasdec = np.sum([dec[2*i+1] for i in range(1,3)])
#     dot = np.sum([dec[2*i] for i in range(1,3)])
#     if commasdec>dot:decimal = ','
#     elif commasdec<=dot:decimal = '.'

#     if decimal == separator or len(np.unique(dec)) <= 2:
#         decimal = "."

#     df = pd.read_csv(file, decimal=decimal, sep=separator, header=None, index_col=None)
#     try:
#         rat = np.mean(df.iloc[0,50:60]/df.iloc[5,50:60])>10
#         header = 0 if rat or np.nan else None
#     except:
#         header = 0

#     from pandas.api.types import is_float_dtype

#     if is_float_dtype(df.iloc[1:,0]):
#         index_col = None
#     else:
#         try:
#             te = df.iloc[1:,0].to_numpy().astype(float).dtype

#         except:
#             te = set(df.iloc[1:,0])

#         if len(te) == df.shape[0]-1:
#             index_col = 0
#         elif len(te) < df.shape[0]-1:
#             index_col = None
#         else:
#             index_col = None

#     # index_col = 0 if len(set(df.iloc[1:,0])) == df.shape[0]-1 and is_float_dtype(df.iloc[:,0])==False else None
#     df = pd.read_csv(file, decimal=decimal, sep=separator, header=header, index_col=index_col)
#     # st.write(decimal, separator, index_col, header)

#     if df.select_dtypes(exclude='float').shape[1] >0:
#         non_float = df.select_dtypes(exclude='float')

#     else:
#         non_float = pd.DataFrame()


#     if df.select_dtypes(include='float').shape[1] >0:
#         float_data = df.select_dtypes(include='float')

#     else:
#         float_data = pd.DataFrame()
#     return float_data, non_float