TorchDiffEqPack.odesolver_mem package

TorchDiffEqPack.odesolver_mem.adjoint module

TorchDiffEqPack.odesolver_mem.adjoint.odesolve_adjoint(func, y0, options=None)

Implementation of ICML 2020 paper “Adaptive checkpoint adjoint method for accurate gradient esitmation in Neural ODEs”

"odesolve_adjoint" implements the "Adaptive checkpoint adjoint" (ACA) method. It can be combined with general ODE solver, and guarantees the accuracy of reverse-time trajectory; compared with the "odesolve" method, its memory cost is lower but still slowly increases with integration time.

Limitation: "odesolve_adjoint" only support outputs the end-time value, if you need to output values at multiple points and back-prop, the method "odesolve" is recommended. If you need to output values at t1, t2, ... tn, and your GPU memory is limited, a work-around is to treat [t1,t2], [t2, t3] as separate end-time state problem, and call "odesolve_adjoint" for each chunk.


from TorchDiffEqPack import odesolve_adjoint
# configure training options
options = {}
options.update({'method': 'Dopri5'})
options.update({'h': None})
options.update({'t0': 0.0})
options.update({'t1': 1.0})
options.update({'rtol': 1e-7})
options.update({'atol': 1e-8})
options.update({'print_neval': False})
options.update({'neval_max': 1000000})
options.update({'t_eval':None})
options.update({'interpolation_method':'cubic'})
options.update({'regenerate_graph':False})
out = odesolve_adjoint(func, y0, options = options)

Arguments:

  • ‘t_eval’: Must be None, only output the value at time t1

  • Other parameters are the same in "odesolve"

TorchDiffEqPack.odesolver_mem.adjoint_mem module

TorchDiffEqPack.odesolver_mem.adjoint_mem.odesolve_adjoint_sym12(func, y0, options=None)

Implementation of ICLR 2021 paper “MALI: a memory efficient asynchronous leapfrog integrator for Neural ODEs”

"odesolve_adjoint_sym12" implements the "MALI" method. It is limited to the asynchronous leapfrog integrator, which is a 2nd-order ODE solver. Its memory cost is constant w.r.t. integration time. Its constant memory cost makes "odesolve_adjoint_sym12" suitable for large-scale systems such as FFJORD, ODE-CNN on ImageNet.

Limitation: "odesolve_adjoint_sym12" only support outputs the end-time value, if you need to output values at multiple points and back-prop, the method "odesolve" is recommended. If you need to output values at t1, t2, ... tn, and your GPU memory is limited, a work-around is to treat [t1,t2], [t2, t3] as separate end-time state problem, and call "odesolve_adjoint_sym12" for each chunk.


from TorchDiffEqPack import odesolve_adjoint_sym12
# configure training options
options = {}
options.update({'method': 'sym12async'})
options.update({'h': None})
options.update({'t0': 0.0})
options.update({'t1': 1.0})
options.update({'rtol': 1e-3})
options.update({'atol': 1e-3})
options.update({'print_neval': False})
options.update({'neval_max': 1000000})
options.update({'t_eval':None})
options.update({'interpolation_method':'cubic'})
options.update({'regenerate_graph':False})
out = odesolve_adjoint_sym12(func, y0, options = options)

Arguments:

  • ‘t_eval’: Must be None, the output is the value at time t1

  • Other parameters are the same in "odesolve"