add odeint function

This commit is contained in:
wyq 2023-02-10 16:54:26 +08:00
parent 7423d8dfb9
commit 7fac130689
17 changed files with 256 additions and 39 deletions

View File

@ -336,6 +336,13 @@
</option>
</inspection_tool>
<inspection_tool class="PyInterpreterInspection" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<option name="ignoredErrors">
<list>
<option value="N802" />
</list>
</option>
</inspection_tool>
<inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredIdentifiers">
<list>

View File

@ -1239,8 +1239,6 @@ public class GLPlot extends Plot {
this.updateMatrix(gl);
//gl.glColor3f(0.0f, 0.0f, 0.0f);
//Draw base
if (this.drawBase) {
this.drawBase(gl);

View File

@ -67,7 +67,7 @@ import java.util.zip.ZipInputStream;
public static String getVersion(){
String version = GlobalUtil.class.getPackage().getImplementationVersion();
if (version == null || version.equals("")) {
version = "3.5.4";
version = "3.5.5";
}
return version;
}

View File

@ -1,30 +1,34 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<MeteoInfo File="milconfig.xml" Type="configurefile">
<Path OpenPath="D:\Working\MIScript\Jython\mis\plot_types\wind">
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\test"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\io"/>
<Path OpenPath="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl\bar">
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\io\radar"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl\scatter"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl\quiver"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl\surf"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d\3d_earth"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math\interpolate"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl\isosurface"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl\particles"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\wind"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d\3d_earth"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl\contour"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl\bar"/>
</Path>
<File>
<OpenedFiles>
<OpenedFile File="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl\surf\surf_pumpkin_1.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\common_math\interpolate\nearest_2d_radius.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\plot_types\3d\3d_earth\scatter_sphere.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\common_math\integrate\odeint_1.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\common_math\integrate\odeint_lorenz.py"/>
</OpenedFiles>
<RecentFiles>
<RecentFile File="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl\surf\surf_pumpkin_1.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\common_math\interpolate\nearest_2d_radius.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\plot_types\3d\3d_earth\scatter_sphere.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\common_math\integrate\odeint_1.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\common_math\integrate\odeint_lorenz.py"/>
</RecentFiles>
</File>
<Font>
@ -32,5 +36,5 @@
</Font>
<LookFeel DockWindowDecorated="true" LafDecorated="true" Name="FlatDarkLaf"/>
<Figure DoubleBuffering="true"/>
<Startup MainFormLocation="-7,-7" MainFormSize="1293,685"/>
<Startup MainFormLocation="-7,0" MainFormSize="1367,792"/>
</MeteoInfo>

View File

@ -12,11 +12,12 @@ from .basic import coriolis_parameter
from .. import constants
__all__ = [
'cdiff','divergence','vorticity','advection','absolute_vorticity','potential_vorticity_barotropic',
'ageostrophic_wind','frontogenesis','geostrophic_wind','montgomery_streamfunction',
'potential_vorticity_baroclinic','shearing_deformation','storm_relative_helicity',
'stretching_deformation','total_deformation','inertial_advective_wind','q_vector'
]
'cdiff', 'divergence', 'vorticity', 'advection', 'absolute_vorticity', 'potential_vorticity_barotropic',
'ageostrophic_wind', 'frontogenesis', 'geostrophic_wind', 'montgomery_streamfunction',
'potential_vorticity_baroclinic', 'shearing_deformation', 'storm_relative_helicity',
'stretching_deformation', 'total_deformation', 'inertial_advective_wind', 'q_vector'
]
def cdiff(a, dimidx):
"""
@ -33,6 +34,7 @@ def cdiff(a, dimidx):
else:
return NDArray(r)
def vorticity(u, v, dx=None, dy=None, x_dim=-1, y_dim=-2):
r"""Calculate the vertical vorticity of the horizontal wind.
@ -106,6 +108,7 @@ def vorticity(u, v, dx=None, dy=None, x_dim=-1, y_dim=-2):
else:
return r
def vorticity_bak(u, v, x=None, y=None):
"""
Calculates the vertical component of the curl (ie, vorticity). The data should be lon/lat projection.
@ -141,6 +144,7 @@ def vorticity_bak(u, v, x=None, y=None):
else:
return NDArray(r)
def divergence(u, v, dx=None, dy=None, x_dim=-1, y_dim=-2):
r"""Calculate the horizontal divergence of the horizontal wind.
@ -251,6 +255,7 @@ def divergence_bak(u, v, x=None, y=None):
else:
return NDArray(r)
def shearing_deformation(u, v, dx=None, dy=None, x_dim=-1, y_dim=-2):
r"""Calculate the shearing deformation of the horizontal wind.
@ -288,6 +293,7 @@ def shearing_deformation(u, v, dx=None, dy=None, x_dim=-1, y_dim=-2):
dvdx = first_derivative(v, delta=dx, axis=x_dim)
return dvdx + dudy
def stretching_deformation(u, v, dx=None, dy=None, x_dim=-1, y_dim=-2):
r"""Calculate the stretching deformation of the horizontal wind.
@ -325,6 +331,7 @@ def stretching_deformation(u, v, dx=None, dy=None, x_dim=-1, y_dim=-2):
dvdy = first_derivative(v, delta=dy, axis=y_dim)
return dudx - dvdy
def total_deformation(u, v, dx=None, dy=None, x_dim=-1, y_dim=-2):
r"""Calculate the total deformation of the horizontal wind.
@ -365,10 +372,11 @@ def total_deformation(u, v, dx=None, dy=None, x_dim=-1, y_dim=-2):
"""
dudy, dudx = gradient(u, deltas=(dy, dx), axes=(y_dim, x_dim))
dvdy, dvdx = gradient(v, deltas=(dy, dx), axes=(y_dim, x_dim))
return np.sqrt((dvdx + dudy)**2 + (dudx - dvdy)**2)
return np.sqrt((dvdx + dudy) ** 2 + (dudx - dvdy) ** 2)
def advection(scalar, u=None, v=None, w=None, dx=None, dy=None, dz=None, x_dim=-1,
y_dim=-2, vertical_dim=-3):
y_dim=-2, vertical_dim=-3):
r"""
Calculate the advection of a scalar field by the wind.
@ -416,6 +424,7 @@ def advection(scalar, u=None, v=None, w=None, dx=None, dy=None, dz=None, x_dim=-
if wind is not None
)
def frontogenesis(potential_temperature, u, v, dx=None, dy=None, x_dim=-1, y_dim=-2):
r"""Calculate the 2D kinematic frontogenesis of a temperature field.
@ -468,7 +477,7 @@ def frontogenesis(potential_temperature, u, v, dx=None, dy=None, x_dim=-1, y_dim
ddx_theta = first_derivative(potential_temperature, delta=dx, axis=x_dim)
# Compute the magnitude of the potential temperature gradient
mag_theta = np.sqrt(ddx_theta**2 + ddy_theta**2)
mag_theta = np.sqrt(ddx_theta ** 2 + ddy_theta ** 2)
# Get the shearing, stretching, and total deformation of the wind field
shrd = shearing_deformation(u, v, dx, dy, x_dim=x_dim, y_dim=y_dim)
@ -484,6 +493,7 @@ def frontogenesis(potential_temperature, u, v, dx=None, dy=None, x_dim=-1, y_dim
return 0.5 * mag_theta * (tdef * np.cos(2 * beta) - div)
def geostrophic_wind(height, dx=None, dy=None, latitude=None, x_dim=-1, y_dim=-2):
r"""Calculate the geostrophic wind given from the height or geopotential.
@ -525,6 +535,7 @@ def geostrophic_wind(height, dx=None, dy=None, latitude=None, x_dim=-1, y_dim=-2
dhdx = first_derivative(height, delta=dx, axis=x_dim)
return -norm_factor * dhdy, norm_factor * dhdx
def ageostrophic_wind(height, u, v, dx=None, dy=None, latitude=None, x_dim=-1, y_dim=-2):
r"""Calculate the ageostrophic wind given from the height or geopotential.
@ -571,6 +582,7 @@ def ageostrophic_wind(height, u, v, dx=None, dy=None, latitude=None, x_dim=-1, y
)
return u - u_geostrophic, v - v_geostrophic
def montgomery_streamfunction(height, temperature):
r"""Compute the Montgomery Streamfunction on isentropic surfaces.
@ -608,6 +620,7 @@ def montgomery_streamfunction(height, temperature):
from . import dry_static_energy
return dry_static_energy(height, temperature)
def storm_relative_helicity(height, u, v, depth, bottom=None, storm_u=None, storm_v=None):
r"""Calculate storm relative helicity.
@ -671,6 +684,7 @@ def storm_relative_helicity(height, u, v, depth, bottom=None, storm_u=None, stor
return positive_srh, negative_srh, positive_srh + negative_srh
def absolute_vorticity(u, v, dx=None, dy=None, latitude=None, x_dim=-1, y_dim=-2):
"""Calculate the absolute vorticity of the horizontal wind.
@ -708,8 +722,9 @@ def absolute_vorticity(u, v, dx=None, dy=None, latitude=None, x_dim=-1, y_dim=-2
relative_vorticity = vorticity(u, v, dx=dx, dy=dy, x_dim=x_dim, y_dim=y_dim)
return relative_vorticity + f
def potential_vorticity_baroclinic(potential_temperature, pressure, u, v,
dx=None, dy=None, latitude=None, x_dim=-1, y_dim=-2, vertical_dim=-3):
dx=None, dy=None, latitude=None, x_dim=-1, y_dim=-2, vertical_dim=-3):
r"""Calculate the baroclinic potential vorticity.
.. math:: PV = -g \left(\frac{\partial u}{\partial p}\frac{\partial \theta}{\partial y}
@ -795,8 +810,9 @@ def potential_vorticity_baroclinic(potential_temperature, pressure, u, v,
return -constants.g * (dudp * dthetady - dvdp * dthetadx
+ avor * dthetadp)
def potential_vorticity_barotropic(height, u, v, dx=None, dy=None, latitude=None,
x_dim=-1, y_dim=-2):
x_dim=-1, y_dim=-2):
r"""Calculate the barotropic (Rossby) potential vorticity.
.. math:: PV = \frac{f + \zeta}{H}
@ -837,8 +853,9 @@ def potential_vorticity_barotropic(height, u, v, dx=None, dy=None, latitude=None
avor = absolute_vorticity(u, v, dx, dy, latitude, x_dim=x_dim, y_dim=y_dim)
return avor / height
def inertial_advective_wind(u, v, u_geostrophic, v_geostrophic, dx=None, dy=None,
latitude=None, x_dim=-1, y_dim=-2):
latitude=None, x_dim=-1, y_dim=-2):
r"""Calculate the inertial advective wind.
.. math:: \frac{\hat k}{f} \times (\vec V \cdot \nabla)\hat V_g
@ -906,8 +923,9 @@ def inertial_advective_wind(u, v, u_geostrophic, v_geostrophic, dx=None, dy=None
return u_component, v_component
def q_vector(u, v, temperature, pressure, dx=None, dy=None,
static_stability=1, x_dim=-1, y_dim=-2):
static_stability=1, x_dim=-1, y_dim=-2):
r"""Calculate Q-vector at a given pressure level using the u, v winds and temperature.
.. math:: \vec{Q} = (Q_1, Q_2)
@ -969,4 +987,4 @@ def q_vector(u, v, temperature, pressure, dx=None, dy=None,
q1 = -mpconsts.Rd / (pressure * static_stability) * (dudx * dtempdx + dvdx * dtempdy)
q2 = -mpconsts.Rd / (pressure * static_stability) * (dudy * dtempdx + dvdy * dtempdy)
return q1.to_base_units(), q2.to_base_units()
return q1.to_base_units(), q2.to_base_units()

View File

@ -9,8 +9,9 @@ from mipylib.geolib import Geod
from ..interpolate import interpolate_1d
from ..cbook import broadcast_indices
__all__ = ['resample_nn_1d','nearest_intersection_idx','first_derivative','gradient',
'lat_lon_grid_deltas','get_layer_heights', 'find_bounding_indices']
__all__ = ['resample_nn_1d', 'nearest_intersection_idx', 'first_derivative', 'gradient',
'lat_lon_grid_deltas', 'get_layer_heights', 'find_bounding_indices']
def resample_nn_1d(a, centers):
"""Return one-dimensional nearest-neighbor indexes based on user-specified centers.
@ -33,6 +34,7 @@ def resample_nn_1d(a, centers):
ix.append(index)
return ix
def nearest_intersection_idx(a, b):
"""Determine the index of the point just before two lines with common x values.
Parameters
@ -55,6 +57,7 @@ def nearest_intersection_idx(a, b):
return sign_change_idx
def _remove_nans(*variables):
"""Remove NaNs from arrays that cause issues with calculations.
Takes a variable number of arguments and returns masked arrays in the same
@ -73,6 +76,7 @@ def _remove_nans(*variables):
ret.append(v[~mask])
return ret
def get_layer_heights(height, depth, *args, **kwargs):
"""Return an atmospheric layer from upper air data with the requested bottom and depth.
@ -162,6 +166,7 @@ def get_layer_heights(height, depth, *args, **kwargs):
ret.append(datavar)
return ret
def find_bounding_indices(arr, values, axis, from_below=True):
"""Find the indices surrounding the values within arr along axis.
@ -242,6 +247,7 @@ def find_bounding_indices(arr, values, axis, from_below=True):
return above, below, good
def _greater_or_close(a, value, **kwargs):
r"""Compare values for greater or close to boolean masks.
@ -283,13 +289,17 @@ def _less_or_close(a, value, **kwargs):
"""
return (a < value) | np.isclose(a, value, **kwargs)
def make_take(ndims, slice_dim):
"""Generate a take function to index in a particular dimension."""
def take(indexer):
return tuple(indexer if slice_dim % ndims == i else slice(None) # noqa: S001
for i in range(ndims))
return take
def _broadcast_to_axis(arr, axis, ndim):
"""Handle reshaping coordinate array to have proper dimensionality.
This puts the values along the specified axis.
@ -300,6 +310,7 @@ def _broadcast_to_axis(arr, axis, ndim):
arr = arr.reshape(*new_shape)
return arr
def lat_lon_grid_deltas(longitude, latitude, x_dim=-1, y_dim=-2, geod=None):
r"""
Calculate the actual delta between grid points that are in latitude/longitude format.
@ -362,6 +373,7 @@ def lat_lon_grid_deltas(longitude, latitude, x_dim=-1, y_dim=-2, geod=None):
return dx, dy
def _process_gradient_args(f, axes, coordinates, deltas):
"""Handle common processing of arguments for gradient and gradient-like functions."""
axes_given = axes is not None
@ -389,6 +401,7 @@ def _process_gradient_args(f, axes, coordinates, deltas):
raise ValueError('Must specify either "coordinates" or "deltas" for value positions '
'when "f" is not a DataArray.')
def _process_deriv_args(f, axis, x, delta):
"""Handle common processing of arguments for derivative functions."""
n = f.ndim
@ -416,6 +429,7 @@ def _process_deriv_args(f, axis, x, delta):
return n, axis, delta
def first_derivative(f, axis=None, x=None, delta=None):
"""Calculate the first derivative of a grid of values.
Works for both regularly-spaced data and grids with varying spacing.
@ -497,6 +511,7 @@ def first_derivative(f, axis=None, x=None, delta=None):
return data
def gradient(f, axes=None, coordinates=None, deltas=None):
"""Calculate the gradient of a grid of values.
Works for both regularly-spaced data, and grids with varying spacing.
@ -535,4 +550,4 @@ def gradient(f, axes=None, coordinates=None, deltas=None):
"""
pos_kwarg, positions, axes = _process_gradient_args(f, axes, coordinates, deltas)
return tuple(first_derivative(f, axis=axis, **{pos_kwarg: positions[ind]})
for ind, axis in enumerate(axes))
for ind, axis in enumerate(axes))

View File

@ -13,10 +13,11 @@ from . import optimize
from . import signal
from . import spatial
from . import special
from . import integrate
__all__ = ['linalg', 'fitting', 'random', 'ma', 'stats', 'interpolate', 'optimize', 'signal', 'spatial',
'special']
'special', 'integrate']
__all__.extend(['__version__'])
__all__.extend(core.__all__)
__all__.extend(lib.__all__)
__all__.extend(['griddata'])
__all__.extend(['griddata'])

View File

@ -0,0 +1,3 @@
from ._ode import *
__all__ = _ode.__all__

View File

@ -0,0 +1,52 @@
# coding=utf-8
from org.meteoinfo.math.integrate import ODEEquations, IntegrateUtil
from ..core import numeric as np
__all__ = ['odeint']
class ODE(ODEEquations):
def __init__(self, f):
"""
Initialize
:param f: Jython function
"""
self.f = f
self._args = list(f.__code__.co_varnames)[2:]
self._args = tuple(self._args)
self.order = len(self._args)
def doComputeDerivatives(self, y, t):
args = tuple(self.getParameters())
return self.f(y, t, *args)
def odeint(func, y0, t, args=()):
"""
Integrate a system of ordinary differential equations.
:param func: (callable(y, t, ) ) Computes the derivative of y at t.
:param y0: (*array*) Initial condition on y (can be a vector).
:param t: (*array*) A sequence of time points for which to solve for y. The initial value point should
be the first element of this sequence.
:param args: (*tuple*) Extra arguments to pass to function.
:return: Array containing the value of y for each desired time in t.
"""
func = ODE(func)
if len(args) > 0:
func.setParameters(args)
if isinstance(y0, (tuple, list)):
y0 = np.array(y0)
ndim = len(y0)
func.setDimension(ndim)
if isinstance(t, (tuple, list)):
t = np.array(t)
r = IntegrateUtil.odeIntegrate(func, y0.asarray(), t.asarray())
return np.NDArray(r)

View File

@ -1,20 +1,18 @@
from org.meteoinfo.math.optimize import OptimizeUtil, ParamUnivariateFunction
from org.apache.commons.math4.legacy.fitting.leastsquares import LeastSquaresBuilder
from org.apache.commons.math4.legacy.fitting.leastsquares import LevenbergMarquardtOptimizer
import warnings
from ..core import numeric as np
from ..linalg import cholesky, solve_triangular, svd
from ..lib._util import _lazywhere
from ..linalg import solve_triangular
# __all__ = ['fsolve', 'leastsq', 'fixed_point', 'curve_fit']
__all__ = ['curve_fit', 'fixed_point']
#__all__ = ['fsolve', 'leastsq', 'fixed_point', 'curve_fit']
__all__ = ['curve_fit','fixed_point']
def _check_func(checker, argname, thefunc, x0, args, numinputs,
output_shape=None):
res = np.atleast_1d(thefunc(*((x0[:numinputs],) + args)))
if (output_shape is not None) and (np.shape(res) != output_shape):
if (output_shape[0] != 1):
if output_shape[0] != 1:
if len(output_shape) > 1:
if output_shape[1] == 1:
return shape(res)
@ -29,6 +27,7 @@ def _check_func(checker, argname, thefunc, x0, args, numinputs,
raise TypeError(msg)
return np.shape(res), res.dtype
def _wrap_func(func, xdata, ydata, transform):
if transform is None:
def func_wrapped(params):
@ -49,6 +48,7 @@ def _wrap_func(func, xdata, ydata, transform):
return solve_triangular(transform, func(xdata, *params) - ydata, lower=True)
return func_wrapped
def _wrap_jac(jac, xdata, transform):
if transform is None:
def jac_wrapped(params):
@ -61,6 +61,7 @@ def _wrap_jac(jac, xdata, transform):
return solve_triangular(transform, np.asarray(jac(xdata, *params)), lower=True)
return jac_wrapped
def _initialize_feasible(lb, ub):
p0 = np.ones_like(lb)
lb_finite = np.isfinite(lb)
@ -77,6 +78,7 @@ def _initialize_feasible(lb, ub):
return p0
class UniFunc(ParamUnivariateFunction):
def __init__(self, f):
"""
@ -93,6 +95,7 @@ class UniFunc(ParamUnivariateFunction):
args = tuple(self.getParameters())
return self.f(x, *args)
def curve_fit(f, xdata, ydata, p0=None, npoint=5, step=0.1):
"""
Use non-linear least squares to fit a function, f, to data.
@ -140,6 +143,7 @@ def curve_fit(f, xdata, ydata, p0=None, npoint=5, step=0.1):
return r
# def curve_fit(f, xdata, ydata, p0=None, npoint=5, step=0.1):
# """
# Use non-linear least squares to fit a function, f, to data.
@ -202,6 +206,7 @@ def _del2(p0, p1, d):
def _relerr(actual, desired):
return (actual - desired) / desired
def _fixed_point_helper(func, x0, args, xtol, maxiter, use_accel):
p0 = x0
for _ in range(maxiter):
@ -257,6 +262,6 @@ def fixed_point(func, x0, args=(), xtol=1e-8, maxiter=500, method='del2'):
array([ 1.4920333 , 1.37228132])
"""
use_accel = {'del2': True, 'iteration': False}[method]
#x0 = _asarray_validated(x0, as_inexact=True)
# x0 = _asarray_validated(x0, as_inexact=True)
x0 = np.asarray(x0)
return _fixed_point_helper(func, x0, args, xtol, maxiter, use_accel)
return _fixed_point_helper(func, x0, args, xtol, maxiter, use_accel)

View File

@ -0,0 +1,60 @@
package org.meteoinfo.math.integrate;
import org.apache.commons.math4.legacy.ode.ContinuousOutputModel;
import org.apache.commons.math4.legacy.ode.FirstOrderDifferentialEquations;
import org.apache.commons.math4.legacy.ode.FirstOrderIntegrator;
import org.apache.commons.math4.legacy.ode.nonstiff.DormandPrince853Integrator;
import org.apache.commons.math4.legacy.ode.sampling.StepHandler;
import org.apache.commons.math4.legacy.ode.sampling.StepInterpolator;
import org.meteoinfo.ndarray.Array;
import org.meteoinfo.ndarray.DataType;
import org.meteoinfo.ndarray.IndexIterator;
import java.io.*;
public class IntegrateUtil {
/**
* Integrate a system of ordinary differential equations
*
* @param equations Computes the derivative of y at t
* @param y0 Initial condition on y
* @param t A sequence of time points for which to solve for y
* @return Array containing the value of y for each desired time in t, with the initial value y0 in the first row.
*/
public static Array odeIntegrate(FirstOrderDifferentialEquations equations, Array y0, Array t) throws IOException, ClassNotFoundException {
y0 = y0.copyIfView();
t = t.copyIfView();
int nt = (int) t.getSize();
int ny0 = (int) y0.getSize();
FirstOrderIntegrator integrator = new DormandPrince853Integrator(1.0e-8, 100.0, 1.0e-10, 1.0e-10);
double[] y0v = (double[]) y0.getStorage();
double[] yDot = new double[ny0];
integrator.addStepHandler(new ContinuousOutputModel());
integrator.integrate(equations, t.getDouble(0), y0v, t.getDouble(nt - 1), yDot);
ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(bos);
for (StepHandler handler : integrator.getStepHandlers()) {
oos.writeObject(handler);
}
ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
ObjectInputStream ois = new ObjectInputStream(bis);
ContinuousOutputModel cm = (ContinuousOutputModel) ois.readObject();
Array r = Array.factory(DataType.DOUBLE, new int[]{nt, ny0});
IndexIterator iter = r.getIndexIterator();
for (int i = 0; i < nt; i++) {
cm.setInterpolatedTime(t.getDouble(i));
double[] interpolatedY = cm.getInterpolatedState();
for (double v : interpolatedY) {
iter.setDoubleNext(v);
}
}
return r;
}
}

View File

@ -0,0 +1,54 @@
package org.meteoinfo.math.integrate;
import org.apache.commons.math4.legacy.exception.DimensionMismatchException;
import org.apache.commons.math4.legacy.exception.MaxCountExceededException;
import org.apache.commons.math4.legacy.ode.FirstOrderDifferentialEquations;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
public class ODEEquations implements FirstOrderDifferentialEquations {
private List<Double> parameters = new ArrayList<>();
private int dimension;
/**
* Get parameters
* @return Parameters
*/
public List<Double> getParameters() {
return this.parameters;
}
/**
* Set parameters
* @param value Parameters
*/
public void setParameters(List<Double> value) {
this.parameters = value;
}
@Override
public int getDimension() {
return dimension;
}
public void setDimension(int value) {
this.dimension = value;
}
@Override
public void computeDerivatives(double v, double[] y, double[] yDot) throws MaxCountExceededException, DimensionMismatchException {
List<Double> yd = doComputeDerivatives(Arrays.stream(y).boxed().collect(Collectors.toList()), v);
int n = yd.size();
for (int i = 0; i < n; i++) {
yDot[i] = yd.get(i);
}
}
public List<Double> doComputeDerivatives(List<Double> y, double t) {
return Arrays.asList(0.);
}
}