add LinearAlgebra abstract class for BLAS engine switch

This commit is contained in:
wyq 2024-03-05 14:28:31 +08:00
parent 9e5394ae89
commit 7c87de22d4
9 changed files with 77 additions and 96 deletions

View File

@ -19,12 +19,12 @@
</Path>
<File>
<OpenedFiles>
<OpenedFile File="D:\Working\MIScript\Jython\mis\map\geoshow\map_2.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\map\projection\lcc_proj_gridlabel.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\plot_types\funny\bindundun.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\plot_types\funny\xuerongrong.py"/>
</OpenedFiles>
<RecentFiles>
<RecentFile File="D:\Working\MIScript\Jython\mis\map\geoshow\map_2.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\map\projection\lcc_proj_gridlabel.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\plot_types\funny\bindundun.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\plot_types\funny\xuerongrong.py"/>
</RecentFiles>
</File>
<Font>
@ -32,5 +32,5 @@
</Font>
<LookFeel DockWindowDecorated="true" LafDecorated="true" Name="FlatDarkLaf"/>
<Figure DoubleBuffering="true"/>
<Startup MainFormLocation="-7,-7" MainFormSize="1293,685"/>
<Startup MainFormLocation="-7,0" MainFormSize="1426,817"/>
</MeteoInfo>

View File

@ -33,35 +33,6 @@ import java.util.logging.Logger;
* @author Haifeng Li
*/
public interface BLAS {
/** The default BLAS engine. */
BLAS engine = getInstance();
/**
* Creates an instance.
* @return a BLAS instance.
*/
static BLAS getInstance() {
BLAS mkl = MKL();
return mkl != null ? mkl : new org.meteoinfo.math.blas.openblas.OpenBLAS();
}
/**
* Creates an MKL instance.
* @return a BLAS instance of MKL.
*/
static BLAS MKL() {
Logger logger = Logger.getLogger("BLAS.class");
try {
Class<?> clazz = Class.forName("org.meteoinfo.math.blas.mkl.MKL");
logger.info("mkl module is available.");
return (BLAS) clazz.getDeclaredConstructor().newInstance();
} catch (Exception e) {
logger.info(String.format("Failed to create MKL instance: %s", e));
}
return null;
}
/**
* Sums the absolute values of the elements of a vector.

View File

@ -33,35 +33,6 @@ import java.util.logging.Logger;
* @author Haifeng Li
*/
public interface LAPACK {
/** The default LAPACK engine. */
LAPACK engine = getInstance();
/**
* Creates an instance.
* @return a LAPACK instance.
*/
static LAPACK getInstance() {
LAPACK mkl = MKL();
return mkl != null ? mkl : new org.meteoinfo.math.blas.openblas.OpenBLAS();
}
/**
* Creates an MKL instance.
* @return a LAPACK instance of MKL.
*/
static LAPACK MKL() {
Logger logger = Logger.getLogger("LAPACK.class");
try {
Class<?> clazz = Class.forName("org.meteoinfo.math.blas.mkl.MKL");
logger.info("mkl module is available.");
return (LAPACK) clazz.getDeclaredConstructor().newInstance();
} catch (Exception e) {
logger.info(String.format("Failed to create MKL instance: %s", e));
}
return null;
}
/**
* Solves a real system of linear equations.

View File

@ -0,0 +1,38 @@
package org.meteoinfo.math.blas;
import org.meteoinfo.math.blas.openblas.OpenBLAS;
import java.util.logging.Logger;
public abstract class LinearAlgebra implements BLAS, LAPACK {
public static LinearAlgebra engine = new OpenBLAS();
/**
* Set linear algebra engine
* @param engineName Engine name
*/
public static void setEngine(String engineName) {
if (engineName.equalsIgnoreCase("mkl")) {
LinearAlgebra la = MKL();
if (la != null) {
engine = la;
}
} else {
engine = new OpenBLAS();
}
}
static LinearAlgebra MKL() {
Logger logger = Logger.getLogger("LAPACK.class");
try {
Class<?> clazz = Class.forName("org.meteoinfo.math.blas.mkl.MKL");
logger.info("mkl module is available.");
return (LinearAlgebra) clazz.getDeclaredConstructor().newInstance();
} catch (Exception e) {
logger.info(String.format("Failed to create MKL instance: %s", e));
}
return null;
}
}

View File

@ -30,7 +30,7 @@ import static org.bytedeco.openblas.global.openblas.*;
*
* @author Haifeng Li
*/
public class OpenBLAS implements BLAS, LAPACK {
public class OpenBLAS extends LinearAlgebra {
@Override
public double asum(int n, double[] x, int incx) {
return cblas_dasum(n, x, incx);

View File

@ -9,6 +9,7 @@ import org.apache.commons.math4.legacy.fitting.leastsquares.*;
import org.apache.commons.math4.legacy.linear.*;
import org.apache.commons.math4.legacy.core.Pair;
import org.meteoinfo.math.blas.LAPACK;
import org.meteoinfo.math.blas.LinearAlgebra;
import org.meteoinfo.math.blas.SVDJob;
import org.meteoinfo.math.matrix.Matrix;
import org.meteoinfo.math.matrix.MatrixUtil;
@ -175,7 +176,7 @@ public class LinalgUtil {
U = new Matrix(m, m);
VT = new Matrix(n, n);
int info = LAPACK.engine.gesdd(W.layout(), SVDJob.ALL, m, n, W.getA(), W.ld(), DoubleBuffer.wrap(s), U.getA(), U.ld(), VT.getA(), VT.ld());
int info = LinearAlgebra.engine.gesdd(W.layout(), SVDJob.ALL, m, n, W.getA(), W.ld(), DoubleBuffer.wrap(s), U.getA(), U.ld(), VT.getA(), VT.ld());
if (info != 0) {
logger.severe(String.format("LAPACK GESDD error code: {%s}", info));
throw new ArithmeticException("LAPACK GESDD error code: " + info);
@ -184,7 +185,7 @@ public class LinalgUtil {
U = new Matrix(m, k);
VT = new Matrix(k, n);
int info = LAPACK.engine.gesdd(W.layout(), SVDJob.COMPACT, m, n, W.getA(), W.ld(), DoubleBuffer.wrap(s), U.getA(), U.ld(), VT.getA(), VT.ld());
int info = LinearAlgebra.engine.gesdd(W.layout(), SVDJob.COMPACT, m, n, W.getA(), W.ld(), DoubleBuffer.wrap(s), U.getA(), U.ld(), VT.getA(), VT.ld());
if (info != 0) {
logger.severe(String.format("LAPACK GESDD error code: {%s}", info));
throw new ArithmeticException("LAPACK GESDD error code: " + info);

View File

@ -1227,9 +1227,9 @@ public class Matrix extends DMatrix {
}
if (isSymmetric() && x == y) {
BLAS.engine.syr(layout(), uplo, m, alpha, DoubleBuffer.wrap(x), 1, A, ld);
LinearAlgebra.engine.syr(layout(), uplo, m, alpha, DoubleBuffer.wrap(x), 1, A, ld);
} else {
BLAS.engine.ger(layout(), m, n, alpha, DoubleBuffer.wrap(x), 1, DoubleBuffer.wrap(y), 1, A, ld);
LinearAlgebra.engine.ger(layout(), m, n, alpha, DoubleBuffer.wrap(x), 1, DoubleBuffer.wrap(y), 1, A, ld);
}
return this;
@ -1531,12 +1531,12 @@ public class Matrix extends DMatrix {
Matrix inv = eye(n);
int[] ipiv = new int[n];
if (isSymmetric()) {
int info = LAPACK.engine.sysv(lu.layout(), uplo, n, n, lu.A, lu.ld, IntBuffer.wrap(ipiv), inv.A, inv.ld);
int info = LinearAlgebra.engine.sysv(lu.layout(), uplo, n, n, lu.A, lu.ld, IntBuffer.wrap(ipiv), inv.A, inv.ld);
if (info != 0) {
throw new ArithmeticException("SYSV fails: " + info);
}
} else {
int info = LAPACK.engine.gesv(lu.layout(), n, n, lu.A, lu.ld, IntBuffer.wrap(ipiv), inv.A, inv.ld);
int info = LinearAlgebra.engine.gesv(lu.layout(), n, n, lu.A, lu.ld, IntBuffer.wrap(ipiv), inv.A, inv.ld);
if (info != 0) {
throw new ArithmeticException("GESV fails: " + info);
}
@ -1561,15 +1561,15 @@ public class Matrix extends DMatrix {
if (uplo != null) {
if (diag != null) {
if (alpha == 1.0 && beta == 0.0 && x == y) {
BLAS.engine.trmv(layout(), uplo, trans, diag, m, A, ld, y, 1);
LinearAlgebra.engine.trmv(layout(), uplo, trans, diag, m, A, ld, y, 1);
} else {
BLAS.engine.gemv(layout(), trans, m, n, alpha, A, ld, x, 1, beta, y, 1);
LinearAlgebra.engine.gemv(layout(), trans, m, n, alpha, A, ld, x, 1, beta, y, 1);
}
} else {
BLAS.engine.symv(layout(), uplo, m, alpha, A, ld, x, 1, beta, y, 1);
LinearAlgebra.engine.symv(layout(), uplo, m, alpha, A, ld, x, 1, beta, y, 1);
}
} else {
BLAS.engine.gemv(layout(), trans, m, n, alpha, A, ld, x, 1, beta, y, 1);
LinearAlgebra.engine.gemv(layout(), trans, m, n, alpha, A, ld, x, 1, beta, y, 1);
}
}
@ -1613,15 +1613,15 @@ public class Matrix extends DMatrix {
*/
public void mm(Transpose transA, Transpose transB, double alpha, Matrix B, double beta, Matrix C) {
if (isSymmetric() && transB == NO_TRANSPOSE && B.layout() == C.layout()) {
BLAS.engine.symm(C.layout(), LEFT, uplo, C.m, C.n, alpha, A, ld, B.A, B.ld, beta, C.A, C.ld);
LinearAlgebra.engine.symm(C.layout(), LEFT, uplo, C.m, C.n, alpha, A, ld, B.A, B.ld, beta, C.A, C.ld);
} else if (B.isSymmetric() && transA == NO_TRANSPOSE && layout() == C.layout()) {
BLAS.engine.symm(C.layout(), RIGHT, B.uplo, C.m, C.n, alpha, B.A, B.ld, A, ld, beta, C.A, C.ld);
LinearAlgebra.engine.symm(C.layout(), RIGHT, B.uplo, C.m, C.n, alpha, B.A, B.ld, A, ld, beta, C.A, C.ld);
} else {
if (C.layout() != layout()) transA = flip(transA);
if (C.layout() != B.layout()) transB = flip(transB);
int k = transA == NO_TRANSPOSE ? n : m;
BLAS.engine.gemm(layout(), transA, transB, C.m, C.n, k, alpha, A, ld, B.A, B.ld, beta, C.A, C.ld);
LinearAlgebra.engine.gemm(layout(), transA, transB, C.m, C.n, k, alpha, A, ld, B.A, B.ld, beta, C.A, C.ld);
}
}
@ -1759,7 +1759,7 @@ public class Matrix extends DMatrix {
public LU lu(boolean overwrite) {
Matrix lu = overwrite ? this : clone();
int[] ipiv = new int[Math.min(m, n)];
int info = LAPACK.engine.getrf(lu.layout(), lu.m, lu.n, lu.A, lu.ld, IntBuffer.wrap(ipiv));
int info = LinearAlgebra.engine.getrf(lu.layout(), lu.m, lu.n, lu.A, lu.ld, IntBuffer.wrap(ipiv));
if (info < 0) {
logger.severe(String.format("LAPACK GETRF error code: {%d}", info));
throw new ArithmeticException("LAPACK GETRF error code: " + info);
@ -1791,7 +1791,7 @@ public class Matrix extends DMatrix {
}
Matrix lu = overwrite ? this : clone();
int info = LAPACK.engine.potrf(lu.layout(), lu.uplo, lu.n, lu.A, lu.ld);
int info = LinearAlgebra.engine.potrf(lu.layout(), lu.uplo, lu.n, lu.A, lu.ld);
if (info != 0) {
logger.severe(String.format("LAPACK GETRF error code: {%d}", info));
throw new ArithmeticException("LAPACK GETRF error code: " + info);
@ -1817,7 +1817,7 @@ public class Matrix extends DMatrix {
public QR qr(boolean overwrite) {
Matrix qr = overwrite ? this : clone();
double[] tau = new double[Math.min(m, n)];
int info = LAPACK.engine.geqrf(qr.layout(), qr.m, qr.n, qr.A, qr.ld, DoubleBuffer.wrap(tau));
int info = LinearAlgebra.engine.geqrf(qr.layout(), qr.m, qr.n, qr.A, qr.ld, DoubleBuffer.wrap(tau));
if (info != 0) {
logger.severe(String.format("LAPACK GEQRF error code: {%d}", info));
throw new ArithmeticException("LAPACK GEQRF error code: " + info);
@ -1872,7 +1872,7 @@ public class Matrix extends DMatrix {
Matrix U = new Matrix(m, k);
Matrix VT = new Matrix(k, n);
int info = LAPACK.engine.gesdd(W.layout(), SVDJob.COMPACT, W.m, W.n, W.A, W.ld, DoubleBuffer.wrap(s), U.A, U.ld, VT.A, VT.ld);
int info = LinearAlgebra.engine.gesdd(W.layout(), SVDJob.COMPACT, W.m, W.n, W.A, W.ld, DoubleBuffer.wrap(s), U.A, U.ld, VT.A, VT.ld);
if (info != 0) {
logger.severe(String.format("LAPACK GESDD error code: {%s}", info));
throw new ArithmeticException("LAPACK GESDD error code: " + info);
@ -1883,7 +1883,7 @@ public class Matrix extends DMatrix {
Matrix U = new Matrix(1, 1);
Matrix VT = new Matrix(1, 1);
int info = LAPACK.engine.gesdd(W.layout(), SVDJob.NO_VECTORS, W.m, W.n, W.A, W.ld, DoubleBuffer.wrap(s), U.A, U.ld, VT.A, VT.ld);
int info = LinearAlgebra.engine.gesdd(W.layout(), SVDJob.NO_VECTORS, W.m, W.n, W.A, W.ld, DoubleBuffer.wrap(s), U.A, U.ld, VT.A, VT.ld);
if (info != 0) {
logger.severe(String.format("LAPACK GESDD error code: {}", info));
throw new ArithmeticException("LAPACK GESDD error code: " + info);
@ -1929,7 +1929,7 @@ public class Matrix extends DMatrix {
Matrix eig = overwrite ? this : clone();
if (isSymmetric()) {
double[] w = new double[n];
int info = LAPACK.engine.syevd(eig.layout(), vr ? EVDJob.VECTORS : EVDJob.NO_VECTORS, eig.uplo, n, eig.A, eig.ld, DoubleBuffer.wrap(w));
int info = LinearAlgebra.engine.syevd(eig.layout(), vr ? EVDJob.VECTORS : EVDJob.NO_VECTORS, eig.uplo, n, eig.A, eig.ld, DoubleBuffer.wrap(w));
if (info != 0) {
logger.severe(String.format("LAPACK SYEV error code: {%d}", info));
throw new ArithmeticException("LAPACK SYEV error code: " + info);
@ -1940,7 +1940,7 @@ public class Matrix extends DMatrix {
double[] wi = new double[n];
Matrix Vl = vl ? new Matrix(n, n) : new Matrix(1, 1);
Matrix Vr = vr ? new Matrix(n, n) : new Matrix(1, 1);
int info = LAPACK.engine.geev(eig.layout(), vl ? EVDJob.VECTORS : EVDJob.NO_VECTORS, vr ? EVDJob.VECTORS : EVDJob.NO_VECTORS, n, eig.A, eig.ld, DoubleBuffer.wrap(wr), DoubleBuffer.wrap(wi), Vl.A, Vl.ld, Vr.A, Vr.ld);
int info = LinearAlgebra.engine.geev(eig.layout(), vl ? EVDJob.VECTORS : EVDJob.NO_VECTORS, vr ? EVDJob.VECTORS : EVDJob.NO_VECTORS, n, eig.A, eig.ld, DoubleBuffer.wrap(wr), DoubleBuffer.wrap(wi), Vl.A, Vl.ld, Vr.A, Vr.ld);
if (info != 0) {
logger.severe(String.format("LAPACK GEEV error code: {%d}", info));
throw new ArithmeticException("LAPACK GEEV error code: " + info);
@ -2536,7 +2536,7 @@ public class Matrix extends DMatrix {
throw new RuntimeException("The matrix is singular.");
}
int ret = LAPACK.engine.getrs(lu.layout(), NO_TRANSPOSE, lu.n, B.n, lu.A, lu.ld, IntBuffer.wrap(ipiv), B.A, B.ld);
int ret = LinearAlgebra.engine.getrs(lu.layout(), NO_TRANSPOSE, lu.n, B.n, lu.A, lu.ld, IntBuffer.wrap(ipiv), B.A, B.ld);
if (ret != 0) {
logger.severe(String.format("LAPACK GETRS error code: {%d}", ret));
throw new ArithmeticException("LAPACK GETRS error code: " + ret);
@ -2645,7 +2645,7 @@ public class Matrix extends DMatrix {
throw new IllegalArgumentException(String.format("Row dimensions do not agree: A is %d x %d, but B is %d x %d", lu.m, lu.n, B.m, B.n));
}
int info = LAPACK.engine.potrs(lu.layout(), lu.uplo, lu.n, B.n, lu.A, lu.ld, B.A, B.ld);
int info = LinearAlgebra.engine.potrs(lu.layout(), lu.uplo, lu.n, B.n, lu.A, lu.ld, B.A, B.ld);
if (info != 0) {
logger.severe(String.format("LAPACK POTRS error code: {%d}", info));
throw new ArithmeticException("LAPACK POTRS error code: " + info);
@ -2727,7 +2727,7 @@ public class Matrix extends DMatrix {
int n = qr.n;
int k = Math.min(m, n);
Matrix Q = qr.clone();
int info = LAPACK.engine.orgqr(qr.layout(), m, n, k, Q.A, qr.ld, DoubleBuffer.wrap(tau));
int info = LinearAlgebra.engine.orgqr(qr.layout(), m, n, k, Q.A, qr.ld, DoubleBuffer.wrap(tau));
if (info != 0) {
logger.severe(String.format("LAPACK ORGRQ error code: {%d}", info));
throw new ArithmeticException("LAPACK ORGRQ error code: " + info);
@ -2768,13 +2768,13 @@ public class Matrix extends DMatrix {
int n = qr.n;
int k = Math.min(m, n);
int info = LAPACK.engine.ormqr(qr.layout(), LEFT, TRANSPOSE, B.nrows(), B.ncols(), k, qr.A, qr.ld, DoubleBuffer.wrap(tau), B.A, B.ld);
int info = LinearAlgebra.engine.ormqr(qr.layout(), LEFT, TRANSPOSE, B.nrows(), B.ncols(), k, qr.A, qr.ld, DoubleBuffer.wrap(tau), B.A, B.ld);
if (info != 0) {
logger.severe(String.format("LAPACK ORMQR error code: {%d}", info));
throw new IllegalArgumentException("LAPACK ORMQR error code: " + info);
}
info = LAPACK.engine.trtrs(qr.layout(), UPPER, NO_TRANSPOSE, NON_UNIT, qr.n, B.n, qr.A, qr.ld, B.A, B.ld);
info = LinearAlgebra.engine.trtrs(qr.layout(), UPPER, NO_TRANSPOSE, NON_UNIT, qr.n, B.n, qr.A, qr.ld, B.A, B.ld);
if (info != 0) {
logger.severe(String.format("LAPACK TRTRS error code: {%d}", info));

View File

@ -199,14 +199,14 @@ public class SymmMatrix extends DMatrix {
@Override
public void mv(Transpose trans, double alpha, double[] x, double beta, double[] y) {
BLAS.engine.spmv(layout(), uplo, n, alpha, AP, x, 1, beta, y, 1);
LinearAlgebra.engine.spmv(layout(), uplo, n, alpha, AP, x, 1, beta, y, 1);
}
@Override
public void mv(double[] work, int inputOffset, int outputOffset) {
DoubleBuffer xb = DoubleBuffer.wrap(work, inputOffset, n);
DoubleBuffer yb = DoubleBuffer.wrap(work, outputOffset, n);
BLAS.engine.spmv(layout(), uplo, n, 1.0f, DoubleBuffer.wrap(AP), xb, 1, 0.0f, yb, 1);
LinearAlgebra.engine.spmv(layout(), uplo, n, 1.0f, DoubleBuffer.wrap(AP), xb, 1, 0.0f, yb, 1);
}
@Override
@ -221,7 +221,7 @@ public class SymmMatrix extends DMatrix {
public BunchKaufman bk() {
SymmMatrix lu = clone();
int[] ipiv = new int[n];
int info = LAPACK.engine.sptrf(lu.layout(), lu.uplo, lu.n, lu.AP, ipiv);
int info = LinearAlgebra.engine.sptrf(lu.layout(), lu.uplo, lu.n, lu.AP, ipiv);
if (info < 0) {
logger.severe(String.format("LAPACK SPTRF error code: {%d}", info));
throw new ArithmeticException("LAPACK SPTRF error code: " + info);
@ -242,7 +242,7 @@ public class SymmMatrix extends DMatrix {
}
SymmMatrix lu = clone();
int info = LAPACK.engine.pptrf(lu.layout(), lu.uplo, lu.n, lu.AP);
int info = LinearAlgebra.engine.pptrf(lu.layout(), lu.uplo, lu.n, lu.AP);
if (info != 0) {
logger.severe(String.format("LAPACK PPTRF error code: {%d}", info));
throw new ArithmeticException("LAPACK PPTRF error code: " + info);
@ -366,7 +366,7 @@ public class SymmMatrix extends DMatrix {
throw new RuntimeException("The matrix is singular.");
}
int ret = LAPACK.engine.sptrs(lu.layout(), lu.uplo, lu.n, B.n, DoubleBuffer.wrap(lu.AP), IntBuffer.wrap(ipiv), B.A, B.ld);
int ret = LinearAlgebra.engine.sptrs(lu.layout(), lu.uplo, lu.n, B.n, DoubleBuffer.wrap(lu.AP), IntBuffer.wrap(ipiv), B.A, B.ld);
if (ret != 0) {
logger.severe(String.format("LAPACK GETRS error code: {%d}", ret));
throw new ArithmeticException("LAPACK GETRS error code: " + ret);
@ -474,7 +474,7 @@ public class SymmMatrix extends DMatrix {
throw new IllegalArgumentException(String.format("Row dimensions do not agree: A is %d x %d, but B is %d x %d", lu.n, lu.n, B.m, B.n));
}
int info = LAPACK.engine.pptrs(lu.layout(), lu.uplo, lu.n, B.n, DoubleBuffer.wrap(lu.AP), B.A, B.ld);
int info = LinearAlgebra.engine.pptrs(lu.layout(), lu.uplo, lu.n, B.n, DoubleBuffer.wrap(lu.AP), B.A, B.ld);
if (info != 0) {
logger.severe(String.format("LAPACK POTRS error code: {%d}", info));
throw new ArithmeticException("LAPACK POTRS error code: " + info);

View File

@ -32,7 +32,7 @@ import static org.bytedeco.openblas.global.openblas.LAPACKE_dorgrq;
*
* @author Haifeng Li
*/
public class MKL implements BLAS, LAPACK {
public class MKL extends LinearAlgebra {
@Override
public double asum(int n, double[] x, int incx) {
return cblas_dasum(n, x, incx);