2021-07-19 14:15:52 +08:00

43 lines
1.3 KiB
Python

"""
multiarray module
"""
from .fromnumeric import nonzero
from ._exceptions import AxisError
__all__ = ['normalize_axis_index','where']
def normalize_axis_index(axis, ndim, msg_prefix=None):
"""
Normalizes an axis index, `axis`, such that is a valid positive index into
the shape of array with `ndim` dimensions. Raises an AxisError with an
appropriate message if this is not possible.
:param axis: (*int*) The un-normalized index of the axis. Can be negative.
:param ndim: (*int*) The number of dimensions of the array that `axis` should be normalized
against.
:param msg_prefix: (*str*) A prefix to put before the message, typically the name of the argument.
:return: (*int*) The normalized axis index, such that `0 <= normalized_axis < ndim`.
"""
if 0 <= axis < ndim:
return axis
elif axis < 0:
axis = ndim + axis
if 0 <= axis < ndim:
return axis
else:
raise AxisError(axis, ndim, msg_prefix)
else:
raise AxisError(axis, ndim, msg_prefix)
def where(condition):
'''
Return elements, either from x or y, depending on condition.
If only condition is given, return condition.nonzero().
:param condition: (*array_like*) Input array.
:returns: (*tuple*) Indices of elements that are non-zero.
'''
return nonzero(condition)