86 lines
3.3 KiB
Python

from .. import core as np
from ..lib import expand_dims, r_, unique
from ..lib._util import prod as _prod
from ..linalg import lstsq
__all__ = ['detrend']
def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=False):
"""
Remove linear trend along axis from data.
Parameters
----------
data : array_like
The input data.
axis : int, optional
The axis along which to detrend the data. By default this is the
last axis (-1).
type : {'linear', 'constant'}, optional
The type of detrending. If ``type == 'linear'`` (default),
the result of a linear least-squares fit to `data` is subtracted
from `data`.
If ``type == 'constant'``, only the mean of `data` is subtracted.
bp : array_like of ints, optional
A sequence of break points. If given, an individual linear fit is
performed for each part of `data` between two break points.
Break points are specified as indices into `data`. This parameter
only has an effect when ``type == 'linear'``.
overwrite_data : bool, optional
If True, perform in place detrending and avoid a copy. Default is False
Returns
-------
ret : ndarray
The detrended input data.
Examples
--------
>>> from mipylib.numeric import signal
>>> randgen = np.random.RandomState(9)
>>> npoints = 1000
>>> noise = randgen.randn(npoints)
>>> x = 3 + 2*np.linspace(0, 1, npoints) + noise
>>> (signal.detrend(x) - noise).max() < 0.01
True
"""
if type not in ['linear', 'l', 'constant', 'c']:
raise ValueError("Trend type must be 'linear' or 'constant'.")
data = np.asarray(data)
dtype = data.dtype
if type in ['constant', 'c']:
ret = data - expand_dims(np.mean(data, axis), axis)
return ret
else:
dshape = data.shape
N = dshape[axis]
bp = np.sort(unique(r_[0, bp, N]))
if np.any(bp > N):
raise ValueError("Breakpoints must be less than length "
"of data along given axis.")
Nreg = len(bp) - 1
# Restructure data so that axis is along first dimension and
# all other dimensions are collapsed into second dimension
rnk = len(dshape)
if axis < 0:
axis = axis + rnk
newdims = r_[axis, 0:axis, axis + 1:rnk]
newdata = np.reshape(np.transpose(data, tuple(newdims)),
(N, _prod(dshape) // N))
if not overwrite_data:
newdata = newdata.copy() # make sure we have a copy
if newdata.dtype.char not in 'dfDF':
newdata = newdata.astype(dtype)
# Find leastsq fit and remove it for each piece
for m in range(Nreg):
Npts = bp[m + 1] - bp[m]
A = np.ones((Npts, 2), dtype)
A[:, 0] = np.arange(1, Npts + 1) * 1.0 / Npts
sl = slice(bp[m], bp[m + 1])
coef, resids = lstsq(A, newdata[sl])
newdata[sl] = newdata[sl] - np.dot(A, coef)
# Put data back in original shape.
tdshape = np.take(dshape, newdims, 0)
ret = np.reshape(newdata, tuple(tdshape))
vals = list(range(1, rnk))
olddims = vals[:axis] + [0] + vals[axis:]
ret = np.transpose(ret, tuple(olddims))
return ret