-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Python scalars in elementwise functions #807
Comments
Hmm,I am in two minds about reconsidering this choice. On the con side: non-array input to functions is going against the design we have had from the start, it makes static typing a bit harder (we'd need both an >>> import torch
>>> t = torch.ones(3)
>>> torch.maximum(t, 1.5)
...
TypeError: maximum(): argument 'other' (position 2) must be Tensor, not float In principle PyTorch is fine with adding this it looks like, but it's a nontrivial amount of work and no one is working on it as far as I know: pytorch/pytorch#110636. PyTorch does support it in functions matching operators (e.g., TensorFlow also doesn't support it (except for in their
The readability argument is less prominent for functions that for operators though. Both because On the pro side: I agree that it is pretty annoying to get right in a completely portable and generic way. In the cases where one does need it, the natural choice of xp.maximum(x, xp.asarray(1, dtype=x.dtype, device=x.device)) Hence if this is a pattern that a project happens to need a lot, it will probably create a utility function like: def as_zerodim(value, x, /, xp=None):
if xp is None:
xp = array_namespace(x)
return xp.asarray(value, dtype=x.dtype, device=x.device)
# Usage:
xp.maximum(x, as_zerodim(1, x)) PyTorch support comes through |
We could support them in a bespoke way for specific useful functions' arguments like |
It's even a little messier in the case Xarray is currently facing:
|
That's deviating from even the operator behavior in the array API. The specified Similarly, |
The array API supports Python scalars in arithmetic only, i.e., operations like
x + 1
.For the same readability reasons that supporting scalars in arithmetic is valuable, it would nice to also support Python scalars in other elementwise functions, at least those that take multiple arguments like
maximum(x, 0)
orwhere(y, x, 0)
.The text was updated successfully, but these errors were encountered: