from .. import core as np from ..lib import expand_dims, r_, unique from ..lib._util import prod as _prod from ..linalg import lstsq __all__ = ['detrend'] def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=False): """ Remove linear trend along axis from data. Parameters ---------- data : array_like The input data. axis : int, optional The axis along which to detrend the data. By default this is the last axis (-1). type : {'linear', 'constant'}, optional The type of detrending. If ``type == 'linear'`` (default), the result of a linear least-squares fit to `data` is subtracted from `data`. If ``type == 'constant'``, only the mean of `data` is subtracted. bp : array_like of ints, optional A sequence of break points. If given, an individual linear fit is performed for each part of `data` between two break points. Break points are specified as indices into `data`. This parameter only has an effect when ``type == 'linear'``. overwrite_data : bool, optional If True, perform in place detrending and avoid a copy. Default is False Returns ------- ret : ndarray The detrended input data. Examples -------- >>> from mipylib.numeric import signal >>> randgen = np.random.RandomState(9) >>> npoints = 1000 >>> noise = randgen.randn(npoints) >>> x = 3 + 2*np.linspace(0, 1, npoints) + noise >>> (signal.detrend(x) - noise).max() < 0.01 True """ if type not in ['linear', 'l', 'constant', 'c']: raise ValueError("Trend type must be 'linear' or 'constant'.") data = np.asarray(data) dtype = data.dtype if type in ['constant', 'c']: ret = data - expand_dims(np.mean(data, axis), axis) return ret else: dshape = data.shape N = dshape[axis] bp = np.sort(unique(r_[0, bp, N])) if np.any(bp > N): raise ValueError("Breakpoints must be less than length " "of data along given axis.") Nreg = len(bp) - 1 # Restructure data so that axis is along first dimension and # all other dimensions are collapsed into second dimension rnk = len(dshape) if axis < 0: axis = axis + rnk newdims = r_[axis, 0:axis, axis + 1:rnk] newdata = np.reshape(np.transpose(data, tuple(newdims)), (N, _prod(dshape) // N)) if not overwrite_data: newdata = newdata.copy() # make sure we have a copy if newdata.dtype.char not in 'dfDF': newdata = newdata.astype(dtype) # Find leastsq fit and remove it for each piece for m in range(Nreg): Npts = bp[m + 1] - bp[m] A = np.ones((Npts, 2), dtype) A[:, 0] = np.arange(1, Npts + 1) * 1.0 / Npts sl = slice(bp[m], bp[m + 1]) coef, resids = lstsq(A, newdata[sl]) newdata[sl] = newdata[sl] - np.dot(A, coef) # Put data back in original shape. tdshape = np.take(dshape, newdims, 0) ret = np.reshape(newdata, tuple(tdshape)) vals = list(range(1, rnk)) olddims = vals[:axis] + [0] + vals[axis:] ret = np.transpose(ret, tuple(olddims)) return ret