multiple_shooting_adjoint packageΒΆ
See https://github.com/juntang-zhuang/TorchDiffEqPack/blob/master/test_code/multiple_shooting_adjoint
Define the MultipleShoot class
import torch
import math
import numpy as np
from TorchDiffEqPack import odesolve
from torch import nn
class MultipleShoot(nn.Module):
def __init__(self, ode_func, chunk_length = 10, observation_length = 100, ODE_options = None,
smooth_penalty = 1.0, time_interval = 1.0):
super(MultipleShoot, self).__init__()
"""
:param ode_func: The ODE functions, dy/dt = func(t,y)
:param chunk_length: the observation is divided into chunks, each of of length chunk_length
:observation_length: total length of observation (This determines how many inter-mediate initial values
need to specified as extra parameters to update)
:ODE_options: options for ODE solvers
:smooth_penalty: penalty for smoothness
"""
self.odefunc = ode_func
self.chunk_length = chunk_length
self.observation_length = observation_length
self.ODE_options = ODE_options
self.smooth_penalty = smooth_penalty
self.time_interval = time_interval
def prepare_intermediate(self, observations):
# observations of shape num_time_points x N, N is the dimension of hidden state y
observation_length = int( observations.shape[0] )
self.observation_length = observation_length
# calculate the number of chunks
self.num_chunks = math.ceil( float(observation_length) / float(self.chunk_length) )
# create a list of intermedia results
self.intermediates = nn.ParameterList()
for i in range(self.num_chunks):
self.intermediates.append( nn.Parameter(
observations[i*self.chunk_length, :], requires_grad=True
)
)
def fit_and_grad(self, observations, time_points): # calculate grad w.r.t parameters
assert isinstance(time_points, list), "time_points must be of type list"
# check number of time points match observation
assert self.observation_length == len(time_points), "Number of time points mismatch observation"
# create observation into chunks
data_chunks, time_chunks = [], []
for i in range(self.num_chunks):
data_chunks.append(
observations[ i*self.chunk_length : min( (i+1) * self.chunk_length+1, self.observation_length), :]
)
time_chunks.append(
time_points[ i * self.chunk_length : min( (i+1) * self.chunk_length+1, self.observation_length)]
)
# fit data chunk by chunk
prediction_chunks = []
for i in range(self.num_chunks):
data_chunk, time_chunk, intermediate = data_chunks[i], time_chunks[i], self.intermediates[i]
self.ODE_options.update({'t0': time_chunk[0]})
self.ODE_options.update({'t1': time_chunk[-1]})
self.ODE_options.update({'t_eval': time_chunk})
result = odesolve(self.odefunc, y0 = intermediate, options=self.ODE_options)
prediction_chunks.append(result)
return prediction_chunks, data_chunks
def get_loss(self, prediction_chunks, data_chunks):
assert len(prediction_chunks)==len(data_chunks), "Length of data_chunks and prediction_chunks must match"
# loss between prediction and observation
observation_loss = 0.0
for data, prediction in zip(data_chunks, prediction_chunks):
observation_loss = observation_loss + torch.mean((data - prediction)**2)
# loss in mis-match between prediction and inter-mediate parameters
mismatch_loss = 0.0
for i in range(self.num_chunks-1):
prev, next = prediction_chunks[i][-1,:], self.intermediates[i+1]
mismatch_loss = mismatch_loss + torch.mean((prev - next)**2)
loss = observation_loss + mismatch_loss * self.smooth_penalty
print('Observation loss: {}, smoothness loss {}'.format( observation_loss.item(), mismatch_loss.item() ))
return loss
Examples with a linear dynamical system, see https://github.com/juntang-zhuang/TorchDiffEqPack/blob/master/test_code/multiple_shooting_adjoint/test_multiple_shoot.py
Examples with Lotka-Voltera equation, see https://github.com/juntang-zhuang/TorchDiffEqPack/blob/master/test_code/multiple_shooting_adjoint/test_multiple_shoot_lotka.py
Examples with nonlinear Lotka-Voltera equation, see https://github.com/juntang-zhuang/TorchDiffEqPack/blob/master/test_code/multiple_shooting_adjoint/test_multiple_shoot_lotka_sigmoid.py