fix: detect incomplete fast refit on TRT-RTX via unset weights check#4198
fix: detect incomplete fast refit on TRT-RTX via unset weights check#4198tp5uiuc wants to merge 1 commit intopytorch:mainfrom
Conversation
The fast refit path relies on get_missing_weights() to detect when the weight name map cache is incomplete and fall back to slow refit. On TRT-RTX, get_missing_weights() may not report weights in independent wtsEngines (Myelin native refit closures have no cross-engine dependencies), causing the fast refit to silently produce wrong results. Add a complementary check that compares engine weights against the mapping to detect unset weights directly, independent of the wtsEngine dependency structure. Also unwaive test_refit_one_engine_with_wrong_weightmap on TRT-RTX since the fix enables it to pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> fix : import ordering Signed-off-by: tejaswinp <tejaswinp@nvidia.com>
57678d3 to
bcc25ae
Compare
| # weights in independent engines that get_missing_weights() may not report. | ||
| missing_weights = refitter.get_missing_weights() | ||
| unset_weights = {w for w in weight_list if w not in mapping} | ||
| assert len(missing_weights) == 0 and len(unset_weights) == 0, ( |
There was a problem hiding this comment.
On TRT-RTX, get_missing_weights() may return 0 even when some engine weights were not set during fast refit. This happens because TRT-RTX engines can have independent weight groups with no cross-dependencies. Setting weights in one group doesn't make other groups visible to the missing-weights check. TRT has this behavior because global optimizations like batchnorm folding occurs at a user-visible level.
Note that get_missing_weights will NOT report all missing weights, it only reports weights missing from that group/combination (C++ API, Python API)
For example, if some Weights have been set, but the engine was optimized in a way that combines weights, any unsupplied Weights in the combination are considered missing.
For TRT-RTX, all weights are treated independently (i.e. no combination exists).
The existing fast refit path relies solely on get_missing_weights() == 0 to verify completeness. When this check passes incorrectly, the fallback to slow refit never triggers, and the engine runs with stale weights producing wrong output. The existing check is insufficient for TRT-RTX.
Propsed fix: After the fast refit loop, also compare the set of engine weights (get_all_weights()) against the weights that were actually present in the mapping. If any engine weight was skipped (not in the mapping), the fast refit is considered incomplete and falls back to slow refit — regardless of what get_missing_weights() reports.
This is a one-line addition (unset_weights = {w for w in weight_list if w not in mapping}) plus the extended assert. Works correctly on both standard TensorRT and TensorRT-RTX.
Summary
get_missing_weights()may not report weights in independent wtsEngines (TRT-RTX native refit closures have no cross-engine dependencies), causing the fast refit to silently produce wrong results with stale weightsunset_weightscheck catches this regardless of the wtsEngine dependency structuretest_refit_one_engine_with_wrong_weightmapon TRT-RTXTest plan
test_refit_one_engine_with_wrong_weightmappasses on standard TRT (torch 2.13.0.dev20260416, TRT 10.16.1.11)test_refit_one_engine_with_wrong_weightmappasses on TRT-RTX (torch 2.13.0.dev20260416, TRT-RTX 1.4.0.76)🤖 Generated with Claude Code