2021-03-10 17:19:43 +08:00

76 lines
2.2 KiB
Python

from ..core.numeric import asanyarray, normalize_axis_tuple
__all__ = [
'expand_dims'
]
def expand_dims(a, axis):
"""
Expand the shape of an array.
Insert a new axis that will appear at the `axis` position in the expanded
array shape.
Parameters
----------
a : array_like
Input array.
axis : int or tuple of ints
Position in the expanded axes where the new axis (or axes) is placed.
.. deprecated:: 1.13.0
Passing an axis where ``axis > a.ndim`` will be treated as
``axis == a.ndim``, and passing ``axis < -a.ndim - 1`` will
be treated as ``axis == 0``. This behavior is deprecated.
.. versionchanged:: 1.18.0
A tuple of axes is now supported. Out of range axes as
described above are now forbidden and raise an `AxisError`.
Returns
-------
result : ndarray
View of `a` with the number of dimensions increased.
See Also
--------
squeeze : The inverse operation, removing singleton dimensions
reshape : Insert, remove, and combine dimensions, and resize existing ones
doc.indexing, atleast_1d, atleast_2d, atleast_3d
Examples
--------
>>> x = np.array([1, 2])
>>> x.shape
(2,)
The following is equivalent to ``x[np.newaxis, :]`` or ``x[np.newaxis]``:
>>> y = np.expand_dims(x, axis=0)
>>> y
array([[1, 2]])
>>> y.shape
(1, 2)
The following is equivalent to ``x[:, np.newaxis]``:
>>> y = np.expand_dims(x, axis=1)
>>> y
array([[1],
[2]])
>>> y.shape
(2, 1)
``axis`` may also be a tuple:
>>> y = np.expand_dims(x, axis=(0, 1))
>>> y
array([[[1, 2]]])
>>> y = np.expand_dims(x, axis=(2, 0))
>>> y
array([[[1],
[2]]])
Note that some examples may use ``None`` instead of ``np.newaxis``. These
are the same objects:
>>> np.newaxis is None
True
"""
a = asanyarray(a)
if type(axis) not in (tuple, list):
axis = (axis,)
out_ndim = len(axis) + a.ndim
axis = normalize_axis_tuple(axis, out_ndim)
shape_it = iter(a.shape)
shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)]
return a.reshape(shape)