update solve function

This commit is contained in:
wyq 2022-09-25 20:36:30 +08:00
parent 1c79a5b405
commit 7fb965af98
4 changed files with 12 additions and 11 deletions

View File

@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<MeteoInfo File="milconfig.xml" Type="configurefile">
<Path OpenPath="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl\surf">
<Path OpenPath="D:\Working\MIScript\Jython\mis\common_math\linalg">
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\toolbox\miml"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\toolbox\miml\regression"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d"/>
@ -14,8 +14,8 @@
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\io\micaps"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\array"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math\linalg"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl\surf"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math\linalg"/>
</Path>
<File>
<OpenedFiles>

View File

@ -51,14 +51,8 @@ def solve(a, b):
Solution to the system a x = b. Returned shape is identical to ``b``.
"""
_assert_2d(a)
r_2d = False
if b.ndim == 2:
b = b.flatten()
r_2d = True
x = LinalgUtil.solve(a.asarray(), b.asarray())
r = NDArray(x)
if r_2d:
r = r.reshape((len(r),1))
return r
def solve_triangular(a, b, lower=False):

View File

@ -45,9 +45,16 @@ public class LinalgUtil {
public static Array solve(Array a, Array b) {
Matrix ma = MatrixUtil.arrayToMatrix(a);
Matrix.LU lu = ma.lu();
double[] bb = (double[]) ArrayUtil.copyToNDJavaArray_Double(b);
double[] x = lu.solve(bb);
Array r = Array.factory(DataType.DOUBLE, b.getShape(), x);
Array r;
if (b.getRank() == 1) {
double[] bb = (double[]) ArrayUtil.copyToNDJavaArray_Double(b);
double[] x = lu.solve(bb);
r = Array.factory(DataType.DOUBLE, b.getShape(), x);
} else {
Matrix mb = MatrixUtil.arrayToMatrix(b);
lu.solve(mb);
r = MatrixUtil.matrixToArray(mb);
}
return r;
}