#default_exp train
Train
API details.
#hide
from nbdev.showdoc import *
#export
import os, sys, json, math, itertools
import pandas as pd, numpy as np
import warnings
# from tqdm import tqdm
from tqdm.notebook import tqdm
import torch
from MIOFlow.utils import sample, generate_steps
from MIOFlow.losses import MMD_loss, OT_loss, Density_loss, Local_density_loss
def train(
=20,
model, df, groups, optimizer, n_batches=MMD_loss(),
criterion=False,
use_cuda
=(100, ),
sample_size=False,
sample_with_replacement
=True,
local_loss=False,
global_loss
=False,
hold_one_out='random',
hold_out=True,
apply_losses_in_time
= 5,
top_k = 0.01,
hinge_value =True,
use_density_loss# use_local_density=False,
= 1.0,
lambda_density
=None,
autoencoder=True,
use_emb=False,
use_gae
bool=True,
use_gaussian:bool=False,
add_noise:float=0.1,
noise_scale:
=None,
logger
=False,
use_penalty=1.0,
lambda_energy
bool = False
reverse:
):
'''
MIOFlow training loop
Notes:
- The argument `model` must have a method `forward` that accepts two arguments
in its function signature:
```python
model.forward(x, t)
```
where, `x` is the input tensor and `t` is a `torch.Tensor` of time points (float).
- The training loop is divided in two parts; local (predict t+1 from t), and global (predict the entire trajectory).
Arguments:
model (nn.Module): the initialized pytorch ODE model.
df (pd.DataFrame): the DataFrame from which to extract batch data.
groups (list): the list of the numerical groups in the data, e.g.
`[1.0, 2.0, 3.0, 4.0, 5.0]`, if the data has five groups.
optimizer (torch.optim): an optimizer initilized with the model's parameters.
n_batches (int): Default to '20', the number of batches from which to randomly sample each consecutive pair
of groups.
criterion (Callable | nn.Loss): a loss function.
use_cuda (bool): Defaults to `False`. Whether or not to send the model and data to cuda.
sample_size (tuple): Defaults to `(100, )`
sample_with_replacement (bool): Defaults to `False`. Whether or not to sample data points with replacement.
local_loss (bool): Defaults to `True`. Whether or not to use a local loss in the model.
See notes for more detail.
global_loss (bool): Defaults to `False`. Whether or not to use a global loss in the model.
hold_one_out (bool): Defaults to `False`. Whether or not to randomly hold one time pair
e.g. t_1 to t_2 out when computing the global loss.
hold_out (str | int): Defaults to `"random"`. Which time point to hold out when calculating the
global loss.
apply_losses_in_time (bool): Defaults to `True`. Applies the losses and does back propegation
as soon as a loss is calculated. See notes for more detail.
top_k (int): Default to '5'. The k for the k-NN used in the density loss.
hinge_value (float): Defaults to `0.01`. The hinge value for density loss.
use_density_loss (bool): Defaults to `True`. Whether or not to add density regularization.
lambda_density (float): Defaults to `1.0`. The weight for density loss.
autoencoder (NoneType|nn.Module): Default to 'None'. The full geodesic Autoencoder.
use_emb (bool): Defaults to `True`. Whether or not to use the embedding model.
use_gae (bool): Defaults to `False`. Whether or not to use the full Geodesic AutoEncoder.
use_gaussian (bool): Defaults to `True`. Whether to use random or gaussian noise.
add_noise (bool): Defaults to `False`. Whether or not to add noise.
noise_scale (float): Defaults to `0.30`. How much to scale the noise by.
logger (NoneType|Logger): Default to 'None'. The logger to record information.
use_penalty (bool): Defaults to `False`. Whether or not to use $L_e$ during training (norm of the derivative).
lambda_energy (float): Default to '1.0'. The weight of the energy penalty.
reverse (bool): Whether to train time backwards.
'''
if autoencoder is None and (use_emb or use_gae):
= False
use_emb = False
use_gae '\'autoencoder\' is \'None\', but \'use_emb\' or \'use_gae\' is True, both will be set to False.')
warnings.warn(
= torch.randn if use_gaussian else torch.rand
noise_fn def noise(data):
return noise_fn(*data.shape).cuda() if use_cuda else noise_fn(*data.shape)
# Create the indicies for the steps that should be used
= generate_steps(groups)
steps
if reverse:
= groups[::-1]
groups = generate_steps(groups)
steps
# Storage variables for losses
= []
batch_losses = []
globe_losses if hold_one_out and hold_out in groups:
= [g for g in groups if g != hold_out]
groups_ho = {f'{t0}:{t1}':[] for (t0, t1) in generate_steps(groups_ho) if hold_out not in [t0, t1]}
local_losses else:
= {f'{t0}:{t1}':[] for (t0, t1) in steps}
local_losses
= Density_loss(hinge_value) # if not use_local_density else Local_density_loss()
density_fn
# Send model to cuda and specify it as training mode
if use_cuda:
= model.cuda()
model
model.train()
for batch in tqdm(range(n_batches)):
# apply local loss
if local_loss and not global_loss:
# for storing the local loss with calling `.item()` so `loss.backward()` can still be used
= []
batch_loss if hold_one_out:
= [g for g in groups if g != hold_out] # TODO: Currently does not work if hold_out='random'. Do to_ignore before.
groups = generate_steps(groups)
steps for step_idx, (t0, t1) in enumerate(steps):
if hold_out in [t0, t1] and hold_one_out: # TODO: This `if` can be deleted since the groups does not include the ho timepoint anymore
continue # i.e. it is always False.
optimizer.zero_grad()
#sampling, predicting, and evaluating the loss.
# sample data
= sample(df, t0, size=sample_size, replace=sample_with_replacement, to_torch=True, use_cuda=use_cuda)
data_t0 = sample(df, t1, size=sample_size, replace=sample_with_replacement, to_torch=True, use_cuda=use_cuda)
data_t1 = torch.Tensor([t0, t1]).cuda() if use_cuda else torch.Tensor([t0, t1])
time
if add_noise:
+= noise(data_t0) * noise_scale
data_t0 += noise(data_t1) * noise_scale
data_t1 if autoencoder is not None and use_gae:
= autoencoder.encoder(data_t0)
data_t0 = autoencoder.encoder(data_t1)
data_t1 # prediction
= model(data_t0, time)
data_tp
if autoencoder is not None and use_emb:
= autoencoder.encoder(data_tp), autoencoder.encoder(data_t1)
data_tp, data_t1 # loss between prediction and sample t1
= criterion(data_tp, data_t1)
loss
if use_density_loss:
= density_fn(data_tp, data_t1, top_k=top_k)
density_loss += lambda_density * density_loss
loss
if use_penalty:
= sum(model.norm)
penalty += lambda_energy * penalty
loss
# apply local loss as we calculate it
if apply_losses_in_time and local_loss:
loss.backward()
optimizer.step()=[]
model.norm# save loss in storage variables
f'{t0}:{t1}'].append(loss.item())
local_losses[
batch_loss.append(loss)
# convert the local losses into a tensor of len(steps)
= torch.Tensor(batch_loss).float()
batch_loss if use_cuda:
= batch_loss.cuda()
batch_loss
if not apply_losses_in_time:
batch_loss.backward()
optimizer.step()
# store average / sum of local losses for training
= torch.mean(batch_loss)
ave_local_loss = torch.sum(batch_loss)
sum_local_loss
batch_losses.append(ave_local_loss.item())
# apply global loss
elif global_loss and not local_loss:
optimizer.zero_grad()#sampling, predicting, and evaluating the loss.
# sample data
= [
data_ti
sample(=sample_size, replace=sample_with_replacement,
df, group, size=True, use_cuda=use_cuda
to_torch
)for group in groups
]= torch.Tensor(groups).cuda() if use_cuda else torch.Tensor(groups)
time
if add_noise:
= [
data_ti + noise(data) * noise_scale for data in data_ti
data
]if autoencoder is not None and use_gae:
= [autoencoder.encoder(data) for data in data_ti]
data_ti # prediction
= model(data_ti[0], time, return_whole_sequence=True)
data_tp if autoencoder is not None and use_emb:
= [autoencoder.encoder(data) for data in data_tp]
data_tp = [autoencoder.encoder(data) for data in data_ti]
data_ti
#ignoring one time point
= None #TODO: This assignment of `to_ingnore`, could be moved at the beginning of the function.
to_ignore if hold_one_out and hold_out == 'random':
= np.random.choice(groups)
to_ignore elif hold_one_out and hold_out in groups:
= hold_out
to_ignore elif hold_one_out:
raise ValueError('Unknown group to hold out')
else:
pass
= sum([
loss
criterion(data_tp[i], data_ti[i]) for i in range(1, len(groups))
if groups[i] != to_ignore
])
if use_density_loss:
= density_fn(data_tp, data_ti, groups, to_ignore, top_k)
density_loss += lambda_density * density_loss
loss
if use_penalty:
= sum([model.norm[-(i+1)] for i in range(1, len(groups))
penalty if groups[i] != to_ignore])
+= lambda_energy * penalty
loss
loss.backward()
optimizer.step()=[]
model.norm
globe_losses.append(loss.item())elif local_loss and global_loss:
# NOTE: weighted local / global loss has been removed to improve runtime
raise NotImplementedError()
else:
raise ValueError('A form of loss must be specified.')
= globe_losses if global_loss else batch_losses
print_loss if logger is None:
f'Train loss: {np.round(np.mean(print_loss), 5)}')
tqdm.write(else:
f'Train loss: {np.round(np.mean(print_loss), 5)}')
logger.info(return local_losses, batch_losses, globe_losses
#export
from MIOFlow.utils import generate_steps
import torch.nn as nn
from tqdm.notebook import tqdm
import numpy as np
def train_ae(
model, df, groups, optimizer,=60, criterion=nn.MSELoss(), dist=None, recon = True,
n_epochs=False, sample_size=(100, ),
use_cuda=False,
sample_with_replacement=0.09,
noise_min_scale=0.15,
noise_max_scalebool=False,
hold_one_out:='random'
hold_out
):"""
Geodesic Autoencoder training loop.
Notes:
- We can train only the encoder the fit the geodesic distance (recon=False), or the full geodesic Autoencoder (recon=True),
i.e. matching the distance and reconstruction of the inputs.
Arguments:
model (nn.Module): the initialized pytorch Geodesic Autoencoder model.
df (pd.DataFrame): the DataFrame from which to extract batch data.
groups (list): the list of the numerical groups in the data, e.g.
`[1.0, 2.0, 3.0, 4.0, 5.0]`, if the data has five groups.
optimizer (torch.optim): an optimizer initilized with the model's parameters.
n_epochs (int): Default to '60'. The number of training epochs.
criterion (torch.nn). Default to 'nn.MSELoss()'. The criterion to minimize.
dist (NoneType|Class). Default to 'None'. The distance Class with a 'fit(X)' method for a dataset 'X'. Computes the pairwise distances in 'X'.
recon (bool): Default to 'True'. Whether or not the apply the reconstruction loss.
use_cuda (bool): Defaults to `False`. Whether or not to send the model and data to cuda.
sample_size (tuple): Defaults to `(100, )`.
sample_with_replacement (bool): Defaults to `False`. Whether or not to sample data points with replacement.
noise_min_scale (float): Default to '0.0'. The minimum noise scale.
noise_max_scale (float): Default to '1.0'. The maximum noise scale. The true scale is sampled between these two bounds for each epoch.
hold_one_out (bool): Default to False, whether or not to ignore a timepoint during training.
hold_out (str|int): Default to 'random', the timepoint to hold out, either a specific element of 'groups' or a random one.
"""
= generate_steps(groups)
steps = []
losses
model.train()for epoch in tqdm(range(n_epochs)):
# ignoring one time point
= None
to_ignore if hold_one_out and hold_out == 'random':
= np.random.choice(groups)
to_ignore elif hold_one_out and hold_out in groups:
= hold_out
to_ignore elif hold_one_out:
raise ValueError('Unknown group to hold out')
else:
pass
# Training
optimizer.zero_grad()= torch.FloatTensor(1).uniform_(noise_min_scale, noise_max_scale)
noise_scale = torch.vstack([sample(df, group, size=sample_size, replace=sample_with_replacement, to_torch=True, use_cuda=use_cuda) for group in groups if group != to_ignore])
data_ti = (noise_scale*torch.randn(data_ti.size())).cuda() if use_cuda else noise_scale*torch.randn(data_ti.size())
noise
= model.encoder(data_ti + noise)
encode_dt = model.decoder(encode_dt) if recon else None
recon_dt
if recon:
= criterion(recon_dt,data_ti)
loss_recon = loss_recon
loss
if epoch%50==0:
f'Train loss recon: {np.round(np.mean(loss_recon.item()), 5)}')
tqdm.write(
if dist is not None:
= dist.fit(data_ti.cpu().numpy())
dist_geo = torch.from_numpy(dist_geo).float().cuda() if use_cuda else torch.from_numpy(dist_geo).float()
dist_geo = torch.cdist(encode_dt,encode_dt)**2
dist_emb = criterion(dist_emb,dist_geo)
loss_dist = loss_recon + loss_dist if recon else loss_dist
loss
if epoch%50==0:
f'Train loss dist: {np.round(np.mean(loss_dist.item()), 5)}')
tqdm.write(
loss.backward()
optimizer.step()
losses.append(loss.item())return losses
#export
from MIOFlow.plots import plot_comparision, plot_losses
from MIOFlow.eval import generate_plot_data
def training_regimen(
n_local_epochs, n_epochs, n_post_local_epochs,
exp_dir,
# BEGIN: train params
=20,
model, df, groups, optimizer, n_batches=MMD_loss(), use_cuda=False,
criterion
=False, hold_out='random',
hold_one_out=0.01, use_density_loss=True,
hinge_value
= 5, lambda_density = 1.0,
top_k =None, use_emb=True, use_gae=False,
autoencoder=(100, ),
sample_size=False,
sample_with_replacement=None,
logger=False, noise_scale=0.1, use_gaussian=True,
add_noise=False, lambda_energy=1.0,
use_penalty# END: train params
=None, plot_every=None,
steps=100, n_trajectories=100, n_bins=100,
n_points=None, batch_losses=None, globe_losses=None,
local_losses=True, reverse_n=4
reverse_schema
):= use_gae and not use_emb
recon if steps is None:
= generate_steps(groups)
steps
if local_losses is None:
if hold_one_out and hold_out in groups:
= [g for g in groups if g != hold_out]
groups_ho = {f'{t0}:{t1}':[] for (t0, t1) in generate_steps(groups_ho) if hold_out not in [t0, t1]}
local_losses if reverse_schema:
= {
local_losses **local_losses,
**{f'{t0}:{t1}':[] for (t0, t1) in generate_steps(groups_ho[::-1]) if hold_out not in [t0, t1]}
}else:
= {f'{t0}:{t1}':[] for (t0, t1) in generate_steps(groups)}
local_losses if reverse_schema:
= {
local_losses **local_losses,
**{f'{t0}:{t1}':[] for (t0, t1) in generate_steps(groups[::-1])}
}if batch_losses is None:
= []
batch_losses if globe_losses is None:
= []
globe_losses
= False
reverse for epoch in tqdm(range(n_local_epochs), desc='Pretraining Epoch'):
= True if reverse_schema and epoch % reverse_n == 0 else False
reverse
= train(
l_loss, b_loss, g_loss
model, df, groups, optimizer, n_batches, = criterion, use_cuda = use_cuda,
criterion =True, global_loss=False, apply_losses_in_time=True,
local_loss=hold_one_out, hold_out=hold_out,
hold_one_out=hinge_value,
hinge_value= use_density_loss,
use_density_loss = top_k, lambda_density = lambda_density,
top_k = autoencoder, use_emb = use_emb, use_gae = use_gae, sample_size=sample_size,
autoencoder =sample_with_replacement, logger=logger,
sample_with_replacement=add_noise, noise_scale=noise_scale, use_gaussian=use_gaussian,
add_noise=use_penalty, lambda_energy=lambda_energy, reverse=reverse
use_penalty
)for k, v in l_loss.items():
local_losses[k].extend(v)
batch_losses.extend(b_loss)
globe_losses.extend(g_loss)if plot_every is not None and epoch % plot_every == 0:
= generate_plot_data(
generated, trajectories
model, df, n_points, n_trajectories, n_bins, =sample_with_replacement, use_cuda=use_cuda,
sample_with_replacement='samples', logger=logger,
samples_key=autoencoder, recon=recon
autoencoder
)
plot_comparision(
df, generated, trajectories,= 'viridis', df_time_key='samples',
palette =True, path=exp_dir,
savefile=f'2d_comparision_local_{epoch}.png',
='d1', y='d2', z='d3', is_3d=False
x
)
for epoch in tqdm(range(n_epochs), desc='Epoch'):
= True if reverse_schema and epoch % reverse_n == 0 else False
reverse = train(
l_loss, b_loss, g_loss
model, df, groups, optimizer, n_batches, = criterion, use_cuda = use_cuda,
criterion =False, global_loss=True, apply_losses_in_time=True,
local_loss=hold_one_out, hold_out=hold_out,
hold_one_out=hinge_value,
hinge_value= use_density_loss,
use_density_loss = top_k, lambda_density = lambda_density,
top_k = autoencoder, use_emb = use_emb, use_gae = use_gae, sample_size=sample_size,
autoencoder =sample_with_replacement, logger=logger,
sample_with_replacement=add_noise, noise_scale=noise_scale, use_gaussian=use_gaussian,
add_noise=use_penalty, lambda_energy=lambda_energy, reverse=reverse
use_penalty
)for k, v in l_loss.items():
local_losses[k].extend(v)
batch_losses.extend(b_loss)
globe_losses.extend(g_loss)if plot_every is not None and epoch % plot_every == 0:
= generate_plot_data(
generated, trajectories
model, df, n_points, n_trajectories, n_bins, =sample_with_replacement, use_cuda=use_cuda,
sample_with_replacement='samples', logger=logger,
samples_key=autoencoder, recon=recon
autoencoder
)
plot_comparision(
df, generated, trajectories,= 'viridis', df_time_key='samples',
palette =True, path=exp_dir,
savefile=f'2d_comparision_local_{n_local_epochs}_global_{epoch}.png',
='d1', y='d2', z='d3', is_3d=False
x
)
for epoch in tqdm(range(n_post_local_epochs), desc='Posttraining Epoch'):
= True if reverse_schema and epoch % reverse_n == 0 else False
reverse
= train(
l_loss, b_loss, g_loss
model, df, groups, optimizer, n_batches, = criterion, use_cuda = use_cuda,
criterion =True, global_loss=False, apply_losses_in_time=True,
local_loss=hold_one_out, hold_out=hold_out,
hold_one_out=hinge_value,
hinge_value= use_density_loss,
use_density_loss = top_k, lambda_density = lambda_density,
top_k = autoencoder, use_emb = use_emb, use_gae = use_gae, sample_size=sample_size,
autoencoder =sample_with_replacement, logger=logger,
sample_with_replacement=add_noise, noise_scale=noise_scale, use_gaussian=use_gaussian,
add_noise=use_penalty, lambda_energy=lambda_energy, reverse=reverse
use_penalty
)for k, v in l_loss.items():
local_losses[k].extend(v)
batch_losses.extend(b_loss)
globe_losses.extend(g_loss)if plot_every is not None and epoch % plot_every == 0:
= generate_plot_data(
generated, trajectories
model, df, n_points, n_trajectories, n_bins, =sample_with_replacement, use_cuda=use_cuda,
sample_with_replacement='samples', logger=logger,
samples_key=autoencoder, recon=recon
autoencoder
)
plot_comparision(
df, generated, trajectories,= 'viridis', df_time_key='samples',
palette =True, path=exp_dir,
savefile=f'2d_comparision_local_{n_local_epochs}_global_{n_epochs}_post_{epoch}.png',
='d1', y='d2', z='d3', is_3d=False
x
)
if reverse_schema:
= {}
_temp if hold_one_out:
for (t0, t1) in generate_steps([g for g in groups if g != hold_out]):
= f'{t0}:{t1}'
a = f'{t1}:{t0}'
b = []
_temp[a] for i, value in enumerate(local_losses[a]):
if i % reverse_n == 0:
0))
_temp[a].append(local_losses[b].pop(
_temp[a].append(value)else:
_temp[a].append(value)= _temp
local_losses else:
for (t0, t1) in generate_steps(groups):
= f'{t0}:{t1}'
a = f'{t1}:{t0}'
b = []
_temp[a] for i, value in enumerate(local_losses[a]):
if i % reverse_n == 0:
0))
_temp[a].append(local_losses[b].pop(
_temp[a].append(value)else:
_temp[a].append(value)= _temp
local_losses
if plot_every is not None:
plot_losses(
local_losses, batch_losses, globe_losses, =True, path=exp_dir,
savefile=f'losses_l{n_local_epochs}_e{n_epochs}_ple{n_post_local_epochs}.png'
)
return local_losses, batch_losses, globe_losses