Add MLX op handler for aten.isinf#18936
Conversation
Decompose isinf(x) into abs(x) == inf using existing AbsNode and EqualNode, so the op runs on the Metal GPU via MLX instead of falling back to CPU execution. Closes pytorch#18922 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18936
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below:
|
|
Hi @Ai-chan-0411! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
This PR needs a
|
| """Return a callable(shape, dtype) that generates inputs with some inf values.""" | ||
|
|
||
| def fn(shape, dtype): | ||
| x = torch.randn(shape, dtype=dtype) |
There was a problem hiding this comment.
Can we add some nans to this generated test input as well?
There was a problem hiding this comment.
Done — added NaN values to _inf_input_fn. The generated inputs now include float("nan") alongside inf/-inf, so the test verifies that isinf correctly returns False for NaN. Pushed in 7647a7e.
| return out | ||
|
|
||
|
|
||
| @REGISTRY.register(target=[torch.ops.aten.isinf.default]) |
There was a problem hiding this comment.
Are there any other variants of this torch op? E.g., isinf.Tensor, etc.
There may not be, just curious.
There was a problem hiding this comment.
I checked the PyTorch ATen operator registry and the executorch codebase — the only variant is aten.isinf.default. There is no isinf.Tensor or isinf.out overload defined. The other backends (MPS, Qualcomm) also only register aten.isinf.default, so this handler covers the complete surface.
|
Looks great @Ai-chan-0411! Can we add nans to the test case (to make sure isinf behaves correctly when its input is nan)? |
Include NaN in the generated test data alongside inf/-inf to ensure isinf correctly returns False for NaN inputs.
|
Thanks for the review! You're right about adding NaN test cases. Since |
|
Good question about variants. In PyTorch's aten ops, |
|
I noticed the MLX workflow run (24492915727) failed, but the failure was due to infrastructure issues — specifically, network connectivity errors when cloning submodules (NVIDIA/cutlass and google/flatbuffers). This is not related to the code changes in this PR. The implementation itself should be sound. A retry of the workflow run should succeed. |
Summary
While looking at the MLX backend coverage for numerical-stability ops, I noticed
aten.isinfwas missing — any model usingtorch.isinfwould silently fall back to CPU, breaking the GPU acceleration pipeline.This PR adds a decomposed handler that expresses
isinf(x)asabs(x) == inf, reusing the existingAbsNodeandEqualNodeinfrastructure. Both positive and negative infinity are correctly detected through the abs step.Changes:
backends/mlx/ops.py— new_isinf_handlerregistered fortorch.ops.aten.isinf.defaultbackends/mlx/test/test_ops.py— added_inf_input_fn(generates tensors with scattered ±inf values) and anisinfentry in_UNARY_OP_TESTSCloses #18922