#default_exp ode
ODE
API details.
#hide
from nbdev.showdoc import *
#export
import os, math, numpy as np
import torch
import torch.nn as nn
def ode_solve(z0, t0, t1, f):
"""
Simplest Euler ODE initial value solver
"""
= 0.05
h_max = math.ceil((abs(t1 - t0)/h_max).max().item())
n_steps
= (t1 - t0)/n_steps
h = t0
t = z0
z
for i_step in range(n_steps):
= z + h * f(z, t)
z = t + h
t return z
class ODEF(nn.Module):
def forward_with_grad(self, z, t, grad_outputs):
"""Compute f and a df/dz, a df/dp, a df/dt"""
= z.shape[0]
batch_size
= self.forward(z, t)
out
= grad_outputs
a *adfdp = torch.autograd.grad(
adfdz, adfdt, + tuple(self.parameters()), grad_outputs=(a),
(out,), (z, t) =True, retain_graph=True
allow_unused
)# grad method automatically sums gradients for batch items, we have to expand them back
if adfdp is not None:
= torch.cat([p_grad.flatten() for p_grad in adfdp]).unsqueeze(0)
adfdp = adfdp.expand(batch_size, -1) / batch_size
adfdp if adfdt is not None:
= adfdt.expand(batch_size, 1) / batch_size
adfdt return out, adfdz, adfdt, adfdp
def flatten_parameters(self):
= []
p_shapes = []
flat_parameters for p in self.parameters():
p_shapes.append(p.size())
flat_parameters.append(p.flatten())return torch.cat(flat_parameters)
class ODEAdjoint(torch.autograd.Function):
@staticmethod
def forward(ctx, z0, t, flat_parameters, func):
assert isinstance(func, ODEF)
*z_shape = z0.size()
bs, = t.size(0)
time_len
with torch.no_grad():
= torch.zeros(time_len, bs, *z_shape).to(z0)
z 0] = z0
z[for i_t in range(time_len - 1):
= ode_solve(z0, t[i_t], t[i_t+1], func)
z0 +1] = z0
z[i_t
= func
ctx.func
ctx.save_for_backward(t, z.clone(), flat_parameters)return z
@staticmethod
def backward(ctx, dLdz):
"""
dLdz shape: time_len, batch_size, *z_shape
"""
= ctx.func
func = ctx.saved_tensors
t, z, flat_parameters *z_shape = z.size()
time_len, bs, = np.prod(z_shape)
n_dim = flat_parameters.size(0)
n_params
# Dynamics of augmented system to be calculated backwards in time
def augmented_dynamics(aug_z_i, t_i):
"""
tensors here are temporal slices
t_i - is tensor with size: bs, 1
aug_z_i - is tensor with size: bs, n_dim*2 + n_params + 1
"""
= aug_z_i[:, :n_dim], aug_z_i[:, n_dim:2*n_dim] # ignore parameters and time
z_i, a
# Unflatten z and a
= z_i.view(bs, *z_shape)
z_i = a.view(bs, *z_shape)
a with torch.set_grad_enabled(True):
= t_i.detach().requires_grad_(True)
t_i = z_i.detach().requires_grad_(True)
z_i = func.forward_with_grad(z_i, t_i, grad_outputs=a) # bs, *z_shape
func_eval, adfdz, adfdt, adfdp = adfdz.to(z_i) if adfdz is not None else torch.zeros(bs, *z_shape).to(z_i)
adfdz = adfdp.to(z_i) if adfdp is not None else torch.zeros(bs, n_params).to(z_i)
adfdp = adfdt.to(z_i) if adfdt is not None else torch.zeros(bs, 1).to(z_i)
adfdt
# Flatten f and adfdz
= func_eval.view(bs, n_dim)
func_eval = adfdz.view(bs, n_dim)
adfdz return torch.cat((func_eval, -adfdz, -adfdp, -adfdt), dim=1)
= dLdz.view(time_len, bs, n_dim) # flatten dLdz for convenience
dLdz with torch.no_grad():
## Create placeholders for output gradients
# Prev computed backwards adjoints to be adjusted by direct gradients
= torch.zeros(bs, n_dim).to(dLdz)
adj_z = torch.zeros(bs, n_params).to(dLdz)
adj_p # In contrast to z and p we need to return gradients for all times
= torch.zeros(time_len, bs, 1).to(dLdz)
adj_t
for i_t in range(time_len-1, 0, -1):
= z[i_t]
z_i = t[i_t]
t_i = func(z_i, t_i).view(bs, n_dim)
f_i
# Compute direct gradients
= dLdz[i_t]
dLdz_i = torch.bmm(torch.transpose(dLdz_i.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]
dLdt_i
# Adjusting adjoints with direct gradients
+= dLdz_i
adj_z = adj_t[i_t] - dLdt_i
adj_t[i_t]
# Pack augmented variable
= torch.cat((z_i.view(bs, n_dim), adj_z, torch.zeros(bs, n_params).to(z), adj_t[i_t]), dim=-1)
aug_z
# Solve augmented system backwards
= ode_solve(aug_z, t_i, t[i_t-1], augmented_dynamics)
aug_ans
# Unpack solved backwards augmented system
= aug_ans[:, n_dim:2*n_dim]
adj_z[:] += aug_ans[:, 2*n_dim:2*n_dim + n_params]
adj_p[:] -1] = aug_ans[:, 2*n_dim + n_params:]
adj_t[i_t
del aug_z, aug_ans
## Adjust 0 time adjoint with direct gradients
# Compute direct gradients
= dLdz[0]
dLdz_0 = torch.bmm(torch.transpose(dLdz_0.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]
dLdt_0
# Adjust adjoints
+= dLdz_0
adj_z 0] = adj_t[0] - dLdt_0
adj_t[return adj_z.view(bs, *z_shape), adj_t, adj_p, None
class NeuralODE(nn.Module):
def __init__(self, func):
super(NeuralODE, self).__init__()
assert isinstance(func, ODEF)
self.func = func
def forward(self, z0, t=torch.Tensor([0., 1.]), return_whole_sequence=False):
= t.to(z0)
t = ODEAdjoint.apply(z0, t, self.func.flatten_parameters(), self.func)
z if return_whole_sequence:
return z
else:
return z[-1]