Skip to content

fix: detect incomplete fast refit on TRT-RTX via unset weights check#4198

Open
tp5uiuc wants to merge 1 commit intopytorch:mainfrom
tp5uiuc:fix/refit-wrong-weightmap-trtrtx
Open

fix: detect incomplete fast refit on TRT-RTX via unset weights check#4198
tp5uiuc wants to merge 1 commit intopytorch:mainfrom
tp5uiuc:fix/refit-wrong-weightmap-trtrtx

Conversation

@tp5uiuc
Copy link
Copy Markdown
Contributor

@tp5uiuc tp5uiuc commented Apr 21, 2026

Summary

  • Add complementary completeness check in fast refit path that compares engine weights against the mapping to detect unset weights directly
  • On TRT-RTX, get_missing_weights() may not report weights in independent wtsEngines (TRT-RTX native refit closures have no cross-engine dependencies), causing the fast refit to silently produce wrong results with stale weights
  • The new unset_weights check catches this regardless of the wtsEngine dependency structure
  • Unwaive test_refit_one_engine_with_wrong_weightmap on TRT-RTX

Test plan

  • test_refit_one_engine_with_wrong_weightmap passes on standard TRT (torch 2.13.0.dev20260416, TRT 10.16.1.11)
  • test_refit_one_engine_with_wrong_weightmap passes on TRT-RTX (torch 2.13.0.dev20260416, TRT-RTX 1.4.0.76)
  • CI L0/L1 refit tests

🤖 Generated with Claude Code

@meta-cla meta-cla Bot added the cla signed label Apr 21, 2026
@github-actions github-actions Bot added component: tests Issues re: Tests component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Apr 21, 2026
@github-actions github-actions Bot requested a review from cehongwang April 21, 2026 00:10
Comment thread tests/py/dynamo/models/test_model_refit.py Outdated
The fast refit path relies on get_missing_weights() to detect when the
weight name map cache is incomplete and fall back to slow refit. On
TRT-RTX, get_missing_weights() may not report weights in independent
wtsEngines (Myelin native refit closures have no cross-engine
dependencies), causing the fast refit to silently produce wrong results.

Add a complementary check that compares engine weights against the
mapping to detect unset weights directly, independent of the wtsEngine
dependency structure. Also unwaive test_refit_one_engine_with_wrong_weightmap
on TRT-RTX since the fix enables it to pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

fix : import ordering

Signed-off-by: tejaswinp <tejaswinp@nvidia.com>
@tp5uiuc tp5uiuc force-pushed the fix/refit-wrong-weightmap-trtrtx branch from 57678d3 to bcc25ae Compare April 21, 2026 00:21
# weights in independent engines that get_missing_weights() may not report.
missing_weights = refitter.get_missing_weights()
unset_weights = {w for w in weight_list if w not in mapping}
assert len(missing_weights) == 0 and len(unset_weights) == 0, (
Copy link
Copy Markdown
Contributor Author

@tp5uiuc tp5uiuc Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On TRT-RTX, get_missing_weights() may return 0 even when some engine weights were not set during fast refit. This happens because TRT-RTX engines can have independent weight groups with no cross-dependencies. Setting weights in one group doesn't make other groups visible to the missing-weights check. TRT has this behavior because global optimizations like batchnorm folding occurs at a user-visible level.

Note that get_missing_weights will NOT report all missing weights, it only reports weights missing from that group/combination (C++ API, Python API)

For example, if some Weights have been set, but the engine was optimized in a way that combines weights, any unsupplied Weights in the combination are considered missing.

For TRT-RTX, all weights are treated independently (i.e. no combination exists).

The existing fast refit path relies solely on get_missing_weights() == 0 to verify completeness. When this check passes incorrectly, the fallback to slow refit never triggers, and the engine runs with stale weights producing wrong output. The existing check is insufficient for TRT-RTX.

Propsed fix: After the fast refit loop, also compare the set of engine weights (get_all_weights()) against the weights that were actually present in the mapping. If any engine weight was skipped (not in the mapping), the fast refit is considered incomplete and falls back to slow refit — regardless of what get_missing_weights() reports.

This is a one-line addition (unset_weights = {w for w in weight_list if w not in mapping}) plus the extended assert. Works correctly on both standard TensorRT and TensorRT-RTX.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant