[Feat,Refactor]: Offline Dflash; Spec Mixin; Deprecate parallel draft;#1271
[Feat,Refactor]: Offline Dflash; Spec Mixin; Deprecate parallel draft;#1271
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThe pull request reorganizes speculative decoding infrastructure by consolidating plugin implementations into separate, dedicated modules. It introduces a reusable Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/modeling_dflash.py (1)
118-124: Consider handlingNonefor_attn_implementationmore explicitly.The code assumes
config._attn_implementationis set (per the comment referencingdflash/default_config.py), but if it'sNone,ALL_ATTENTION_FUNCTIONS.get(None, ...)would still work and fall back to SDPA. However, this could be made more explicit for clarity.♻️ Optional: More explicit None handling
def _get_attn_fn(self): """Lazily resolve the HF attention function (default: sdpa).""" if self._attn_fn is not None: return self._attn_fn - impl = self.config._attn_implementation # default set in dflash/default_config.py + impl = self.config._attn_implementation or "sdpa" self._attn_fn = ALL_ATTENTION_FUNCTIONS.get(impl, ALL_ATTENTION_FUNCTIONS["sdpa"]) return self._attn_fn🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/modeling_dflash.py` around lines 118 - 124, The _get_attn_fn method should explicitly handle a None or missing config._attn_implementation instead of relying on dict.get's default; update _get_attn_fn to read impl = self.config._attn_implementation and if impl is None or impl not in ALL_ATTENTION_FUNCTIONS explicitly assign impl = "sdpa" (or the intended default) before setting self._attn_fn via ALL_ATTENTION_FUNCTIONS[impl], so callers of _get_attn_fn and readers of the code see clear, intentional fallback behavior referencing the _get_attn_fn method, config._attn_implementation, and ALL_ATTENTION_FUNCTIONS.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/hf_medusa.py`:
- Line 137: The call in hf_medusa.py is passing an incorrectly named keyword
rcache_position which causes the cache position to be ignored; update the
invocation (around the code that constructs the KV cache / calls the method
using rcache_position) to use the correct keyword name cache_position so the
value is honored (search for rcache_position in the hf_medusa.py plugin and
replace it with cache_position in that function/method call).
- Around line 162-171: The loop mutates labels in-place causing misaligned
targets for later heads; preserve the original labels (e.g., store
original_labels = labels before the for loop) and inside the loop compute a
shifted copy like shifted = original_labels[..., 1 + i :].contiguous() (instead
of reassigning labels), then use loss_logits = medusa_logits[i][:, : -(1 +
i)].contiguous(), loss_labels = shifted.view(-1), and compute loss as before
with loss_fct, medusa_decay_coefficient, and medusa_heads_coefficient to avoid
cumulative shifts across iterations.
In `@modelopt/torch/speculative/plugins/hf_spec_mixin.py`:
- Around line 157-165: The code temporarily mutates the global
torch._dynamo.config.suppress_errors but never restores it; wrap the change
around the compile loop by saving the original value of
torch._dynamo.config.suppress_errors, set it to True before iterating
self._compile_targets and compiling each target (using getattr(self, name),
torch.compile(...), setattr(self, name, ...)), and ensure you restore the
original suppress_errors value in a finally block so the global config is
returned to its prior state regardless of compilation success or exceptions.
- Around line 169-175: HFDFlashModel currently inherits
HFSpecDecMixin.get_dummy_inputs which raises NotImplementedError; add an
override in HFDFlashModel (hf_dflash.py) implementing get_dummy_inputs()
following the same shape/keys pattern used by HFEagleModel (so it returns a dict
of dummy tensors/arrays for the export forward pass rather than raising),
ensuring the method signature matches the base and provides the expected keys
consumed by unified_export_hf.py (called at the export flow around
unified_export_hf.py:386) so speculative/offline export does not fail at
runtime.
In `@modelopt/torch/speculative/plugins/modeling_eagle.py`:
- Around line 159-163: The code always assigns self._input_embeds from
self.layers[0].input_layernorm(inputs_embeds) even when the EAGLE-3 pre-hook
that consumes it is not registered; guard the assignment so it only runs when
the pre-hook will consume the value (e.g., check
self.config.use_aux_hidden_state or the same condition used when registering the
pre-hook) to avoid retaining an unnecessary activation tensor. Update the block
that sets self._input_embeds (referencing inputs_embeds,
self.layers[0].input_layernorm, and self._input_embeds) to perform the
assignment only when use_aux_hidden_state (or the hook-registered flag) is true.
---
Nitpick comments:
In `@modelopt/torch/speculative/plugins/modeling_dflash.py`:
- Around line 118-124: The _get_attn_fn method should explicitly handle a None
or missing config._attn_implementation instead of relying on dict.get's default;
update _get_attn_fn to read impl = self.config._attn_implementation and if impl
is None or impl not in ALL_ATTENTION_FUNCTIONS explicitly assign impl = "sdpa"
(or the intended default) before setting self._attn_fn via
ALL_ATTENTION_FUNCTIONS[impl], so callers of _get_attn_fn and readers of the
code see clear, intentional fallback behavior referencing the _get_attn_fn
method, config._attn_implementation, and ALL_ATTENTION_FUNCTIONS.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 0cc18e20-ee50-4dfc-bcc2-bad5a1a6e6e1
📒 Files selected for processing (16)
.pre-commit-config.yamlexamples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/main.pyexamples/speculative_decoding/scripts/ar_validate.pymodelopt/torch/export/plugins/hf_spec_export.pymodelopt/torch/speculative/config.pymodelopt/torch/speculative/dflash/dflash_model.pymodelopt/torch/speculative/eagle/default_config.pymodelopt/torch/speculative/plugins/__init__.pymodelopt/torch/speculative/plugins/hf_dflash.pymodelopt/torch/speculative/plugins/hf_eagle.pymodelopt/torch/speculative/plugins/hf_medusa.pymodelopt/torch/speculative/plugins/hf_spec_mixin.pymodelopt/torch/speculative/plugins/modeling_dflash.pymodelopt/torch/speculative/plugins/modeling_eagle.pymodelopt/torch/speculative/utils.py
💤 Files with no reviewable changes (1)
- modelopt/torch/speculative/eagle/default_config.py
| use_cache=use_cache, | ||
| output_attentions=output_attentions, | ||
| output_hidden_states=output_hidden_states, | ||
| rcache_position=cache_position, |
There was a problem hiding this comment.
Typo: rcache_position should be cache_position.
This will cause the cache position to be ignored, potentially leading to incorrect behavior with KV cache.
🐛 Proposed fix
- rcache_position=cache_position,
+ cache_position=cache_position,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| rcache_position=cache_position, | |
| cache_position=cache_position, |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/hf_medusa.py` at line 137, The call in
hf_medusa.py is passing an incorrectly named keyword rcache_position which
causes the cache position to be ignored; update the invocation (around the code
that constructs the KV cache / calls the method using rcache_position) to use
the correct keyword name cache_position so the value is honored (search for
rcache_position in the hf_medusa.py plugin and replace it with cache_position in
that function/method call).
| for i in range(self.medusa_num_heads): | ||
| labels = labels[..., 1:].contiguous() | ||
| loss_logits = medusa_logits[i][:, : -(1 + i)].contiguous() | ||
| loss_logits = loss_logits.view(-1, loss_logits.shape[-1]) | ||
| loss_labels = labels.view(-1) | ||
| loss += ( | ||
| loss_fct(loss_logits, loss_labels) | ||
| * medusa_decay_coefficient**i | ||
| * medusa_heads_coefficient | ||
| ) |
There was a problem hiding this comment.
Labels are mutated in-place, causing incorrect loss computation for heads > 0.
The line labels = labels[..., 1:].contiguous() overwrites labels on each iteration. For head i, the labels will be shifted i+1 times cumulatively instead of i+1 times from the original. This causes later heads to train on increasingly truncated and misaligned labels.
🐛 Proposed fix
+ # Store original labels for computing per-head loss
+ original_labels = labels
# Medusa loss
for i in range(self.medusa_num_heads):
- labels = labels[..., 1:].contiguous()
- loss_logits = medusa_logits[i][:, : -(1 + i)].contiguous()
+ # Shift labels by (i+1) positions for head i
+ shifted_labels = original_labels[..., (i + 1) :].contiguous()
+ loss_logits = medusa_logits[i][:, : -(1 + i)].contiguous()
loss_logits = loss_logits.view(-1, loss_logits.shape[-1])
- loss_labels = labels.view(-1)
+ loss_labels = shifted_labels.view(-1)
loss += (
loss_fct(loss_logits, loss_labels)
* medusa_decay_coefficient**i
* medusa_heads_coefficient
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/hf_medusa.py` around lines 162 - 171, The
loop mutates labels in-place causing misaligned targets for later heads;
preserve the original labels (e.g., store original_labels = labels before the
for loop) and inside the loop compute a shifted copy like shifted =
original_labels[..., 1 + i :].contiguous() (instead of reassigning labels), then
use loss_logits = medusa_logits[i][:, : -(1 + i)].contiguous(), loss_labels =
shifted.view(-1), and compute loss as before with loss_fct,
medusa_decay_coefficient, and medusa_heads_coefficient to avoid cumulative
shifts across iterations.
| import torch._dynamo | ||
|
|
||
| torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode | ||
|
|
||
| for name, kwargs in self._compile_targets: | ||
| try: | ||
| setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs)) | ||
| except Exception: # noqa: PERF203 | ||
| print(f"Disabling torch.compile for {name} due to compilation error.") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n modelopt/torch/speculative/plugins/hf_spec_mixin.py | sed -n '150,175p'Repository: NVIDIA/Model-Optimizer
Length of output: 1319
🏁 Script executed:
# Search for torch._dynamo.config.suppress_errors usage in the codebase
rg 'torch\._dynamo\.config\.suppress_errors' -A 2 -B 2Repository: NVIDIA/Model-Optimizer
Length of output: 491
🏁 Script executed:
# Check if there are any imports or context managers related to torch._dynamo
rg 'torch\._dynamo' -A 1 -B 1Repository: NVIDIA/Model-Optimizer
Length of output: 1953
Restore torch._dynamo.config.suppress_errors after compilation.
Line 159 sets a process-global Dynamo flag without restoring it. After one model calls this helper, all subsequent code in the process will silently fall back to eager mode on compilation errors, even if not intended.
♻️ Proposed fix
def _activate_torch_compile(self):
"""Apply ``torch.compile`` to methods listed in ``_compile_targets``.
Each entry is ``(method_name, extra_kwargs)`` passed to ``torch.compile(..., dynamic=False)``.
Failures fall back to eager mode silently.
"""
import torch._dynamo
- torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode
-
- for name, kwargs in self._compile_targets:
- try:
- setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs))
- except Exception: # noqa: PERF203
- print(f"Disabling torch.compile for {name} due to compilation error.")
+ prev_suppress_errors = torch._dynamo.config.suppress_errors
+ torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode
+ try:
+ for name, kwargs in self._compile_targets:
+ try:
+ setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs))
+ except Exception: # noqa: PERF203
+ print(f"Disabling torch.compile for {name} due to compilation error.")
+ finally:
+ torch._dynamo.config.suppress_errors = prev_suppress_errors🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/hf_spec_mixin.py` around lines 157 - 165,
The code temporarily mutates the global torch._dynamo.config.suppress_errors but
never restores it; wrap the change around the compile loop by saving the
original value of torch._dynamo.config.suppress_errors, set it to True before
iterating self._compile_targets and compiling each target (using getattr(self,
name), torch.compile(...), setattr(self, name, ...)), and ensure you restore the
original suppress_errors value in a finally block so the global config is
returned to its prior state regardless of compilation success or exceptions.
| def get_dummy_inputs(self) -> dict: | ||
| """Construct dummy inputs for export forward pass. Subclasses must override.""" | ||
| raise NotImplementedError | ||
|
|
||
| def get_exporter(self): | ||
| """Return the exporter for the draft model. Subclasses must override.""" | ||
| raise NotImplementedError |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== overrides in speculative plugins ==="
for name in hf_spec_mixin.py hf_eagle.py hf_dflash.py; do
file="$(fd -a "^${name}$" modelopt/torch/speculative/plugins || true)"
if [ -n "$file" ]; then
echo "--- $file ---"
rg -nC2 'class\s+HF(Eagle|DFlash)Model|def\s+get_dummy_inputs|def\s+get_exporter' "$file"
fi
done
echo
echo "=== generic get_dummy_inputs call sites ==="
rg -nC2 '\bget_dummy_inputs\s*\(' modelopt examplesRepository: NVIDIA/Model-Optimizer
Length of output: 3721
🏁 Script executed:
cat -n modelopt/torch/speculative/plugins/hf_dflash.py | grep -A 50 "class HFDFlashModel"Repository: NVIDIA/Model-Optimizer
Length of output: 2800
🏁 Script executed:
rg -A 100 "class HFDFlashModel" modelopt/torch/speculative/plugins/hf_dflash.py | head -120Repository: NVIDIA/Model-Optimizer
Length of output: 4569
🏁 Script executed:
grep -n "get_dummy_inputs" modelopt/torch/speculative/plugins/hf_dflash.pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
wc -l modelopt/torch/speculative/plugins/hf_dflash.pyRepository: NVIDIA/Model-Optimizer
Length of output: 117
🏁 Script executed:
cat -n modelopt/torch/export/unified_export_hf.py | sed -n '375,395p'Repository: NVIDIA/Model-Optimizer
Length of output: 1221
🏁 Script executed:
grep -B 10 -A 5 "model.get_dummy_inputs()" modelopt/torch/export/unified_export_hf.pyRepository: NVIDIA/Model-Optimizer
Length of output: 986
🏁 Script executed:
rg -n "get_dummy_inputs\s*\(" --type py modelopt examples | grep -v "def get_dummy_inputs\|^\s*#"Repository: NVIDIA/Model-Optimizer
Length of output: 287
🏁 Script executed:
grep -r "HFDFlashModel" modelopt examples --include="*.py" | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 359
🏁 Script executed:
rg -n "export.*dflash|dflash.*export" modelopt examples --type py -i | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 948
🏁 Script executed:
rg -n "HFDFlashModel.*export\|unified_export.*dflash" modelopt examples --type py -iRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
grep -r "specdec_offline" examples --include="*.py" -B 3 -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 5824
🏁 Script executed:
rg -n "get_dummy_inputs" modelopt examples --type py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 2131
🏁 Script executed:
rg -n "class DFlashExporter" -A 20 modelopt/torch/export/plugins/hf_spec_export.py | head -30Repository: NVIDIA/Model-Optimizer
Length of output: 986
🏁 Script executed:
rg -n "def export" modelopt/torch/export/plugins/hf_spec_export.py | head -10Repository: NVIDIA/Model-Optimizer
Length of output: 188
🏁 Script executed:
sed -n '354,380p' modelopt/torch/export/plugins/hf_spec_export.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1344
🏁 Script executed:
sed -n '268,360p' modelopt/torch/export/plugins/hf_spec_export.py | grep -A 20 "def export"Repository: NVIDIA/Model-Optimizer
Length of output: 391
🏁 Script executed:
python3 << 'EOF'
class Base:
def method(self):
raise NotImplementedError
class Child(Base):
pass
c = Child()
print("hasattr check:", hasattr(c, "method"))
try:
c.method()
except NotImplementedError:
print("NotImplementedError raised as expected")
EOFRepository: NVIDIA/Model-Optimizer
Length of output: 124
🏁 Script executed:
sed -n '50,160p' modelopt/torch/speculative/plugins/hf_dflash.py | grep -n "def "Repository: NVIDIA/Model-Optimizer
Length of output: 130
🏁 Script executed:
rg -n "specdec_offline" examples/llm_ptq/hf_ptq.py -B 5 -A 5 | grep -i "model\|eagle\|dflash" | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 1351
🏁 Script executed:
sed -n '860,885p' examples/llm_ptq/hf_ptq.pyRepository: NVIDIA/Model-Optimizer
Length of output: 949
HFDFlashModel is missing the required get_dummy_inputs() implementation.
HFDFlashModel (hf_dflash.py) only implements get_exporter() but not get_dummy_inputs(). It inherits the base NotImplementedError stub from HFSpecDecMixin. The export code at unified_export_hf.py:386 calls model.get_dummy_inputs() when hasattr(model, "get_dummy_inputs") is true—which passes for HFDFlashModel since the method exists (inherited)—but will raise NotImplementedError at runtime. This breaks the offline speculative decoding export flow used in examples/llm_ptq/hf_ptq.py when --specdec_offline_dataset is specified.
Add a get_dummy_inputs() override to HFDFlashModel matching the pattern in HFEagleModel.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/hf_spec_mixin.py` around lines 169 - 175,
HFDFlashModel currently inherits HFSpecDecMixin.get_dummy_inputs which raises
NotImplementedError; add an override in HFDFlashModel (hf_dflash.py)
implementing get_dummy_inputs() following the same shape/keys pattern used by
HFEagleModel (so it returns a dict of dummy tensors/arrays for the export
forward pass rather than raising), ensuring the method signature matches the
base and provides the expected keys consumed by unified_export_hf.py (called at
the export flow around unified_export_hf.py:386) so speculative/offline export
does not fail at runtime.
| inputs_embeds = inputs_embeds.to(hidden_states.dtype).to(hidden_states.device) | ||
| # In EAGLE-3, we save input embeddings to attribute, and use it in first decoder layer by hook function | ||
| # Also, we normalize input embeddings and hidden states before concatenating them. | ||
| # The default input norm in first layer attn will be disabled. | ||
| self._input_embeds = self.layers[0].input_layernorm(inputs_embeds) |
There was a problem hiding this comment.
Only stash _input_embeds when the EAGLE-3 pre-hook will consume it.
These lines run for every config, but the pre-hook is only registered under use_aux_hidden_state. In the normal path this leaves one extra activation tensor hanging off self for no reason, which increases memory pressure and can keep autograd state alive longer than necessary.
♻️ Proposed fix
inputs_embeds = inputs_embeds.to(hidden_states.dtype).to(hidden_states.device)
# In EAGLE-3, we save input embeddings to attribute, and use it in first decoder layer by hook function
# Also, we normalize input embeddings and hidden states before concatenating them.
# The default input norm in first layer attn will be disabled.
- self._input_embeds = self.layers[0].input_layernorm(inputs_embeds)
+ if self.config.use_aux_hidden_state:
+ self._input_embeds = self.layers[0].input_layernorm(inputs_embeds)
+ else:
+ self._input_embeds = None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/modeling_eagle.py` around lines 159 - 163,
The code always assigns self._input_embeds from
self.layers[0].input_layernorm(inputs_embeds) even when the EAGLE-3 pre-hook
that consumes it is not registered; guard the assignment so it only runs when
the pre-hook will consume the value (e.g., check
self.config.use_aux_hidden_state or the same condition used when registering the
pre-hook) to avoid retaining an unnecessary activation tensor. Update the block
that sets self._input_embeds (referencing inputs_embeds,
self.layers[0].input_layernorm, and self._input_embeds) to perform the
assignment only when use_aux_hidden_state (or the hook-registered flag) is true.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1271 +/- ##
==========================================
- Coverage 76.07% 71.35% -4.72%
==========================================
Files 459 464 +5
Lines 48528 48993 +465
==========================================
- Hits 36917 34961 -1956
- Misses 11611 14032 +2421
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
6df234f to
f91cf9d
Compare
What does this PR do?
Type of change: New feature, refactoring
HFSpecDecMixin: Extract duplicated base-model discovery, forward pass, NVTX profiling, andtorch.compilelogic fromHFEagleModel/HFDFlashModelinto a shared mixin.dflash_offlineconfig flag for training from pre-computed hidden states; deletes base model layers to save memory.ParallelDraft: Removeparallel_draft_step,ParallelDraftmodule, and all related logic from Eagle.transformers.py→hf_eagle.py;HFMedusaModel→hf_medusa.py;DFlashModule→modeling_dflash.py;EagleModule→modeling_eagle.py.dflash_mask_token_idauto-detection frommain.pyintoDFlashConfigPydantic validators.Testing
Validated with existing Eagle and DFlash training scripts (online + offline modes).
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).parallel_draft_stepconfig; renamestransformers.py→hf_eagle.pyCONTRIBUTING.md: N/AAdditional Information
Breaking changes:
modelopt.torch.speculative.plugins.transformers→.hf_eagleparallel_draft_step/parallel_draft_heads_num_layersremoved from Eagle config_draft_model_config→eagle_configin export pluginSummary by CodeRabbit
Release Notes
New Features
Improvements
parallel_draft_step > 1support.Chores