-
Notifications
You must be signed in to change notification settings - Fork 940
Fix #18562: Method.execute() silently produces wrong results for no...
#18935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -144,6 +144,12 @@ def execute(self, inputs: Sequence[Any]) -> Sequence[Any]: | |||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||
| A list of output values, typically torch.Tensor objects. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| inputs = [ | ||||||||||||||||||||||||||
| x.contiguous() if isinstance(x, torch.Tensor) and not x.is_contiguous() else x | ||||||||||||||||||||||||||
| for x in inputs | ||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||
| return self._method(inputs) | ||||||||||||||||||||||||||
|
Comment on lines
+149
to
153
|
||||||||||||||||||||||||||
| inputs = [ | |
| x.contiguous() if isinstance(x, torch.Tensor) and not x.is_contiguous() else x | |
| for x in inputs | |
| ] | |
| return self._method(inputs) | |
| converted_inputs = None | |
| for i, x in enumerate(inputs): | |
| if isinstance(x, torch.Tensor) and not x.is_contiguous(): | |
| if converted_inputs is None: | |
| converted_inputs = list(inputs) | |
| converted_inputs[i] = x.contiguous() | |
| return self._method(converted_inputs if converted_inputs is not None else inputs) |
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -78,6 +78,24 @@ def test_add(program): | |||||||
| program = runtime.load_program(f.read()) | ||||||||
| test_add(program) | ||||||||
|
|
||||||||
| def test_execute_non_contiguous_inputs(self): | ||||||||
| """Non-contiguous tensors (e.g. after permute) must produce the same | ||||||||
| result as their contiguous equivalents.""" | ||||||||
| ep, inputs = create_program(ModuleAdd()) | ||||||||
| runtime = Runtime.get() | ||||||||
| program = runtime.load_program(ep.buffer, verification=Verification.Minimal) | ||||||||
|
|
||||||||
| # Make a non-contiguous version of the first input via transpose. | ||||||||
|
||||||||
| # Make a non-contiguous version of the first input via transpose. | |
| # Make a non-contiguous version of the first input via | |
| # unsqueeze/expand/permute followed by slicing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Importing
torchinsideexecute()adds overhead on every call and also changes whenImportErrorwould surface (now at runtime call time). Prefer a module-level import, or at least import lazily only when atorch.Tensoris actually present (e.g., scan inputs first), optionally with a clear error if torch isn't available.