TorchDiffEqPack.odesolver_mem package¶
TorchDiffEqPack.odesolver_mem.adjoint module¶
-
TorchDiffEqPack.odesolver_mem.adjoint.
odesolve_adjoint
(func, y0, options=None)¶ Implementation of ICML 2020 paper “Adaptive checkpoint adjoint method for accurate gradient esitmation in Neural ODEs”
"odesolve_adjoint" implements the "Adaptive checkpoint adjoint" (ACA) method. It can be combined with general ODE solver, and guarantees the accuracy of reverse-time trajectory; compared with the "odesolve" method, its memory cost is lower but still slowly increases with integration time.
Limitation: "odesolve_adjoint" only support outputs the end-time value, if you need to output values at multiple points and back-prop, the method "odesolve" is recommended. If you need to output values at t1, t2, ... tn, and your GPU memory is limited, a work-around is to treat [t1,t2], [t2, t3] as separate end-time state problem, and call "odesolve_adjoint" for each chunk.
from TorchDiffEqPack import odesolve_adjoint # configure training options options = {} options.update({'method': 'Dopri5'}) options.update({'h': None}) options.update({'t0': 0.0}) options.update({'t1': 1.0}) options.update({'rtol': 1e-7}) options.update({'atol': 1e-8}) options.update({'print_neval': False}) options.update({'neval_max': 1000000}) options.update({'t_eval':None}) options.update({'interpolation_method':'cubic'}) options.update({'regenerate_graph':False}) out = odesolve_adjoint(func, y0, options = options)
Arguments:
‘t_eval’: Must be None, only output the value at time t1
Other parameters are the same in "odesolve"
TorchDiffEqPack.odesolver_mem.adjoint_mem module¶
-
TorchDiffEqPack.odesolver_mem.adjoint_mem.
odesolve_adjoint_sym12
(func, y0, options=None)¶ Implementation of ICLR 2021 paper “MALI: a memory efficient asynchronous leapfrog integrator for Neural ODEs”
"odesolve_adjoint_sym12" implements the "MALI" method. It is limited to the asynchronous leapfrog integrator, which is a 2nd-order ODE solver. Its memory cost is constant w.r.t. integration time. Its constant memory cost makes "odesolve_adjoint_sym12" suitable for large-scale systems such as FFJORD, ODE-CNN on ImageNet.
Limitation: "odesolve_adjoint_sym12" only support outputs the end-time value, if you need to output values at multiple points and back-prop, the method "odesolve" is recommended. If you need to output values at t1, t2, ... tn, and your GPU memory is limited, a work-around is to treat [t1,t2], [t2, t3] as separate end-time state problem, and call "odesolve_adjoint_sym12" for each chunk.
from TorchDiffEqPack import odesolve_adjoint_sym12 # configure training options options = {} options.update({'method': 'sym12async'}) options.update({'h': None}) options.update({'t0': 0.0}) options.update({'t1': 1.0}) options.update({'rtol': 1e-3}) options.update({'atol': 1e-3}) options.update({'print_neval': False}) options.update({'neval_max': 1000000}) options.update({'t_eval':None}) options.update({'interpolation_method':'cubic'}) options.update({'regenerate_graph':False}) out = odesolve_adjoint_sym12(func, y0, options = options)
Arguments:
‘t_eval’: Must be None, the output is the value at time t1
Other parameters are the same in "odesolve"