TorchDiffEqPack.odesolver package

TorchDiffEqPack.odesolver.ode_solver module

TorchDiffEqPack.odesolver.ode_solver.odesolve(func, y0, options, return_solver=False, **kwargs)

General differentiable ODE solvers without too much memory consideration. Compared with the "odeint" in "torchdiffeq" package, "odesolve" deletes the adjustment of stepsize from back-propagation computation graph, instead it records all accepted steps.

Limitations: the memory cost is still increasing with integration time. Please use "odesolve_adjoint" or "odesolve_adjoint_sym12" to save memory

How to use:


# configure training options
options = {}
options.update({'method': 'Dopri5'})
options.update({'h': None})
options.update({'t0': 0.0})
options.update({'t1': 1.0})
options.update({'rtol': 1e-7})
options.update({'atol': 1e-8})
options.update({'print_neval': False})
options.update({'neval_max': 1000000})
options.update({'safety': None})
options.update({'t_eval':t_list})
options.update({'interpolation_method':'cubic'})
options.update({'regenerate_graph':False})
out = odesolve(func, y0, options = options) 
  • 'method': string, must be in [‘euler’,’rk2’,'rk4',’rk12’,’rk23’,’dopri5’, ‘ode23s’,'sym12async','fixedstep_sym12async'], ‘ode23s’ for stiff ODEs; fixed stepsize solvers include 'euler, rk2, rk4, fixedstep_asym12async', adaptive stepsize solvers include 'rk12, rk23, dopri5'.

  • 'h': float, initial stepsize for integration. Must be specified for fixed stepsize solvers; for adaptive solvers, can be set as None, then the solver witll automatically determine the initial stepsize

  • 't0' : float, initial time for integration

  • 't1': float, end time for integration

  • 'rtol': float or list of floats (must be same length as y0), relative tolerance for integration, typically set as 1e-5 or 1e-6 for dopri5

  • 'atol': float or list of floats (must be same length as y0), absolute tolerance for integration, typically set as 1e-6 or 1e-7 for dopri5

  • 'print_neval': bool, when print number of function evaluations, recommended to set as False

  • 'neval_max' : int, maximum number of evaluations when encountering stiff problems, typically set as 5e5

  • 't_eval': Evaluation time points, a list of float; if is None, only output the value at time t1

  • out = odesolve(func, y0, options = options) : func is the ODE; y0 is the initial condition, could be either a tensor or a tuple of tensors