add cross function

This commit is contained in:
wyq 2025-09-20 11:10:48 +08:00
parent d3a1b71637
commit 42153c5bc0
10 changed files with 204 additions and 33 deletions

View File

@ -69,7 +69,7 @@ import java.util.zip.ZipInputStream;
public static String getVersion() {
String version = GlobalUtil.class.getPackage().getImplementationVersion();
if (version == null || version.equals("")) {
version = "4.1.2";
version = "4.1.3";
}
return version;
}

View File

@ -13,7 +13,6 @@ import org.meteoinfo.ndarray.util.DataTypeUtil;
import java.text.DecimalFormat;
import java.time.format.DateTimeFormatter;
import java.util.Formatter;
import java.util.Locale;
/**
@ -251,8 +250,8 @@ public class Column {
* @param s Input string
* @return Result object
*/
public Object convertStringTo(String s) {
return DataTypeUtil.convertStringTo(s, dataType, format);
public Object convertFromString(String s) {
return DataTypeUtil.convertFromString(s, dataType, format);
}
@Override

View File

@ -2764,7 +2764,7 @@ public class DataFrame implements Iterable {
}
} else {
for (String s : indexValues) {
indexData.add(DataTypeUtil.convertStringTo(s, idxDT, null));
indexData.add(DataTypeUtil.convertFromString(s, idxDT, null));
}
index = Index.factory(indexData);
index.updateFormat();
@ -2784,7 +2784,7 @@ public class DataFrame implements Iterable {
String v;
for (int j = 0; j < vv.size(); j++) {
v = (String) vv.get(j);
data.setObject(j * colNum + i, col.convertStringTo(v));
data.setObject(j * colNum + i, col.convertFromString(v));
}
}
df = new DataFrame(data, index, cols);
@ -2800,7 +2800,7 @@ public class DataFrame implements Iterable {
String v;
for (int j = 0; j < vv.size(); j++) {
v = (String) vv.get(j);
a.setObject(j, col.convertStringTo(v));
a.setObject(j, col.convertFromString(v));
}
data.add(a);
}
@ -3051,7 +3051,7 @@ public class DataFrame implements Iterable {
}
} else {
for (String s : indexValues) {
indexData.add(DataTypeUtil.convertStringTo(s, idxDT, null));
indexData.add(DataTypeUtil.convertFromString(s, idxDT, null));
}
index = Index.factory(indexData);
index.updateFormat();
@ -3071,7 +3071,7 @@ public class DataFrame implements Iterable {
String v;
for (int j = 0; j < vv.size(); j++) {
v = (String) vv.get(j);
data.setObject(j * colNum + i, col.convertStringTo(v));
data.setObject(j * colNum + i, col.convertFromString(v));
}
}
df = new DataFrame(data, index, cols);
@ -3087,7 +3087,7 @@ public class DataFrame implements Iterable {
String v;
for (int j = 0; j < vv.size(); j++) {
v = (String) vv.get(j);
a.setObject(j, col.convertStringTo(v));
a.setObject(j, col.convertFromString(v));
}
data.add(a);
}

View File

@ -1,32 +1,30 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<MeteoInfo File="milconfig.xml" Type="configurefile">
<Path OpenPath="D:\Working\MIScript\Jython\mis\ascii">
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math\interpolate"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\array\slice"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\array"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\meteo"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\meteo\calc"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\wind"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\io\burf"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\funny"/>
<Path OpenPath="D:\Working\MIScript\Jython\mis\common_math\linalg">
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\io\micaps"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\map\geoshow"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\map"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\io"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\chart\text"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\chart\latex"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\chart"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\chart\legend"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\map\geoshow"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\map\maskout"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\map"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\map\topology"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\ascii"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\funny"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math\linalg"/>
</Path>
<File>
<OpenedFiles>
<OpenedFile File="D:\Working\MIScript\Jython\mis\meteo\calc\mixed_layer_cape_cin.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\io\micaps\mdfs_10.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\ascii\asciiread_160.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\common_math\linalg\cross.py"/>
</OpenedFiles>
<RecentFiles>
<RecentFile File="D:\Working\MIScript\Jython\mis\meteo\calc\mixed_layer_cape_cin.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\io\micaps\mdfs_10.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\ascii\asciiread_160.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\common_math\linalg\cross.py"/>
</RecentFiles>
</File>
<Font>

View File

@ -12,7 +12,7 @@ from org.meteoinfo.math.stats import StatsUtil
from .. import core as np
__all__ = ['solve', 'cholesky', 'cond', 'det', 'lu', 'qr', 'svd', 'eig', 'eigvals', 'inv',
__all__ = ['cross', 'solve', 'cholesky', 'cond', 'det', 'lu', 'qr', 'svd', 'eig', 'eigvals', 'inv',
'lstsq', 'slogdet', 'solve_triangular', 'norm', 'pinv', 'LinAlgError']
@ -34,6 +34,76 @@ def _assert_2d(*arrays):
'two-dimensional' % a.ndim)
def cross(x1, x2, axis=-1):
"""
Returns the cross product of 3-element vectors.
If ``x1`` and/or ``x2`` are multi-dimensional arrays, then
the cross-product of each pair of corresponding 3-element vectors
is independently computed.
Parameters
----------
x1 : array_like
The first input array.
x2 : array_like
The second input array. Must be compatible with ``x1`` for all
non-compute axes. The size of the axis over which to compute
the cross-product must be the same size as the respective axis
in ``x1``.
axis : int, optional
The axis (dimension) of ``x1`` and ``x2`` containing the vectors for
which to compute the cross-product. Default: ``-1``.
Returns
-------
out : ndarray
An array containing the cross products.
Examples
--------
Vector cross-product.
>>> x = np.array([1, 2, 3])
>>> y = np.array([4, 5, 6])
>>> np.linalg.cross(x, y)
array([-3, 6, -3])
Multiple vector cross-products. Note that the direction of the cross
product vector is defined by the *right-hand rule*.
>>> x = np.array([[1,2,3], [4,5,6]])
>>> y = np.array([[4,5,6], [1,2,3]])
>>> np.linalg.cross(x, y)
array([[-3, 6, -3],
[ 3, -6, 3]])
>>> x = np.array([[1, 2], [3, 4], [5, 6]])
>>> y = np.array([[4, 5], [6, 1], [2, 3]])
>>> np.linalg.cross(x, y, axis=0)
array([[-24, 6],
[ 18, 24],
[-6, -18]])
"""
x1 = np.asanyarray(x1)
x2 = np.asanyarray(x2)
if x1.shape[axis] != 3 or x2.shape[axis] != 3:
raise ValueError(
"Both input arrays must be (arrays of) 3-dimensional vectors, "
"but they are {} and {} "
"dimensional instead.".format(x1.shape[axis], x2.shape[axis])
)
if x1.ndim == 1:
r = LinalgUtil.cross(x1._array, x2._array)
else:
r = LinalgUtil.cross(x1._array, x2._array, axis)
return np.NDArray(r)
def solve(a, b):
"""
Solve a linear matrix equation, or system of linear scalar equations.

View File

@ -13,12 +13,14 @@ import org.meteoinfo.math.blas.LinearAlgebra;
import org.meteoinfo.math.blas.SVDJob;
import org.meteoinfo.math.matrix.Matrix;
import org.meteoinfo.math.matrix.MatrixUtil;
import org.meteoinfo.ndarray.*;
import org.meteoinfo.ndarray.math.ArrayMath;
import org.meteoinfo.ndarray.math.ArrayUtil;
import org.meteoinfo.ndarray.Array;
import org.meteoinfo.ndarray.DataType;
import org.meteoinfo.math.blas.UPLO;
import java.nio.DoubleBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;
/**
@ -29,6 +31,108 @@ public class LinalgUtil {
static Logger logger = Logger.getLogger("LinalgUtil class");
/**
* Returns the cross product of 3-element vectors.
* @param aa The first input array
* @param ba The second input array
* @return An array containing the cross products
*/
public static Array cross(Array aa, Array ba) {
DataType dataType = ArrayMath.commonType(aa.getDataType(), ba.getDataType());
Array r = Array.factory(dataType, aa.getShape());
double[] a = (double[]) aa.get1DJavaArray(double.class);
double[] b = (double[]) ba.get1DJavaArray(double.class);
r.setObject(0, a[1] * b[2] - a[2] * b[1]); // i component
r.setObject(1,a[2] * b[0] - a[0] * b[2]); // j component
r.setObject(2, a[0] * b[1] - a[1] * b[0]); // k component
return r;
}
/**
* Returns the cross product of 3-element vectors.
* @param aa The first input array
* @param ba The second input array
* @return An array containing the cross products
*/
private static Array cross(Array aa, Array ba, DataType dataType) {
Array r = Array.factory(dataType, aa.getShape());
double[] a = (double[]) aa.get1DJavaArray(double.class);
double[] b = (double[]) ba.get1DJavaArray(double.class);
r.setObject(0, a[1] * b[2] - a[2] * b[1]); // i component
r.setObject(1,a[2] * b[0] - a[0] * b[2]); // j component
r.setObject(2, a[0] * b[1] - a[1] * b[0]); // k component
return r;
}
/**
* Returns the cross product of 3-element vectors.
* @param a The first input array
* @param b The second input array
* @param axis The axis to calculate cross products
* @return An array containing the cross products
*/
public static Array cross(Array a, Array b, int axis) throws InvalidRangeException {
a = a.copyIfView();
b = b.copyIfView();
DataType dataType = ArrayMath.commonType(a.getDataType(), b.getDataType());
int[] shape = a.getShape();
Array r = Array.factory(dataType, shape);
int n = 1;
if (axis == -1) {
axis = shape.length - 1;
}
for (int i = 0; i < shape.length; i++) {
if (i != axis) {
n += n * shape[i];
}
}
Index indexR = r.getIndex();
Index index = a.getIndex();
int[] current;
for (int i = 0; i < r.getSize(); i++) {
current = indexR.getCurrentCounter();
if (current[axis] == 0) {
List<Range> ranges = new ArrayList<>();
for (int j = 0; j < shape.length; j++) {
if (j == axis) {
ranges.add(new Range(0, shape[j] - 1, 1));
} else {
ranges.add(new Range(current[j], current[j], 1));
}
}
Array rr = cross(a.section(ranges), b.section(ranges), dataType);
index.set(current);
r.setObject(index, rr.getObject(0));
index.setDim(axis, 1);
r.setObject(index, rr.getObject(1));
index.setDim(axis, 2);
r.setObject(index, rr.getObject(2));
}
indexR.incr();
}
return r;
}
/**
* Returns the cross product of 2-element vectors.
* @param a The first input array
* @param b The second input array
* @return The cross products
*/
public static double cross2D(Array a, Array b) {
a = a.copyIfView();
b = b.copyIfView();
double r = a.getDouble(0) * b.getDouble(1) - a.getDouble(1) * b.getDouble(0);
return r;
}
/**
* Matrix dot operator
* @param a Matrix a

View File

@ -909,7 +909,7 @@ public abstract class Array {
* if possible. Only for numeric types (byte, short, int, long, double,
* float)
*
* @return equivilent data in a ByteBuffer
* @return equivalent data in a ByteBuffer
*/
public ByteBuffer getDataAsByteBuffer() {
throw new UnsupportedOperationException();

View File

@ -525,7 +525,7 @@ public class DataTypeUtil {
* @param dateFormat Date format
* @return Converted data
*/
public static Object convertStringTo(String vStr, DataType dataType, String dateFormat) {
public static Object convertFromString(String vStr, DataType dataType, String dateFormat) {
if (vStr == null) {
switch (dataType) {
case INT:

View File

@ -39,7 +39,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<java.version>1.8</java.version>
<revision>4.1.2</revision>
<revision>4.1.3</revision>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
<maven.compiler.release>8</maven.compiler.release>