Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: add support for a tuple of axes in expand_dims #760

Open
izaid opened this issue Mar 10, 2024 · 3 comments
Open

RFC: add support for a tuple of axes in expand_dims #760

izaid opened this issue Mar 10, 2024 · 3 comments
Labels
Needs Discussion Needs further discussion. RFC Request for comments. Feature requests and proposed changes. topic: Manipulation Array manipulation and transformation.
Projects
Milestone

Comments

@izaid
Copy link

izaid commented Mar 10, 2024

Hello all! I raised this issue on array-api-compat earlier (data-apis/array-api-compat#105), but I think it might be more properly directed here.

In the array API, expand_dims supports only a single axis (https://data-apis.org/array-api/latest/API_specification/generated/array_api.expand_dims.html) as opposed to a tuple of axes. This is different to NumPy, CuPy, and Jax, which support a tuple of axes. PyTorch, however, supports only a single axis. I don't know the justification for why the array API supports only a single axis as opposed to a tuple, but what it means is that expand_dims no longer works in many places when adopting the array API.

In practice, expand_dims is just a light wrapper for reshape, see https://github.com/numpy/numpy/blob/3b246c6488cf246d488bbe5726ca58dc26b6ea74/numpy/lib/_shape_base_impl.py#L594. But it's not great to force users to write their own version of expand_dims in every library now. Is the array API willing to update expand_dims to support a tuple of axes? If not, and if expand_dims will only support a single axis going forward, that effectively makes all users of expand_dims copy and paste the NumPy implementation.

@lucascolley Pointed out to me that when expand_dims was added to the array API, only NumPy supported a tuple of axes. See #42. That was 4 years ago and the situation has changed, as above.

@asmeurer
Copy link
Member

asmeurer commented Mar 11, 2024

Seems tuple support was omitted because torch doesn't support it #42. I found a few feature requests for it for torch.unsqueeze (the PyTorch equivalent to expand_dims) pytorch/pytorch#30702, pytorch/pytorch#4692 (comment). Seems it was intentionally omitted due to the ambiguity that arises from mixing negative and positive indices.

I agree this ambiguity is a potential concern. If we standardize this, we should somehow only require a subset of behavior that omits this ambiguity, e.g., by leaving the mixing of negative and nonnegative indices unspecified.

Consider for example:

>>> np.expand_dims(np.empty((2,)), (1, -1)).shape
(2, 1, 1)

The resulting shape has 1 in positions 1 and -1, but a result shape of (2, 1) would also satisfy this. I suppose one could argue that exactly len(axes) dimensions should be added.

But also consider

>>> np.expand_dims(np.empty((2, 3, 4, 5)), (3, -3)).shape
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/aaronmeurer/miniconda3/envs/array-apis/lib/python3.11/site-packages/numpy/lib/shape_base.py", line 597, in expand_dims
    axis = normalize_axis_tuple(axis, out_ndim)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronmeurer/miniconda3/envs/array-apis/lib/python3.11/site-packages/numpy/core/numeric.py", line 1385, in normalize_axis_tuple
    raise ValueError('repeated axis')
ValueError: repeated axis

There's no way to insert 1 dimensions into (2, 3, 4, 5) so that they appear at indices 3 and -3.

Here's a small proof. There's no length list where you can remove indices 3 and -3 and result in a list of length 4
>>> def remove_indices(n, idxes):
...     """Return range(n) with `idxes` indices removed"""
...     x = list(range(n))
...     vals = [x[i] for i in idxes]
...     for v in vals:
...         try:
...             x.remove(v)
...         except ValueError: # Already removed
...             pass
...     return x
>>> [remove_indices(n, (-3, 3)) for n in range(4, 10)]
[[0, 2], [0, 1, 4], [0, 1, 2, 4, 5], [0, 1, 2, 5, 6], [0, 1, 2, 4, 6, 7], [0, 1, 2, 4, 5, 7, 8]]
>>> [len(remove_indices(n, (-3, 3))) for n in range(4, 10)]
[2, 3, 5, 5, 6, 7]

At the same time, if the goal of expand_dims is for the axes to refer to the dimensions after unsqueezing/expanding, then it's not exactly trivial to do it as a sequence of expand_dims, because if you apply the expansion in the wrong order you will break the position of previous dimensions (the correct logic is not hard, but it's the sort of thing that's easy to get wrong). So I think there is value in having native support for multiple axes.

@kgryte kgryte changed the title expand_dims for tuple of axes RFC: add support for a tuple of axes in expand_dims Apr 4, 2024
@kgryte kgryte added RFC Request for comments. Feature requests and proposed changes. topic: Manipulation Array manipulation and transformation. Needs Discussion Needs further discussion. labels Apr 4, 2024
@kgryte kgryte added this to the v2024 milestone Apr 4, 2024
@kgryte kgryte added this to Stage 0 in Proposals Apr 4, 2024
@Micky774
Copy link
Contributor

Regarding removing ambiguity, I think it would suffice to impose an ordering in which to prefer expanding dims right? For example, if we specify "negative indices get resolved first" then your borrowing your example above could be resolved as

x = np.empty((2, 3, 4, 5))
xp.expand_dims(x., (3, -3)) == np.expand_dims(np.expand_dims(x, -3), 3)

so that the final output shape is (2, 3, 1, 1, 4, 5), which seems reasonable.

Still, I'm not sure if it is worth it since in the first place users could do it in a two-step expansion (albeit with some more thought), and the resolution order (+ or - indices first?) is rather arbitrary.

@asmeurer
Copy link
Member

When you do repeated expand_dims, the inserted dimensions in the final shape won't necessarily be in the indices you initially specified (that's the whole point of this feature request, that you need a way to do them all at once). (2, 3, 1, 1, 4, 5) has 1s at indices 2 and -3 (remember 0-based indexing), because the 1 that was at index -3 got shifted over.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Needs Discussion Needs further discussion. RFC Request for comments. Feature requests and proposed changes. topic: Manipulation Array manipulation and transformation.
Projects
Proposals
Stage 0
Development

No branches or pull requests

4 participants