import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ predictions histogram ~~~~~~~~~~~~~~~~~~~~~~~~~~


@st.cache_data
def pred_hist(pred):
    # Creating histogram
    hist, axs = plt.subplots(1, 1, figsize=(15, 3),
                             tight_layout=True)

    # Add x, y gridlines
    axs.grid(color='grey', linestyle='-.', linewidth=0.5, alpha=0.6)
    # Remove axes splines
    for s in ['top', 'bottom', 'left', 'right']:
        axs.spines[s].set_visible(False)
    # Remove x, y ticks
    axs.xaxis.set_ticks_position('none')
    axs.yaxis.set_ticks_position('none')
    # Add padding between axes and labels
    axs.xaxis.set_tick_params(pad=5)
    axs.yaxis.set_tick_params(pad=10)
    # Creating histogram
    N, bins, patches = axs.hist(pred, bins=12)
    return hist


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ predictions histogram ~~~~~~~~~~~~~~~~~~~~~~~~~~
@st.cache_data
def plot_spectra(specdf=None, color=None, cmap=None, xunits=None, yunits=None, mean=False):
    # pass
    import matplotlib.pyplot as plt
    import numpy as np

    fig, ax = plt.subplots(figsize=(30, 7))

    if color is None or cmap is None:
        specdf.T.plot(legend=False, ax=ax, color="blue")

    else:
        cats = color.unique()
        for key, value in cmap.items():
            ax.plot([], [], color=value, label=str(key))
            plt.legend()

        for key, value in cmap.items():
            idx = color.index[color == key].tolist()
            specdf.loc[idx].T.plot(legend=False, ax=ax, color=value)
    if mean:
        specdf.mean().T.plot(legend=False, ax=ax, color="black", linewidth=5)

    ax.set_xlabel(xunits, fontsize=30)
    ax.set_ylabel(yunits, fontsize=30)
    plt.margins(x=0)
    plt.tight_layout()
    # plt.legend()
    return fig


@st.cache_data
def barhplot(metadf, cmap):
    counts = metadf.groupby(metadf.columns[0]).size()
    counts = counts.loc[cmap.keys()]
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.barh(counts.index, counts.values, color=cmap.values())
    plt.gca().invert_yaxis()
    plt.xlabel('Count')
    plt.ylabel(str(metadf.columns[0]).capitalize())
    return fig


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Cal/val hist ~~~~~~~~~~~~~~~~~~~~~~~~~~
@st.cache_data
def hist(y, y_train, y_test, target_name='y'):
    fig, ax = plt.subplots(figsize=(5, 2))
    sns.histplot(y, color="#004e9e", kde=True, label=str(
        target_name) + " (Total)", ax=ax, fill=True)
    sns.histplot(y_train, color="#2C6B6F", kde=True,
                 label=str(target_name)+" (Cal)", ax=ax, fill=True)
    sns.histplot(y_test, color="#d0f7be", kde=True, label=str(
        target_name)+" (Val)", ax=ax, fill=True)
    ax.set_xlabel(str(target_name))
    plt.legend()
    plt.tight_layout()
    return fig


@st.cache_data
def reg_plot(meas, pred, train_idx, test_idx, trainplot=True):
    a0 = np.ones(2)
    a1 = np.ones(2)

    n = 2 if trainplot else 1
    for i in range(n):
        meas[i] = np.array(meas[i]).reshape(-1, 1)
        pred[i] = np.array(pred[i]).reshape(-1, 1)

        from sklearn.linear_model import LinearRegression
        M = LinearRegression()
        M.fit(meas[i], pred[i])
        a1[i] = np.round(M.coef_[0][0], 2)
        a0[i] = np.round(M.intercept_[0], 2)

    if trainplot:
        ec = np.subtract(np.array(meas[0]).reshape(-1),
                         np.array(pred[0]).reshape(-1))
    et = np.subtract(np.array(meas[1]).reshape(-1),
                     np.array(pred[1]).reshape(-1))

    fig, ax = plt.subplots(figsize=(12, 4))
    if trainplot:
        sns.regplot(x=meas[0], y=pred[0], color="#2C6B6F", label=f'Cal (Predicted = {
                    a0[0]} + {a1[0]} x Measured)', scatter_kws={'edgecolor': 'black'})
    sns.regplot(x=meas[1], y=pred[1], color='#d0f7be', label=f'Val (Predicted = {
                a0[1]} + {a1[1]} x Measured)', scatter_kws={'edgecolor': 'black'})
    plt.plot([np.min(meas[0]) - 0.05, np.max([meas[0]]) + 0.05],
             [np.min(meas[0]) - 0.05, np.max([meas[0]]) + 0.05], color='black')

    if trainplot:
        for i, txt in enumerate(train_idx):
            # plt.annotate(txt ,(np.array(meas[0]).reshape(-1)[i],ec[i]))
            if np.abs(ec[i]) > np.mean(ec) + 3*np.std(ec):
                plt.annotate(
                    txt, (np.array(meas[0]).reshape(-1)[i], np.array(pred[0]).reshape(-1)[i]))
    for i, txt in enumerate(test_idx):
        if np.abs(et[i]) > np.mean(et) + 3*np.std(et):
            plt.annotate(
                txt, (np.array(meas[1]).reshape(-1)[i], np.array(pred[1]).reshape(-1)[i]))

    ax.set_ylabel('Predicted values')
    ax.set_xlabel('Measured values')
    plt.legend()
    plt.margins(0)
    # fig.savefig('./report/figures/measured_vs_predicted.png')
    return fig

# Resid plot


@st.cache_data
def resid_plot(meas, pred, train_idx, test_idx, trainplot=True):

    et = np.subtract(meas[1], pred[1])
    ett = np.array(et).reshape(-1, 1)

    fig, ax = plt.subplots(figsize=(12, 4))
    plt.axhline(y=0, c='black', linestyle=':')
    if trainplot:
        ec = np.subtract(meas[0], pred[0])
        ecc = np.array(ec).reshape(-1, 1)
        sns.scatterplot(x=pred[0], y=ec, color="#2C6B6F",
                        label=f'Cal', edgecolor="black")

        for i, txt in enumerate(train_idx):
            if np.abs(ecc[i]) > np.mean(ecc) + 3*np.std(ecc):
                plt.annotate(txt, (np.array(pred[0]).reshape(-1)[i], ecc[i]))



    sns.scatterplot(x=pred[1], y=et, color="#d0f7be",
                    label=f'Val', edgecolor="black")
    for i, txt in enumerate(test_idx):
        if np.abs(ett[i]) > np.mean(ett) + 3 * np.std(ett):
            plt.annotate(txt, (np.array(pred[1]).reshape(-1)[i], ett[i]))

    if trainplot:
        lim = np.max(abs(np.concatenate([ec, et], axis=0)))*1.1
    else:
        lim = np.max(abs(et))*1.1
    plt.ylim(- lim, lim)

    ax.set_ylabel('Residuals')
    ax.set_xlabel('Predicted values')
    plt.legend()
    plt.margins(0)
    # fig.savefig('./report/figures/residuals_plot.png')
    return fig