update fft module to support fftn, ifftn functions

This commit is contained in:
wyq 2025-10-08 16:36:51 +08:00
parent f403ee0d40
commit cd03155677
5 changed files with 361 additions and 46 deletions

View File

@ -1,12 +1,10 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<MeteoInfo File="milconfig.xml" Type="configurefile">
<Path OpenPath="D:\Working\MIScript\Jython\mis\plot_types\3d\3d_earth">
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\funny"/>
<Path OpenPath="D:\Working\MIScript\Jython\mis\common_math\fft">
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\dataframe"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\io"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\io\matlab"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\contour"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math\stats"/>
@ -16,15 +14,19 @@
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\plot_types\3d\3d_earth"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math"/>
<RecentFolder Folder="D:\Working\MIScript\Jython\mis\common_math\fft"/>
</Path>
<File>
<OpenedFiles>
<OpenedFile File="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl\surf\shitshape_ice_cream_cone.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\plot_types\3d\3d_earth\plot_cuace_earth_isosurface_streamline_1.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\common_math\fft\fftn_1.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\common_math\fft\ifftn_1.py"/>
<OpenedFile File="D:\Working\MIScript\Jython\mis\common_math\fft\ifftn_2.py"/>
</OpenedFiles>
<RecentFiles>
<RecentFile File="D:\Working\MIScript\Jython\mis\plot_types\3d\jogl\surf\shitshape_ice_cream_cone.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\plot_types\3d\3d_earth\plot_cuace_earth_isosurface_streamline_1.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\common_math\fft\fftn_1.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\common_math\fft\ifftn_1.py"/>
<RecentFile File="D:\Working\MIScript\Jython\mis\common_math\fft\ifftn_2.py"/>
</RecentFiles>
</File>
<Font>

View File

@ -1,11 +1,13 @@
from org.meteoinfo.math.transform import FastFourierTransform, FastFourierTransform2D
from org.meteoinfo.math.transform import FastFourierTransform, FastFourierTransform2D, \
FastFourierTransformND
from .. import core as _nx
__all__ = ['fft', 'ifft', 'fft2', 'ifft2']
__all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn']
def fft(a):
def fft(a, axis=-1, norm=None):
"""
Compute the one-dimensional discrete Fourier Transform.
@ -17,6 +19,13 @@ def fft(a):
----------
a : array_like
Input array, can be complex.
axis : int, optional
Axis over which to compute the FFT. If not given, the last axis is
used.
norm : {"backward", "ortho", "forward"}, optional
Normalization mode. Default is "backward".
Indicates which direction of the forward/backward pair of transforms
is scaled and with what normalization factor.
Returns
-------
@ -32,19 +41,22 @@ def fft(a):
"""
a = _nx.asarray(a)
jfft = FastFourierTransform()
r = jfft.apply(a._array)
if norm is not None:
jfft.setNormalization(norm)
r = jfft.apply(a._array, axis)
return _nx.NDArray(r)
def ifft(a):
def ifft(a, axis=-1, norm=None):
"""
Compute the one-dimensional inverse discrete Fourier Transform.
This function computes the inverse of the one-dimensional *n*-point
discrete Fourier transform computed by `fft`. In other words,
``ifft(fft(a)) == a`` to within numerical accuracy.
For a general description of the algorithm and definitions,
see `numpy.fft`.
For a general description of the algorithm and definitions.
The input should be ordered in the same way as is returned by `fft`,
i.e.,
@ -53,6 +65,13 @@ def ifft(a):
----------
a : array_like
Input array, can be complex.
axis : int, optional
Axis over which to compute the inverse DFT. If not given, the last
axis is used.
norm : {"backward", "ortho", "forward"}, optional
Normalization mode. Default is "backward".
Indicates which direction of the forward/backward pair of transforms
is scaled and with what normalization factor.
Returns
-------
@ -62,11 +81,15 @@ def ifft(a):
"""
a = _nx.asarray(a)
jfft = FastFourierTransform(True)
r = jfft.apply(a._array)
if norm is not None:
jfft.setNormalization(norm)
r = jfft.apply(a._array, axis)
return _nx.NDArray(r)
def fft2(a):
def fft2(a, axes=(-2, -1), norm=None):
"""
Compute the 2-dimensional discrete Fourier Transform.
@ -79,6 +102,15 @@ def fft2(a):
----------
a : array_like
Input array, can be complex
axes : sequence of ints, optional
Axes over which to compute the FFT. If not given, the last two
axes are used. A repeated index in `axes` means the transform over
that axis is performed multiple times. A one-element sequence means
that a one-dimensional FFT is performed. Default: ``(-2, -1)``.
norm : {"backward", "ortho", "forward"}, optional
Normalization mode. Default is "backward".
Indicates which direction of the forward/backward pair of transforms
is scaled and with what normalization factor.
Returns
-------
@ -87,12 +119,16 @@ def fft2(a):
indicated by `axes`, or the last two axes if `axes` is not given.
"""
a = _nx.asarray(a)
jfft = FastFourierTransform2D()
r = jfft.apply(a._array)
jfft = FastFourierTransformND()
if norm is not None:
jfft.setNormalization(norm)
r = jfft.apply(a._array, axes)
return _nx.NDArray(r)
def ifft2(a, s=None, axes=(-2, -1), norm=None, out=None):
def ifft2(a, axes=(-2, -1), norm=None):
"""
Compute the 2-dimensional inverse discrete Fourier Transform.
@ -113,6 +149,15 @@ def ifft2(a, s=None, axes=(-2, -1), norm=None, out=None):
----------
a : array_like
Input array, can be complex.
axes : sequence of ints, optional
Axes over which to compute the FFT. If not given, the last two
axes are used. A repeated index in `axes` means the transform over
that axis is performed multiple times. A one-element sequence means
that a one-dimensional FFT is performed. Default: ``(-2, -1)``.
norm : {"backward", "ortho", "forward"}, optional
Normalization mode. Default is "backward".
Indicates which direction of the forward/backward pair of transforms
is scaled and with what normalization factor.
Returns
-------
@ -121,7 +166,101 @@ def ifft2(a, s=None, axes=(-2, -1), norm=None, out=None):
indicated by `axes`, or the last two axes if `axes` is not given.
"""
a = _nx.asarray(a)
jfft = FastFourierTransform2D(True)
r = jfft.apply(a._array)
jfft = FastFourierTransformND(True)
if norm is not None:
jfft.setNormalization(norm)
r = jfft.apply(a._array, axes)
return _nx.NDArray(r)
def fftn(a, axes=None, norm=None):
"""
Compute the N-dimensional discrete Fourier Transform.
This function computes the *N*-dimensional discrete Fourier Transform over
any number of axes in an *M*-dimensional array by means of the Fast Fourier
Transform (FFT).
Parameters
----------
a : array_like
Input array, can be complex
axes : sequence of ints, optional
Axes over which to compute the FFT. If not given, the last ``len(s)``
axes are used, or all axes if `s` is also not specified.
Repeated indices in `axes` means that the transform over that axis is
performed multiple times.
norm : {"backward", "ortho", "forward"}, optional
Normalization mode (see `numpy.fft`). Default is "backward".
Indicates which direction of the forward/backward pair of transforms
is scaled and with what normalization factor.
Returns
-------
out : complex ndarray
The truncated or zero-padded input, transformed along the axes
indicated by `axes`, or the all axes if `axes` is not given.
"""
a = _nx.asarray(a)
jfft = FastFourierTransformND()
if norm is not None:
jfft.setNormalization(norm)
if axes is None:
r = jfft.apply(a._array)
else:
r = jfft.apply(a._array, axes)
return _nx.NDArray(r)
def ifftn(a, axes=None, norm=None):
"""
Compute the n-dimensional inverse discrete Fourier Transform.
This function computes the inverse of the n-dimensional discrete Fourier
Transform over any number of axes in an M-dimensional array by means of
the Fast Fourier Transform (FFT). In other words, ``ifftn(fftn(a)) == a``
to within numerical accuracy. By default, the inverse transform is
computed over the last two axes of the input array.
The input, analogously to `ifft`, should be ordered in the same way as is
returned by `fftn`, i.e. it should have the term for zero frequency
in the low-order corner of the two axes, the positive frequency terms in
the first half of these axes, the term for the Nyquist frequency in the
middle of the axes and the negative frequency terms in the second half of
both axes, in order of decreasingly negative frequency.
Parameters
----------
a : array_like
Input array, can be complex.
axes : sequence of ints, optional
Axes over which to compute the IFFT. If not given, the last ``len(s)``
axes are used, or all axes if `s` is also not specified.
Repeated indices in `axes` means that the inverse transform over that
axis is performed multiple times.
norm : {"backward", "ortho", "forward"}, optional
Normalization mode (see `numpy.fft`). Default is "backward".
Indicates which direction of the forward/backward pair of transforms
is scaled and with what normalization factor.
Returns
-------
out : complex ndarray
The truncated or zero-padded input, transformed along the axes
indicated by `axes`, or the last two axes if `axes` is not given.
"""
a = _nx.asarray(a)
jfft = FastFourierTransformND(True)
if norm is not None:
jfft.setNormalization(norm)
if axes is None:
r = jfft.apply(a._array)
else:
r = jfft.apply(a._array, axes)
return _nx.NDArray(r)

View File

@ -55,6 +55,27 @@ public class DistributionUtil {
return r;
}
/**
* Random variates of given type.
* @param dis Distribution.
* @param size Size.
* @return Result array.
*/
public static Array rvs(ContinuousDistribution dis, List<Integer> size){
ContinuousDistribution.Sampler sampler = dis.createSampler(RandomSource.MT.create());
int n = 1;
for (int s : size) {
n = n * s;
}
double[] samples = new double[n];
for (int i = 0; i < n; i++) {
samples[i] = sampler.sample();
}
int[] shape = size.stream().mapToInt(Integer::intValue).toArray();
Array r = Array.factory(DataType.DOUBLE, shape, samples);
return r;
}
/**
* Random variates of given type.
* @param dis Distribution.

View File

@ -1,10 +1,7 @@
package org.meteoinfo.math.transform;
import org.apache.commons.numbers.core.ArithmeticUtils;
import org.meteoinfo.ndarray.Array;
import org.meteoinfo.ndarray.ArrayComplex;
import org.meteoinfo.ndarray.Complex;
import org.meteoinfo.ndarray.DataType;
import org.meteoinfo.ndarray.*;
import org.meteoinfo.ndarray.math.ArrayMath;
import org.meteoinfo.ndarray.math.ArrayUtil;
@ -84,7 +81,7 @@ public class FastFourierTransform implements ComplexTransform {
-0x1.921fb54442d18p-58, -0x1.921fb54442d18p-59, -0x1.921fb54442d18p-60 };
/** Type of DFT. */
protected final Norm normalization;
protected Norm normalization;
/** Inverse or forward. */
protected final boolean inverse;
@ -113,14 +110,38 @@ public class FastFourierTransform implements ComplexTransform {
* @param inverse Whether to perform the inverse transform.
*/
public FastFourierTransform(final boolean inverse) {
this(Norm.STD, inverse);
this(Norm.BACKWARD, inverse);
}
/**
* Constructor
*/
public FastFourierTransform() {
this(Norm.STD, false);
this(Norm.BACKWARD, false);
}
/**
* Get normalization
* @return Normalization
*/
public Norm getNormalization() {
return this.normalization;
}
/**
* Set normalization
* @param normalization Normalization
*/
public void setNormalization(Norm normalization) {
this.normalization = normalization;
}
/**
* Set normalization
* @param norm Normalization string
*/
public void setNormalization(String norm) {
this.normalization = Norm.valueOf(norm.toUpperCase());
}
/**
@ -318,11 +339,50 @@ public class FastFourierTransform implements ComplexTransform {
return apply(TransformUtils.sample(f, min, max, n));
}
/**
* Returns the transform of the specified data set.
*
* @param f Input data array
* @param axis The axis over which to compute the FFT
* @return Result array after FFT transformation
*/
public Array apply(Array f, int axis) throws InvalidRangeException {
if (f.getRank() == 1) {
return apply(f);
}
if (axis < 0) {
axis = f.getRank() + axis;
}
int[] shape = f.getShape();
Array r = Array.factory(DataType.COMPLEX, shape);
Index indexr = r.getIndex();
int[] current;
for (int i = 0; i < r.getSize(); i++) {
current = indexr.getCurrentCounter();
if (current[axis] == 0) {
List<Range> ranges = new ArrayList<>();
for (int j = 0; j < shape.length; j++) {
if (j == axis) {
ranges.add(new Range(0, shape[j] - 1, 1));
} else {
ranges.add(new Range(current[j], current[j], 1));
}
}
Array data;
data = ArrayMath.section(f, ranges).copy();
data = apply(data);
ArrayMath.setSection(r, ranges, data);
}
indexr.incr();
}
return r;
}
/**
* {@inheritDoc}
*
* @throws IllegalArgumentException if the length of the data array is
* not a power of two.
*/
@Override
public Array apply(Array f) {
@ -487,7 +547,7 @@ public class FastFourierTransform implements ComplexTransform {
final int n = dataR.length;
switch (normalization) {
case STD:
case BACKWARD:
if (inverse) {
final double scaleFactor = 1d / n;
for (int i = 0; i < n; i++) {
@ -495,18 +555,23 @@ public class FastFourierTransform implements ComplexTransform {
dataI[i] *= scaleFactor;
}
}
break;
case UNIT:
case FORWARD:
if (!inverse) {
final double scaleFactor = 1d / n;
for (int i = 0; i < n; i++) {
dataR[i] *= scaleFactor;
dataI[i] *= scaleFactor;
}
}
break;
case ORTHO:
final double scaleFactor = 1d / Math.sqrt(n);
for (int i = 0; i < n; i++) {
dataR[i] *= scaleFactor;
dataI[i] *= scaleFactor;
}
break;
default:
throw new IllegalStateException(); // Should never happen.
}
@ -592,16 +657,14 @@ public class FastFourierTransform implements ComplexTransform {
*/
public enum Norm {
/**
* Should be passed to the constructor of {@link FastFourierTransform}
* to use the <em>standard</em> normalization convention. This normalization
* convention is defined as follows
* <ul>
* <li>forward transform: \( y_n = \sum_{k = 0}^{N - 1} x_k e^{-2 \pi i n k / N} \),</li>
* <li>inverse transform: \( x_k = \frac{1}{N} \sum_{n = 0}^{N - 1} y_n e^{2 \pi i n k / N} \),</li>
* </ul>
* where \( N \) is the size of the data sample.
* meaning no normalization on the forward transforms and scaling by 1/n on the inverse
*/
STD,
BACKWARD,
/**
* meaning normalization on the forward transforms scaling by 1/n
*/
FORWARD,
/**
* Should be passed to the constructor of {@link FastFourierTransform}
@ -613,7 +676,7 @@ public class FastFourierTransform implements ComplexTransform {
* </ul>
* where \( N \) is the size of the data sample.
*/
UNIT,
ORTHO,
/**
* Not do normalization

View File

@ -0,0 +1,90 @@
package org.meteoinfo.math.transform;
import org.meteoinfo.ndarray.*;
import org.meteoinfo.ndarray.math.ArrayMath;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class FastFourierTransformND extends FastFourierTransform {
/**
* Constructor
*/
public FastFourierTransformND() {
super();
}
/**
* Constructor
* @param inverse Whether is inverse transform
*/
public FastFourierTransformND(boolean inverse) {
super(inverse);
}
@Override
public Array apply(Array f) {
List<Integer> axes = new ArrayList<>();
for (int i = f.getRank() - 1; i >= 0; i--) {
axes.add(i);
}
return apply(f, axes);
}
/**
* Apply
* @param f Input array
* @param axes The axes
* @return Array after N-D FFT transformation
*/
public Array apply(Array f, List<Integer> axes) {
f = f.copyIfView();
int[] shape = f.getShape();
Array r = Array.factory(DataType.COMPLEX, shape);
try {
FastFourierTransform fastFourierTransform = new FastFourierTransform(this.normalization, this.inverse);
int[] current;
int axisIdx = 0;
for (int axis : axes) {
if (axis < 0) {
axis = shape.length + axis;
}
Index indexr = r.getIndex();
for (int i = 0; i < r.getSize(); i++) {
current = indexr.getCurrentCounter();
if (current[axis] == 0) {
List<Range> ranges = new ArrayList<>();
for (int j = 0; j < shape.length; j++) {
if (j == axis) {
ranges.add(new Range(0, shape[j] - 1, 1));
} else {
ranges.add(new Range(current[j], current[j], 1));
}
}
Array data;
if (axisIdx == 0) {
data = ArrayMath.section(f, ranges).copy();
} else {
data = ArrayMath.section(r, ranges).copy();
}
data = fastFourierTransform.apply(data);
ArrayMath.setSection(r, ranges, data);
}
indexr.incr();
}
axisIdx += 1;
}
} catch (InvalidRangeException e) {
e.printStackTrace();
}
return r;
}
}