#default_exp train
from nbdev.showdoc import *
import os, sys, json, math, itertools
import pandas as pd, numpy as np
import warnings

# from tqdm import tqdm
from tqdm.notebook import tqdm

import torch

from MIOFlow.utils import sample, generate_steps
from MIOFlow.losses import MMD_loss, OT_loss, Density_loss, Local_density_loss

def train(
    model, df, groups, optimizer, n_batches=20, 

    sample_size=(100, ),



    top_k = 5,
    hinge_value = 0.01,
    # use_local_density=False,

    lambda_density = 1.0,




    reverse:bool = False

    MIOFlow training loop
        - The argument `model` must have a method `forward` that accepts two arguments
            in its function signature:
                model.forward(x, t)
            where, `x` is the input tensor and `t` is a `torch.Tensor` of time points (float).
        - The training loop is divided in two parts; local (predict t+1 from t), and global (predict the entire trajectory).
        model (nn.Module): the initialized pytorch ODE model.
        df (pd.DataFrame): the DataFrame from which to extract batch data.
        groups (list): the list of the numerical groups in the data, e.g. 
            `[1.0, 2.0, 3.0, 4.0, 5.0]`, if the data has five groups.
        optimizer (torch.optim): an optimizer initilized with the model's parameters.
        n_batches (int): Default to '20', the number of batches from which to randomly sample each consecutive pair
            of groups.
        criterion (Callable | nn.Loss): a loss function.
        use_cuda (bool): Defaults to `False`. Whether or not to send the model and data to cuda. 

        sample_size (tuple): Defaults to `(100, )`

        sample_with_replacement (bool): Defaults to `False`. Whether or not to sample data points with replacement.
        local_loss (bool): Defaults to `True`. Whether or not to use a local loss in the model.
            See notes for more detail.
        global_loss (bool): Defaults to `False`. Whether or not to use a global loss in the model.
        hold_one_out (bool): Defaults to `False`. Whether or not to randomly hold one time pair
            e.g. t_1 to t_2 out when computing the global loss.

        hold_out (str | int): Defaults to `"random"`. Which time point to hold out when calculating the
            global loss.
        apply_losses_in_time (bool): Defaults to `True`. Applies the losses and does back propegation
            as soon as a loss is calculated. See notes for more detail.

        top_k (int): Default to '5'. The k for the k-NN used in the density loss.

        hinge_value (float): Defaults to `0.01`. The hinge value for density loss.

        use_density_loss (bool): Defaults to `True`. Whether or not to add density regularization.

        lambda_density (float): Defaults to `1.0`. The weight for density loss.

        autoencoder (NoneType|nn.Module): Default to 'None'. The full geodesic Autoencoder.

        use_emb (bool): Defaults to `True`. Whether or not to use the embedding model.
        use_gae (bool): Defaults to `False`. Whether or not to use the full Geodesic AutoEncoder.

        use_gaussian (bool): Defaults to `True`. Whether to use random or gaussian noise.

        add_noise (bool): Defaults to `False`. Whether or not to add noise.

        noise_scale (float): Defaults to `0.30`. How much to scale the noise by.
        logger (NoneType|Logger): Default to 'None'. The logger to record information.

        use_penalty (bool): Defaults to `False`. Whether or not to use $L_e$ during training (norm of the derivative).
        lambda_energy (float): Default to '1.0'. The weight of the energy penalty.

        reverse (bool): Whether to train time backwards.
    if autoencoder is None and (use_emb or use_gae):
        use_emb = False
        use_gae = False
        warnings.warn('\'autoencoder\' is \'None\', but \'use_emb\' or \'use_gae\' is True, both will be set to False.')

    noise_fn = torch.randn if use_gaussian else torch.rand
    def noise(data):
        return noise_fn(*data.shape).cuda() if use_cuda else noise_fn(*data.shape)
    # Create the indicies for the steps that should be used
    steps = generate_steps(groups)

    if reverse:
        groups = groups[::-1]
        steps = generate_steps(groups)

    # Storage variables for losses
    batch_losses = []
    globe_losses = []
    if hold_one_out and hold_out in groups:
        groups_ho = [g for g in groups if g != hold_out]
        local_losses = {f'{t0}:{t1}':[] for (t0, t1) in generate_steps(groups_ho) if hold_out not in [t0, t1]}
        local_losses = {f'{t0}:{t1}':[] for (t0, t1) in steps}
    density_fn = Density_loss(hinge_value) # if not use_local_density else Local_density_loss()

    # Send model to cuda and specify it as training mode
    if use_cuda:
        model = model.cuda()
    for batch in tqdm(range(n_batches)):
        # apply local loss
        if local_loss and not global_loss:
            # for storing the local loss with calling `.item()` so `loss.backward()` can still be used
            batch_loss = []
            if hold_one_out:
                groups = [g for g in groups if g != hold_out] # TODO: Currently does not work if hold_out='random'. Do to_ignore before. 
                steps = generate_steps(groups)
            for step_idx, (t0, t1) in enumerate(steps):  
                if hold_out in [t0, t1] and hold_one_out: # TODO: This `if` can be deleted since the groups does not include the ho timepoint anymore
                    continue                              # i.e. it is always False. 
                #sampling, predicting, and evaluating the loss.
                # sample data
                data_t0 = sample(df, t0, size=sample_size, replace=sample_with_replacement, to_torch=True, use_cuda=use_cuda)
                data_t1 = sample(df, t1, size=sample_size, replace=sample_with_replacement, to_torch=True, use_cuda=use_cuda)
                time = torch.Tensor([t0, t1]).cuda() if use_cuda else torch.Tensor([t0, t1])

                if add_noise:
                    data_t0 += noise(data_t0) * noise_scale
                    data_t1 += noise(data_t1) * noise_scale
                if autoencoder is not None and use_gae:
                    data_t0 = autoencoder.encoder(data_t0)
                    data_t1 = autoencoder.encoder(data_t1)
                # prediction
                data_tp = model(data_t0, time)

                if autoencoder is not None and use_emb:        
                    data_tp, data_t1 = autoencoder.encoder(data_tp), autoencoder.encoder(data_t1)
                # loss between prediction and sample t1
                loss = criterion(data_tp, data_t1)

                if use_density_loss:                
                    density_loss = density_fn(data_tp, data_t1, top_k=top_k)
                    loss += lambda_density * density_loss

                if use_penalty:
                    penalty = sum(model.norm)
                    loss += lambda_energy * penalty

                # apply local loss as we calculate it
                if apply_losses_in_time and local_loss:
                # save loss in storage variables 
            # convert the local losses into a tensor of len(steps)
            batch_loss = torch.Tensor(batch_loss).float()
            if use_cuda:
                batch_loss = batch_loss.cuda()
            if not apply_losses_in_time:

            # store average / sum of local losses for training
            ave_local_loss = torch.mean(batch_loss)
            sum_local_loss = torch.sum(batch_loss)            
        # apply global loss
        elif global_loss and not local_loss:
            #sampling, predicting, and evaluating the loss.
            # sample data
            data_ti = [
                    df, group, size=sample_size, replace=sample_with_replacement, 
                    to_torch=True, use_cuda=use_cuda
                for group in groups
            time = torch.Tensor(groups).cuda() if use_cuda else torch.Tensor(groups)

            if add_noise:
                data_ti = [
                    data + noise(data) * noise_scale for data in data_ti
            if autoencoder is not None and use_gae:
                data_ti = [autoencoder.encoder(data) for data in data_ti]
            # prediction
            data_tp = model(data_ti[0], time, return_whole_sequence=True)
            if autoencoder is not None and use_emb:        
                data_tp = [autoencoder.encoder(data) for data in data_tp]
                data_ti = [autoencoder.encoder(data) for data in data_ti]

            #ignoring one time point
            to_ignore = None #TODO: This assignment of `to_ingnore`, could be moved at the beginning of the function. 
            if hold_one_out and hold_out == 'random':
                to_ignore = np.random.choice(groups)
            elif hold_one_out and hold_out in groups:
                to_ignore = hold_out
            elif hold_one_out:
                raise ValueError('Unknown group to hold out')

            loss = sum([
                criterion(data_tp[i], data_ti[i]) 
                for i in range(1, len(groups))
                if groups[i] != to_ignore

            if use_density_loss:                
                density_loss = density_fn(data_tp, data_ti, groups, to_ignore, top_k)
                loss += lambda_density * density_loss

            if use_penalty:
                penalty = sum([model.norm[-(i+1)] for i in range(1, len(groups))
                    if groups[i] != to_ignore])
                loss += lambda_energy * penalty

        elif local_loss and global_loss:
            # NOTE: weighted local / global loss has been removed to improve runtime
            raise NotImplementedError()
            raise ValueError('A form of loss must be specified.')
    print_loss = globe_losses if global_loss else batch_losses 
    if logger is None:      
        tqdm.write(f'Train loss: {np.round(np.mean(print_loss), 5)}')
    else:'Train loss: {np.round(np.mean(print_loss), 5)}')
    return local_losses, batch_losses, globe_losses
from MIOFlow.utils import generate_steps
import torch.nn as nn
from tqdm.notebook import tqdm
import numpy as np

def train_ae(
    model, df, groups, optimizer,
    n_epochs=60, criterion=nn.MSELoss(), dist=None, recon = True,
    use_cuda=False, sample_size=(100, ),
    Geodesic Autoencoder training loop.
        - We can train only the encoder the fit the geodesic distance (recon=False), or the full geodesic Autoencoder (recon=True),
            i.e. matching the distance and reconstruction of the inputs.
        model (nn.Module): the initialized pytorch Geodesic Autoencoder model.

        df (pd.DataFrame): the DataFrame from which to extract batch data.
        groups (list): the list of the numerical groups in the data, e.g. 
            `[1.0, 2.0, 3.0, 4.0, 5.0]`, if the data has five groups.

        optimizer (torch.optim): an optimizer initilized with the model's parameters.

        n_epochs (int): Default to '60'. The number of training epochs.

        criterion (torch.nn). Default to 'nn.MSELoss()'. The criterion to minimize. 

        dist (NoneType|Class). Default to 'None'. The distance Class with a 'fit(X)' method for a dataset 'X'. Computes the pairwise distances in 'X'.

        recon (bool): Default to 'True'. Whether or not the apply the reconstruction loss. 
        use_cuda (bool): Defaults to `False`. Whether or not to send the model and data to cuda. 
        sample_size (tuple): Defaults to `(100, )`.
        sample_with_replacement (bool): Defaults to `False`. Whether or not to sample data points with replacement.
        noise_min_scale (float): Default to '0.0'. The minimum noise scale. 
        noise_max_scale (float): Default to '1.0'. The maximum noise scale. The true scale is sampled between these two bounds for each epoch. 
        hold_one_out (bool): Default to False, whether or not to ignore a timepoint during training.
        hold_out (str|int): Default to 'random', the timepoint to hold out, either a specific element of 'groups' or a random one. 
    steps = generate_steps(groups)
    losses = []

    for epoch in tqdm(range(n_epochs)):
        # ignoring one time point
        to_ignore = None
        if hold_one_out and hold_out == 'random':
            to_ignore = np.random.choice(groups)
        elif hold_one_out and hold_out in groups:
            to_ignore = hold_out
        elif hold_one_out:
            raise ValueError('Unknown group to hold out')
        # Training
        noise_scale = torch.FloatTensor(1).uniform_(noise_min_scale, noise_max_scale)
        data_ti = torch.vstack([sample(df, group, size=sample_size, replace=sample_with_replacement, to_torch=True, use_cuda=use_cuda) for group in groups if group != to_ignore])
        noise = (noise_scale*torch.randn(data_ti.size())).cuda() if use_cuda else noise_scale*torch.randn(data_ti.size())
        encode_dt = model.encoder(data_ti + noise)
        recon_dt = model.decoder(encode_dt) if recon else None
        if recon:
            loss_recon = criterion(recon_dt,data_ti)
            loss = loss_recon
            if epoch%50==0:
                tqdm.write(f'Train loss recon: {np.round(np.mean(loss_recon.item()), 5)}')
        if dist is not None:
            dist_geo =
            dist_geo = torch.from_numpy(dist_geo).float().cuda() if use_cuda else torch.from_numpy(dist_geo).float()
            dist_emb = torch.cdist(encode_dt,encode_dt)**2
            loss_dist = criterion(dist_emb,dist_geo)
            loss = loss_recon + loss_dist if recon else loss_dist
            if epoch%50==0:
                tqdm.write(f'Train loss dist: {np.round(np.mean(loss_dist.item()), 5)}')
    return losses
from MIOFlow.plots import plot_comparision, plot_losses
from MIOFlow.eval import generate_plot_data

def training_regimen(
    n_local_epochs, n_epochs, n_post_local_epochs,

    # BEGIN: train params
    model, df, groups, optimizer, n_batches=20, 
    criterion=MMD_loss(), use_cuda=False,

    hold_one_out=False, hold_out='random', 
    hinge_value=0.01, use_density_loss=True, 

    top_k = 5, lambda_density = 1.0, 
    autoencoder=None, use_emb=True, use_gae=False, 
    sample_size=(100, ), 
    add_noise=False, noise_scale=0.1, use_gaussian=True,  
    use_penalty=False, lambda_energy=1.0,
    # END: train params

    steps=None, plot_every=None,
    n_points=100, n_trajectories=100, n_bins=100, 
    local_losses=None, batch_losses=None, globe_losses=None,
    reverse_schema=True, reverse_n=4
    recon = use_gae and not use_emb
    if steps is None:
        steps = generate_steps(groups)
    if local_losses is None:
        if hold_one_out and hold_out in groups:
            groups_ho = [g for g in groups if g != hold_out]
            local_losses = {f'{t0}:{t1}':[] for (t0, t1) in generate_steps(groups_ho) if hold_out not in [t0, t1]}
            if reverse_schema:
                local_losses = {
                    **{f'{t0}:{t1}':[] for (t0, t1) in generate_steps(groups_ho[::-1]) if hold_out not in [t0, t1]}
            local_losses = {f'{t0}:{t1}':[] for (t0, t1) in generate_steps(groups)}
            if reverse_schema:
                local_losses = {
                    **{f'{t0}:{t1}':[] for (t0, t1) in generate_steps(groups[::-1])}
    if batch_losses is None:
        batch_losses = []
    if globe_losses is None:
        globe_losses = []
    reverse = False
    for epoch in tqdm(range(n_local_epochs), desc='Pretraining Epoch'):
        reverse = True if reverse_schema and epoch % reverse_n == 0 else False

        l_loss, b_loss, g_loss = train(
            model, df, groups, optimizer, n_batches, 
            criterion = criterion, use_cuda = use_cuda,
            local_loss=True, global_loss=False, apply_losses_in_time=True,
            hold_one_out=hold_one_out, hold_out=hold_out, 
            use_density_loss = use_density_loss,    
            top_k = top_k, lambda_density = lambda_density, 
            autoencoder = autoencoder, use_emb = use_emb, use_gae = use_gae, sample_size=sample_size, 
            sample_with_replacement=sample_with_replacement, logger=logger,
            add_noise=add_noise, noise_scale=noise_scale, use_gaussian=use_gaussian, 
            use_penalty=use_penalty, lambda_energy=lambda_energy, reverse=reverse
        for k, v in l_loss.items():  
        if plot_every is not None and epoch % plot_every == 0:
            generated, trajectories = generate_plot_data(
                model, df, n_points, n_trajectories, n_bins, 
                sample_with_replacement=sample_with_replacement, use_cuda=use_cuda, 
                samples_key='samples', logger=logger,
                autoencoder=autoencoder, recon=recon
                df, generated, trajectories,
                palette = 'viridis', df_time_key='samples',
                save=True, path=exp_dir, 
                x='d1', y='d2', z='d3', is_3d=False

    for epoch in tqdm(range(n_epochs), desc='Epoch'):
        reverse = True if reverse_schema and epoch % reverse_n == 0 else False
        l_loss, b_loss, g_loss = train(
            model, df, groups, optimizer, n_batches, 
            criterion = criterion, use_cuda = use_cuda,
            local_loss=False, global_loss=True, apply_losses_in_time=True,
            hold_one_out=hold_one_out, hold_out=hold_out, 
            use_density_loss = use_density_loss,       
            top_k = top_k, lambda_density = lambda_density, 
            autoencoder = autoencoder, use_emb = use_emb, use_gae = use_gae, sample_size=sample_size, 
            sample_with_replacement=sample_with_replacement, logger=logger, 
            add_noise=add_noise, noise_scale=noise_scale, use_gaussian=use_gaussian,
            use_penalty=use_penalty, lambda_energy=lambda_energy, reverse=reverse
        for k, v in l_loss.items():  
        if plot_every is not None and epoch % plot_every == 0:
            generated, trajectories = generate_plot_data(
                model, df, n_points, n_trajectories, n_bins, 
                sample_with_replacement=sample_with_replacement, use_cuda=use_cuda, 
                samples_key='samples', logger=logger,
                autoencoder=autoencoder, recon=recon
                df, generated, trajectories,
                palette = 'viridis', df_time_key='samples',
                save=True, path=exp_dir, 
                x='d1', y='d2', z='d3', is_3d=False
    for epoch in tqdm(range(n_post_local_epochs), desc='Posttraining Epoch'):
        reverse = True if reverse_schema and epoch % reverse_n == 0 else False

        l_loss, b_loss, g_loss = train(
            model, df, groups, optimizer, n_batches, 
            criterion = criterion, use_cuda = use_cuda,
            local_loss=True, global_loss=False, apply_losses_in_time=True,
            hold_one_out=hold_one_out, hold_out=hold_out, 
            use_density_loss = use_density_loss,       
            top_k = top_k, lambda_density = lambda_density,  
            autoencoder = autoencoder, use_emb = use_emb, use_gae = use_gae, sample_size=sample_size, 
            sample_with_replacement=sample_with_replacement, logger=logger, 
            add_noise=add_noise, noise_scale=noise_scale, use_gaussian=use_gaussian,
            use_penalty=use_penalty, lambda_energy=lambda_energy, reverse=reverse
        for k, v in l_loss.items():  
        if plot_every is not None and epoch % plot_every == 0:
            generated, trajectories = generate_plot_data(
                model, df, n_points, n_trajectories, n_bins, 
                sample_with_replacement=sample_with_replacement, use_cuda=use_cuda, 
                samples_key='samples', logger=logger,
                autoencoder=autoencoder, recon=recon
                df, generated, trajectories,
                palette = 'viridis', df_time_key='samples',
                save=True, path=exp_dir, 
                x='d1', y='d2', z='d3', is_3d=False

    if reverse_schema:
        _temp = {}
        if hold_one_out:
            for (t0, t1) in generate_steps([g for g in groups if g != hold_out]):
                a = f'{t0}:{t1}'
                b = f'{t1}:{t0}'
                _temp[a] = []
                for i, value in enumerate(local_losses[a]):

                    if i % reverse_n == 0:
            local_losses = _temp
            for (t0, t1) in generate_steps(groups):
                a = f'{t0}:{t1}'
                b = f'{t1}:{t0}'
                _temp[a] = []
                for i, value in enumerate(local_losses[a]):

                    if i % reverse_n == 0:
            local_losses = _temp

    if plot_every is not None:
            local_losses, batch_losses, globe_losses, 
            save=True, path=exp_dir, 

    return local_losses, batch_losses, globe_losses