Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cumulative_sum behavior for 0-D inputs #797

Open
asmeurer opened this issue Apr 22, 2024 · 8 comments
Open

cumulative_sum behavior for 0-D inputs #797

asmeurer opened this issue Apr 22, 2024 · 8 comments

Comments

@asmeurer
Copy link
Member

asmeurer commented Apr 22, 2024

The standard is not clear what should happen in cumulative_sum for 0-D inputs https://data-apis.org/array-api/latest/API_specification/generated/array_api.cumulative_sum.html#cumulative-sum

Note that NumPy and PyTorch have different conventions here:

>>> import numpy as np
>>> np.cumsum(np.asarray(0))
array([0])
>>> import torch
>>> torch.cumsum(torch.asarray(0), dim=0)
tensor(0)

torch.cumsum unconditionally requires the dim argument, whereas np.cumsum defaults to computing over a flattened array if axis=None. The standard requires axis if the dimensionality is greater than 1. However, axis=0 doesn't really make sense for a 0-D array. NumPy also allows specifying axis=0 and gives the same result:

>>> np.cumsum(np.asarray(0), axis=0)
array([0])

Furthermore, there is ambiguity here on what should happen for a 0-D input when include_initial=True. The standard says:

if include_initial is True, the returned array must have the same shape as x, except the size of the axis along which to compute the cumulative sum must be N+1.

If the result should be 0-D, then clearly include_initial must do nothing, since there is no way to increase the number of elements in the result.

This doesn't seem to have been discussed in the original pull request #653 or issue #597, and I don't recall it being brought up at the consortium meetings.

My suggested behavior would be

  • The result of cumulative_sum on a 0-D input x should be a 0-D output which is the same as x (i.e., it would work just like sum(x)). This matches the behavior that cumulative_sum always returns an output with the same dimensionality as the input.
  • The include_initial flag would do nothing when the input is 0-D. One can read the existing text as already supporting this behavior, since "the axis along which to compute the cumulative sum" is vacuous.
  • The axis argument must be None when the input is 0-D or else the result is an error. This matches the usual "axis must be in the range [-ndim, ndim)" condition, which is not currently spelled out this way for cumulative_sum but is for other functions in the standard.

Alternatively, we could leave the behavior unspecified. To me the above makes sense, but this does break with current cumsum conventions. On the other hand, since the name is different, it's not a big deal for libraries to change behavior between cumsum and cumulative_sum (this is at least the approach that NumPy has taken with some of the existing renames with breaking changes).

@asmeurer
Copy link
Member Author

It would be useful if anyone is aware of any prior discussions about this in NumPy, PyTorch, or other libraries. It doesn't seem to have been mentioned at numpy/numpy#6044, but I didn't look any further in the NumPy tracker.

@asmeurer
Copy link
Member Author

Another consideration: diff (not yet standardized) should be the inverse of cumulative_sum(include_initial=True) (and cumulative_sum the inverse of diff plus a constant). diff errors with 0-D inputs in NumPy and PyTorch.

>>> np.diff(np.asarray(0))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/aaronmeurer/miniconda3/envs/array-apis/lib/python3.11/site-packages/numpy/lib/function_base.py", line 1418, in diff
    raise ValueError("diff requires input that is at least one dimensional")
ValueError: diff requires input that is at least one dimensional
>>> torch.diff(torch.asarray(0))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: diff expects input to be at least one-dimensional

@kgryte
Copy link
Contributor

kgryte commented Apr 22, 2024

Personally, I'd be more inclined to require that the input array should have at least one dimension. If a 0D array is considered the array equivalent of a scalar, then performing a cumulative sum over a scalar doesn't make sense to me.

@seberg
Copy link
Contributor

seberg commented Apr 23, 2024

I think the first question should be what the default is when you have N-dimensions. I would be very surprised if the scalar rull doesn't fall out of that, since the results should only come from axis=None (ravel) or other axis=() (empty tuple).
(Likely accumuldate is limited to a single axis and defaults to that, but that already means that scalars just shouldn't work.)

For example the NumPy example:

np.cumsum(np.asarray(0), axis=0)

seems clarly incorrect (from todays point of view, 20 years ago we were more forgiving/guessing) and np.add.accumulate(np.asarray(0), axis=0) correctly raises.

@asmeurer
Copy link
Member Author

The standard requires axis to be specified if there are more than 1 dimensions. So right now, axis=None doesn't really mean "flatten", it just means you don't have to specify it in the common 1-D case. There's no flattening at all in cumulative_sum, which is why I argue 0-D shouldn't do it either. And I definitely agree no function should allow axis=0 on 0-D inputs.

I'm also somewhat inclining towards disallowing this, or at least leaving it undefined.

@seberg
Copy link
Contributor

seberg commented Apr 23, 2024

it just means you don't have to specify it in the common 1-D case

Right, and this is already means it is unspecified. 0-D is N-D and only 1-D is specified.

towards disallowing this

While I think the default of "ravelling" isn't useful or even very sensible for accumulations, I am not sure I see a big enough gain in prescribing cumsum != cumulative_sum for implementations.

@seberg
Copy link
Contributor

seberg commented Apr 23, 2024

I seee that the confusion here really came from axis=0 which just doesn't make sense for 0-D at all.

I think it would be completely fine to specify that as not allowed, I don't even see it necessary to specify it as not allowed. It is behavior that clearly should be deprecated even if I might not jump at actually doing it.

@Micky774
Copy link
Contributor

I agree that disallowing 0-D inputs makes sense here. Note that in JAX our current implementation raises a ValueError for 0-D inputs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants