add lambertw function

This commit is contained in:
wyq 2025-08-07 15:46:16 +08:00
parent 7594c6c1f6
commit bcfd50f1e3
13 changed files with 417 additions and 89 deletions

View File

@ -14,6 +14,24 @@ import java.util.List;
public class JythonUtil {
/**
* Convert jython complex to java complex
* @param v Jython complex
* @return Java complex
*/
public static Complex toComplex(PyComplex v) {
return new Complex(v.real, v.imag);
}
/**
* Convert java complex to jython complex
* @param v Java complex
* @return Jython complex
*/
public static PyComplex toComplex(Complex v) {
return new PyComplex(v.real(), v.imag());
}
/**
* Convert PyComplex value to ArrayComplex
* @param data PyComplex value

View File

@ -1,32 +1,32 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<MeteoInfo File="milconfig.xml" Type="configurefile">
<Path OpenPath="D:\Working\MIScript\Jython\mis\common_math\integrate">
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\funny"/>
<Path OpenPath="D:\Working\MIScript\Jython\mis\common_math\special">
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl\plot"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\meteo"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\meteo\calc"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\io\geotiff"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math\ndimage"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\array\complex"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\array"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\dataframe"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\dataset"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\io\web"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\io"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\io\micaps"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\others"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\test"/>
<RecentFolder Folder="D:\Working\MIScript\mywork\music"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\chart"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\chart\legend"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\array"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\meteo\calc"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math\integrate"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math\special"/>
</Path>
<File>
<OpenedFiles>
<OpenedFile File="D:\Working\MIScript\Jython\mis\common_math\integrate\odeint_lorenz.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\common_math\integrate\solve_ivp_celestial_motion.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\common_math\integrate\solve_ivp_Lotka-Volterra.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\meteo\calc\mixing_ratio_from_relative_humidity.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\common_math\special\airy.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\common_math\special\lambertw.py"/>
</OpenedFiles>
<RecentFiles>
<RecentFile File="D:\Working\MIScript\Jython\mis\common_math\integrate\odeint_lorenz.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\common_math\integrate\solve_ivp_celestial_motion.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\common_math\integrate\solve_ivp_Lotka-Volterra.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\meteo\calc\mixing_ratio_from_relative_humidity.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\common_math\special\airy.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\common_math\special\lambertw.py"/>
</RecentFiles>
</File>
<Font>
@ -34,5 +34,5 @@
</Font>
<LookFeel DockWindowDecorated="true" LafDecorated="true" Name="FlatDarkLaf"/>
<Figure DoubleBuffering="true"/>
<Startup MainFormLocation="-6,-6" MainFormSize="1292,764"/>
<Startup MainFormLocation="-6,0" MainFormSize="1322,806"/>
</MeteoInfo>

View File

@ -100,10 +100,7 @@ class NDArray(object):
#deal with Ellipsis
if Ellipsis in indices:
n = 0
for ii in indices:
if ii is not None:
n += 1;
n = self.ndim - len(indices) + 1
indices1 = []
for ii in indices:

View File

@ -83,17 +83,38 @@ def nonzero(a):
return tuple(r)
def where(condition):
def where(condition, *args):
"""
Return elements, either from x or y, depending on condition.
If only condition is given, return condition.nonzero().
:param condition: (*array_like*) Input array.
Parameters
----------
condition : `array_like, bool`
Where True, yield x, otherwise yield y.
:returns: (*tuple*) Indices of elements that are non-zero.
x, y : `array_like`
Values from which to choose. x, y and condition need to be broadcastable to some shape.
Returns
-------
`array`
An array with elements from x where condition is True, and elements from y elsewhere.
"""
return nonzero(condition)
if len(args) == 0:
return nonzero(condition)
x = args[0]
y = args[1]
if isinstance(condition, bool):
return x if condition else y
else:
condition = asarray(condition)
x = asarray(x)
y = asarray(y)
r = ArrayUtil.where(condition._array, x._array, y._array)
return NDArray(r)
def searchsorted(a, v, side='left', sorter=None):

View File

@ -1292,7 +1292,10 @@ def any(x, axis=None):
:returns: (*array_like*) Any result
"""
if isinstance(x, list):
if isinstance(x, bool):
return x
if isinstance(x, (list, tuple)):
x = array(x)
return x.any(axis)

View File

@ -188,7 +188,7 @@ class _RichResult(dict):
return sorted(omit_redundant(d.items()), key=key)
if self.keys():
return _dict_formatter(self, sorter=item_sorter)
return list(self.keys()).__repr__()
else:
return self.__class__.__name__ + "()"

View File

@ -2,9 +2,11 @@ from ._gamma import *
from ._basic import *
from ._erf import *
from ._airy import *
from ._lambertw import *
__all__ = []
__all__.extend(_basic.__all__)
__all__.extend(_gamma.__all__)
__all__.extend(_erf.__all__)
__all__.extend(_airy.__all__)
__all__ += ['lambertw']

View File

@ -0,0 +1,155 @@
from ..core import numeric as np
from org.meteoinfo.math.special import LambertW
from org.meteoinfo.jython import JythonUtil
def lambertw(z, k=0, tol=1e-8):
r"""
lambertw(z, k=0, tol=1e-8)
Lambert W function.
The Lambert W function `W(z)` is defined as the inverse function
of ``w * exp(w)``. In other words, the value of ``W(z)`` is
such that ``z = W(z) * exp(W(z))`` for any complex number
``z``.
The Lambert W function is a multivalued function with infinitely
many branches. Each branch gives a separate solution of the
equation ``z = w exp(w)``. Here, the branches are indexed by the
integer `k`.
Parameters
----------
z : array_like
Input argument.
k : int, optional
Branch index.
tol : float, optional
Evaluation tolerance.
Returns
-------
w : array
`w` will have the same shape as `z`.
See Also
--------
wrightomega : the Wright Omega function
Notes
-----
All branches are supported by `lambertw`:
* ``lambertw(z)`` gives the principal solution (branch 0)
* ``lambertw(z, k)`` gives the solution on branch `k`
The Lambert W function has two partially real branches: the
principal branch (`k = 0`) is real for real ``z > -1/e``, and the
``k = -1`` branch is real for ``-1/e < z < 0``. All branches except
``k = 0`` have a logarithmic singularity at ``z = 0``.
**Possible issues**
The evaluation can become inaccurate very close to the branch point
at ``-1/e``. In some corner cases, `lambertw` might currently
fail to converge, or can end up on the wrong branch.
**Algorithm**
Halley's iteration is used to invert ``w * exp(w)``, using a first-order
asymptotic approximation (O(log(w)) or `O(w)`) as the initial estimate.
The definition, implementation and choice of branches is based on [2]_.
References
----------
.. [1] https://en.wikipedia.org/wiki/Lambert_W_function
.. [2] Corless et al, "On the Lambert W function", Adv. Comp. Math. 5
(1996) 329-359.
https://cs.uwaterloo.ca/research/tr/1993/03/W.pdf
Examples
--------
The Lambert W function is the inverse of ``w exp(w)``:
>>> from mipylib.numeric.special import lambertw
>>> w = lambertw(1)
>>> w
(0.56714329040978384+0j)
>>> w * np.exp(w)
(1.0+0j)
Any branch gives a valid inverse:
>>> w = lambertw(1, k=3)
>>> w
(-2.8535817554090377+17.113535539412148j)
>>> w*np.exp(w)
(1.0000000000000002+1.609823385706477e-15j)
**Applications to equation-solving**
The Lambert W function may be used to solve various kinds of
equations. We give two examples here.
First, the function can be used to solve implicit equations of the
form
:math:`x = a + b e^{c x}`
for :math:`x`. We assume :math:`c` is not zero. After a little
algebra, the equation may be written
:math:`z e^z = -b c e^{a c}`
where :math:`z = c (a - x)`. :math:`z` may then be expressed using
the Lambert W function
:math:`z = W(-b c e^{a c})`
giving
:math:`x = a - W(-b c e^{a c})/c`
For example,
>>> a = 3
>>> b = 2
>>> c = -0.5
The solution to :math:`x = a + b e^{c x}` is:
>>> x = a - lambertw(-b*c*np.exp(a*c))/c
>>> x
(3.3707498368978794+0j)
Verify that it solves the equation:
>>> a + b*np.exp(c*x)
(3.37074983689788+0j)
The Lambert W function may also be used find the value of the infinite
power tower :math:`z^{z^{z^{\ldots}}}`:
>>> def tower(z, n):
... if n == 0:
... return z
... return z ** tower(z, n-1)
...
>>> tower(0.5, 100)
0.641185744504986
>>> -lambertw(-np.log(0.5)) / np.log(0.5)
(0.64118574450498589+0j)
"""
LambertW.setEPS(tol)
if isinstance(z, (int, float)):
z = complex(z)
if isinstance(z, complex):
r = LambertW.Wk(JythonUtil.toComplex(z), k)
return JythonUtil.toComplex(r)
else:
r = LambertW.Wk(z, k)
return np.array(r)

View File

@ -0,0 +1,103 @@
package org.meteoinfo.math.special;
import org.meteoinfo.ndarray.Array;
import org.meteoinfo.ndarray.Complex;
import org.meteoinfo.ndarray.DataType;
import org.meteoinfo.ndarray.IndexIterator;
/**
* Implementation of an algorithm for the Lambert W
*/
public class LambertW {
public static int MAXIT = 15;
public static double EPS = 1e-15;
public static void setEPS(double eps) {
EPS = eps;
}
/** main branch W₀(z) */
public static Complex W0(Complex z) { return eval(z, 0); }
/** main branch W₀(z) */
public static Array W0(Array z) {
Array r = Array.factory(DataType.COMPLEX, z.getShape());
IndexIterator iterR = r.getIndexIterator();
IndexIterator iterZ = z.getIndexIterator();
while (iterR.hasNext()) {
iterR.setComplexNext(eval(iterZ.getComplexNext(), 0));
}
return r;
}
/** other branch Wₖ(z) */
public static Complex Wk(Complex z, int k) { return eval(z, k); }
/** other branch Wₖ(z) */
public static Array Wk(Array z, int k) {
Array r = Array.factory(DataType.COMPLEX, z.getShape());
IndexIterator iterR = r.getIndexIterator();
IndexIterator iterZ = z.getIndexIterator();
while (iterR.hasNext()) {
iterR.setComplexNext(eval(iterZ.getComplexNext(), k));
}
return r;
}
/* ---------- core iteration ---------- */
private static Complex eval(Complex z, int k) {
if (z.real() == Double.NEGATIVE_INFINITY || z.imag() == Double.NEGATIVE_INFINITY
|| Double.isNaN(z.real()) || Double.isNaN(z.imag()))
return new Complex(Double.NaN, Double.NaN);
/* deal special points */
if (k == 0 && z.real() == -1.0 / Math.E && z.imag() == 0.0)
return new Complex(-1, 0);
if (k == 0 && z.real() == 0 && z.imag() == 0)
return new Complex(0, 0);
Complex w = initialGuess(z, k);
for (int iter = 0; iter < MAXIT; iter++) {
Complex e = w.exp();
Complex f = w.multiply(e).subtract(z);
Complex df = e.multiply(w.add(new Complex(1, 0)));
Complex ddf = e.multiply(w.add(new Complex(2, 0)));
Complex delta = f.multiply(df).divide(df.multiply(df).subtract(f.multiply(ddf).multiply(0.5)));
w = w.subtract(delta);
if (delta.abs() < EPS * w.abs())
break;
}
return w;
}
/* ---------- initial guess ---------- */
private static Complex initialGuess(Complex z, int k) {
if (k == 0) return initialGuess0(z);
else return initialGuessK(z, k);
}
/* main branch W₀(z) initial value */
private static Complex initialGuess0(Complex z) {
double x = z.real(), y = z.imag();
if (Math.abs(y) < 1e-15 && Math.abs(x) <= 0.5) {
/* from SciPy cephes/lambertw.c Pade approximation */
if (x == 0) return new Complex(0, 0);
double L1 = Math.log(1 + x);
// first order correction: w L1 / (1 + L1)
return new Complex(L1 / (1 + L1), 0);
}
// approximate: w log(z) (k=0)
return z.log();
}
/* other branch Wₖ(z), k≠0 initial value */
private static Complex initialGuessK(Complex z, int k) {
Complex logZ = z.log();
Complex twoPiK = new Complex(0, 2 * Math.PI * k);
return logZ.add(twoPiK);
}
}

View File

@ -1761,66 +1761,6 @@ public class ArrayUtil {
}
}
/**
* Return the indices of the elements that are non-zero.
*
* @param a Input array
* @return Indices
*/
public static List<Array> nonzero(Array a) {
List<List<Integer>> r = new ArrayList<>();
int ndim = a.getRank();
for (int i = 0; i < ndim; i++) {
r.add(new ArrayList<Integer>());
}
IndexIterator iterA = a.getIndexIterator();
int[] counter;
double v;
while (iterA.hasNext()) {
v = iterA.getDoubleNext();
if (!Double.isNaN(v) && v != 0) {
counter = iterA.getCurrentCounter();
for (int j = 0; j < ndim; j++) {
r.get(j).add(counter[j]);
}
}
}
if (r.get(0).isEmpty()) {
return null;
}
List<Array> ra = new ArrayList<>();
for (int i = 0; i < ndim; i++) {
ra.add(ArrayUtil.array(r.get(i), null));
}
return ra;
}
/**
* Return the flat indices of the elements that are non-zero.
*
* @param a Input array
* @return Flat indices
*/
public static Array flatNonZero(Array a) {
List<Integer> r = new ArrayList<>();
IndexIterator iterA = a.getIndexIterator();
int[] counter;
double v;
int i = 0;
while (iterA.hasNext()) {
v = iterA.getDoubleNext();
if (!Double.isNaN(v) && v != 0) {
r.add(i);
}
i += 1;
}
return ArrayUtil.array(r, DataType.INT);
}
// </editor-fold>
// <editor-fold desc="Output">
@ -3021,6 +2961,95 @@ public class ArrayUtil {
return r;
}
/**
* Return the indices of the elements that are non-zero.
*
* @param a Input array
* @return Indices
*/
public static List<Array> nonzero(Array a) {
List<List<Integer>> r = new ArrayList<>();
int ndim = a.getRank();
for (int i = 0; i < ndim; i++) {
r.add(new ArrayList<Integer>());
}
IndexIterator iterA = a.getIndexIterator();
int[] counter;
double v;
while (iterA.hasNext()) {
v = iterA.getDoubleNext();
if (!Double.isNaN(v) && v != 0) {
counter = iterA.getCurrentCounter();
for (int j = 0; j < ndim; j++) {
r.get(j).add(counter[j]);
}
}
}
if (r.get(0).isEmpty()) {
return null;
}
List<Array> ra = new ArrayList<>();
for (int i = 0; i < ndim; i++) {
ra.add(ArrayUtil.array(r.get(i), null));
}
return ra;
}
/**
* Return the flat indices of the elements that are non-zero.
*
* @param a Input array
* @return Flat indices
*/
public static Array flatNonZero(Array a) {
List<Integer> r = new ArrayList<>();
IndexIterator iterA = a.getIndexIterator();
int[] counter;
double v;
int i = 0;
while (iterA.hasNext()) {
v = iterA.getDoubleNext();
if (!Double.isNaN(v) && v != 0) {
r.add(i);
}
i += 1;
}
return ArrayUtil.array(r, DataType.INT);
}
/**
* Return elements chosen from x or y depending on condition.
*
* @param a Input condition array
* @param x X array for true condition
* @param y Y array for false condition
* @return An array with elements from x where condition is True, and elements from y elsewhere.
*/
public static Array where(Array a, Array x, Array y) {
DataType dataType = ArrayMath.commonType(x.getDataType(), y.getDataType());
Array r = Array.factory(dataType, a.getShape());
IndexIterator iterA = a.getIndexIterator();
IndexIterator iterX = x.getIndexIterator();
IndexIterator iterY = y.getIndexIterator();
IndexIterator iterR = r.getIndexIterator();
while (iterR.hasNext()) {
double v = iterA.getDoubleNext();
if (!Double.isNaN(v) && v != 0) {
iterR.setObjectNext(iterX.getObjectNext());
iterY.next();
} else {
iterR.setObjectNext(iterY.getObjectNext());
iterX.next();
}
}
return r;
}
// </editor-fold>
// <editor-fold desc="Statistics">
/**