Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python scalars in elementwise functions #807

Open
shoyer opened this issue May 15, 2024 · 4 comments
Open

Python scalars in elementwise functions #807

shoyer opened this issue May 15, 2024 · 4 comments
Labels
API change Changes to existing functions or objects in the API.

Comments

@shoyer
Copy link
Contributor

shoyer commented May 15, 2024

The array API supports Python scalars in arithmetic only, i.e., operations like x + 1.

For the same readability reasons that supporting scalars in arithmetic is valuable, it would nice to also support Python scalars in other elementwise functions, at least those that take multiple arguments like maximum(x, 0) or where(y, x, 0).

@rgommers
Copy link
Member

Hmm,I am in two minds about reconsidering this choice.

On the con side: non-array input to functions is going against the design we have had from the start, it makes static typing a bit harder (we'd need both an Array protocol and an ArrayOrScalar union), and not all libraries support it yet - PyTorch in particular. E.g.:

>>> import torch
>>> t = torch.ones(3)
>>> torch.maximum(t, 1.5)
...
TypeError: maximum(): argument 'other' (position 2) must be Tensor, not float

In principle PyTorch is fine with adding this it looks like, but it's a nontrivial amount of work and no one is working on it as far as I know: pytorch/pytorch#110636. PyTorch does support it in functions matching operators (e.g., torch.add) and in torch.where.

TensorFlow also doesn't support it (except for in their experimental.numpy namespace IIRC), but that's less relevant now since it doesn't look like they're going to implement anything.

For the same readability reasons that supporting scalars in arithmetic is valuable

The readability argument is less prominent for functions that for operators though. Both because x + 1 is very short so the relative increase in characters is worse than for function calls (since modname.funcname is already long). Plus scalars are less commonly used in function calls.


On the pro side: I agree that it is pretty annoying to get right in a completely portable and generic way. In the cases where one does need it, the natural choice of asarray(scalar) often doesn't work, it should also use the dtype and device. So xp.maximum(x, 1) becomes:

xp.maximum(x, xp.asarray(1, dtype=x.dtype, device=x.device))

Hence if this is a pattern that a project happens to need a lot, it will probably create a utility function like:

def as_zerodim(value, x, /, xp=None):
    if xp is None:
        xp = array_namespace(x)
    return xp.asarray(value, dtype=x.dtype, device=x.device)


# Usage:
xp.maximum(x, as_zerodim(1, x))

PyTorch support comes through array-api-compat at this point, so wrapping the PyTorch functions isn't too hard. So it is doable. I think I'm +0.5 on balance. It's not the highest-prio item, but it's nice to have if it works for all implementing libraries.

@rgommers rgommers added the API change Changes to existing functions or objects in the API. label May 17, 2024
@asmeurer
Copy link
Member

We could support them in a bespoke way for specific useful functions' arguments like where. We already added scalar support specifically to the min and max arguments to clip https://data-apis.org/array-api/latest/API_specification/generated/array_api.clip.html

@shoyer
Copy link
Contributor Author

shoyer commented May 17, 2024

On the pro side: I agree that it is pretty annoying to get right in a completely portable and generic way. In the cases where one does need it, the natural choice of asarray(scalar) often doesn't work, it should also use the dtype and device. So xp.maximum(x, 1) becomes:

xp.maximum(x, xp.asarray(1, dtype=x.dtype, device=x.device))

It's even a little messier in the case Xarray is currently facing:

  1. We want this to work in a completely portable and generic way, with the minimum array-API requirements.
  2. We also still want to allow libraries like NumPy to figure out the result type itself. For example, consider maximum(x, 0.5) in the case where x is an integer dtype. In the array API, mixed dtype casting is undefined, but in most array libraries the result would be upcast to some form of float.

@asmeurer
Copy link
Member

asmeurer commented May 17, 2024

In the array API, mixed dtype casting is undefined, but in most array libraries the result would be upcast to some form of float.

That's deviating from even the operator behavior in the array API. The specified scalar OP array behavior is to only upcast the scalar to the type of the array, not the other way around https://data-apis.org/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars. In other words, int OP float_array is OK, but float OP int_array is not. Implicitly casing an integer array to a floating point dtype is cross-kind casting, and is something we've tried to explicitly avoid. (to be clear, these are all recommended, not required. Libraries like NumPy are free to implement this if they choose to)

Similarly, clip, which as I mentioned is an example of a function that already allows Python scalars, leaves mixed kind scalars unspecified, although I personally think it should adopt the same logic as operators and allow int alongside floating-point arrays.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API change Changes to existing functions or objects in the API.
Projects
None yet
Development

No branches or pull requests

3 participants