mirror of
https://github.com/meteoinfo/MeteoInfo.git
synced 2025-12-08 20:36:05 +00:00
86 lines
3.3 KiB
Python
86 lines
3.3 KiB
Python
|
|
from .. import core as np
|
|
from ..lib import expand_dims, r_, unique
|
|
from ..lib._util import prod as _prod
|
|
from ..linalg import lstsq
|
|
|
|
__all__ = ['detrend']
|
|
|
|
def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=False):
|
|
"""
|
|
Remove linear trend along axis from data.
|
|
Parameters
|
|
----------
|
|
data : array_like
|
|
The input data.
|
|
axis : int, optional
|
|
The axis along which to detrend the data. By default this is the
|
|
last axis (-1).
|
|
type : {'linear', 'constant'}, optional
|
|
The type of detrending. If ``type == 'linear'`` (default),
|
|
the result of a linear least-squares fit to `data` is subtracted
|
|
from `data`.
|
|
If ``type == 'constant'``, only the mean of `data` is subtracted.
|
|
bp : array_like of ints, optional
|
|
A sequence of break points. If given, an individual linear fit is
|
|
performed for each part of `data` between two break points.
|
|
Break points are specified as indices into `data`. This parameter
|
|
only has an effect when ``type == 'linear'``.
|
|
overwrite_data : bool, optional
|
|
If True, perform in place detrending and avoid a copy. Default is False
|
|
Returns
|
|
-------
|
|
ret : ndarray
|
|
The detrended input data.
|
|
Examples
|
|
--------
|
|
>>> from mipylib.numeric import signal
|
|
>>> randgen = np.random.RandomState(9)
|
|
>>> npoints = 1000
|
|
>>> noise = randgen.randn(npoints)
|
|
>>> x = 3 + 2*np.linspace(0, 1, npoints) + noise
|
|
>>> (signal.detrend(x) - noise).max() < 0.01
|
|
True
|
|
"""
|
|
if type not in ['linear', 'l', 'constant', 'c']:
|
|
raise ValueError("Trend type must be 'linear' or 'constant'.")
|
|
data = np.asarray(data)
|
|
dtype = data.dtype
|
|
if type in ['constant', 'c']:
|
|
ret = data - expand_dims(np.mean(data, axis), axis)
|
|
return ret
|
|
else:
|
|
dshape = data.shape
|
|
N = dshape[axis]
|
|
bp = np.sort(unique(r_[0, bp, N]))
|
|
if np.any(bp > N):
|
|
raise ValueError("Breakpoints must be less than length "
|
|
"of data along given axis.")
|
|
Nreg = len(bp) - 1
|
|
# Restructure data so that axis is along first dimension and
|
|
# all other dimensions are collapsed into second dimension
|
|
rnk = len(dshape)
|
|
if axis < 0:
|
|
axis = axis + rnk
|
|
newdims = r_[axis, 0:axis, axis + 1:rnk]
|
|
newdata = np.reshape(np.transpose(data, tuple(newdims)),
|
|
(N, _prod(dshape) // N))
|
|
if not overwrite_data:
|
|
newdata = newdata.copy() # make sure we have a copy
|
|
if newdata.dtype.char not in 'dfDF':
|
|
newdata = newdata.astype(dtype)
|
|
# Find leastsq fit and remove it for each piece
|
|
for m in range(Nreg):
|
|
Npts = bp[m + 1] - bp[m]
|
|
A = np.ones((Npts, 2), dtype)
|
|
A[:, 0] = np.arange(1, Npts + 1) * 1.0 / Npts
|
|
sl = slice(bp[m], bp[m + 1])
|
|
coef, resids = lstsq(A, newdata[sl])
|
|
newdata[sl] = newdata[sl] - np.dot(A, coef)
|
|
# Put data back in original shape.
|
|
tdshape = np.take(dshape, newdims, 0)
|
|
ret = np.reshape(newdata, tuple(tdshape))
|
|
vals = list(range(1, rnk))
|
|
olddims = vals[:axis] + [0] + vals[axis:]
|
|
ret = np.transpose(ret, tuple(olddims))
|
|
return ret |