Skip to content

[Feat,Refactor]: Offline Dflash; Spec Mixin; Deprecate parallel draft;#1271

Draft
h-guo18 wants to merge 1 commit intomainfrom
haoguo/spec-mixin
Draft

[Feat,Refactor]: Offline Dflash; Spec Mixin; Deprecate parallel draft;#1271
h-guo18 wants to merge 1 commit intomainfrom
haoguo/spec-mixin

Conversation

@h-guo18
Copy link
Copy Markdown
Contributor

@h-guo18 h-guo18 commented Apr 16, 2026

What does this PR do?

Type of change: New feature, refactoring

  • HFSpecDecMixin: Extract duplicated base-model discovery, forward pass, NVTX profiling, and torch.compile logic from HFEagleModel / HFDFlashModel into a shared mixin.
  • Offline DFlash: Add dflash_offline config flag for training from pre-computed hidden states; deletes base model layers to save memory.
  • Deprecate ParallelDraft: Remove parallel_draft_step, ParallelDraft module, and all related logic from Eagle.
  • File reorg: transformers.pyhf_eagle.py; HFMedusaModelhf_medusa.py; DFlashModulemodeling_dflash.py; EagleModulemodeling_eagle.py.
  • Config validation: Move dflash_mask_token_id auto-detection from main.py into DFlashConfig Pydantic validators.

Testing

Validated with existing Eagle and DFlash training scripts (online + offline modes).

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ❌ — removes parallel_draft_step config; renames transformers.pyhf_eagle.py
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ❌
  • Did you update Changelog?: ❌

Additional Information

Breaking changes:

  • modelopt.torch.speculative.plugins.transformers.hf_eagle
  • parallel_draft_step / parallel_draft_heads_num_layers removed from Eagle config
  • _draft_model_configeagle_config in export plugin

Summary by CodeRabbit

Release Notes

  • New Features

    • Added offline mode support for DFlash speculative decoding configuration.
    • Introduced HuggingFace Medusa speculative decoding plugin.
  • Improvements

    • Simplified parallel draft handling by removing parallel_draft_step > 1 support.
    • Enhanced base model configuration retrieval for better compatibility.
    • Reorganized speculative decoding plugin architecture for improved maintainability.
  • Chores

    • Cleaned up default EAGLE configuration exports.
    • Updated plugin module paths and imports.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 16, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 16, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 1e03bc68-d7a9-44e5-b1c9-6f174d1d5273

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

The pull request reorganizes speculative decoding infrastructure by consolidating plugin implementations into separate, dedicated modules. It introduces a reusable HFSpecDecMixin base class, moves Eagle and DFlash draft model architectures to dedicated modeling_eagle.py and modeling_dflash.py modules, extracts Medusa into its own plugin file, and refactors existing plugins to use shared components. Configuration validation is enhanced for DFlash with automated field resolution and error checking. Parallel draft step functionality is removed from Eagle.

Changes

Cohort / File(s) Summary
Plugin Import Reorganization
.pre-commit-config.yaml, modelopt/torch/speculative/plugins/__init__.py, modelopt/torch/speculative/utils.py
Redirected plugin loading from transformers module to hf_eagle and hf_medusa modules. Updated license-header exclusion paths and CP-TTT patch gating to reference the new plugin modules.
Example and Script Updates
examples/speculative_decoding/eagle_utils.py, examples/speculative_decoding/scripts/ar_validate.py, examples/speculative_decoding/main.py
Updated imports and references to use hf_eagle plugin module instead of transformers. Replaced manual dflash_mask_token_id detection with schema-based DFlashConfig validation in main training entry point.
DFlash Configuration Enhancement
modelopt/torch/speculative/config.py
Added dflash_offline field and Pydantic validators to auto-derive offline mode and resolve dflash_mask_token_id from tokenizer context during validation, with error handling for missing required fields.
DFlash Model Updates
modelopt/torch/speculative/dflash/dflash_model.py
Updated modify() to accept and assign the new dflash_offline configuration field.
Shared Mixin Infrastructure
modelopt/torch/speculative/plugins/hf_spec_mixin.py
Introduced new HFSpecDecMixin providing common base-model discovery, forward execution (with optional freezing), NVTX profiling, and torch.compile acceleration for speculative models.
Draft Model Architecture Modules
modelopt/torch/speculative/plugins/modeling_eagle.py, modelopt/torch/speculative/plugins/modeling_dflash.py
Created dedicated modules implementing EAGLE and DFlash draft model architectures (layers, attention blocks, output containers) extracted from plugin files.
Eagle Plugin Refactoring
modelopt/torch/speculative/plugins/hf_eagle.py
Removed Medusa implementation and parallel draft step support. Replaced local implementations with shared HFSpecDecMixin and modeling_eagle modules. Simplified draft handling to single logit tensor instead of parallel list. Updated base-model forward and TTT loop control flow.
Medusa Plugin Extraction
modelopt/torch/speculative/plugins/hf_medusa.py
Introduced new standalone HFMedusaModel plugin with frozen/non-frozen base model paths, Medusa head construction via ResBlock stacks, and loss computation with decay weighting.
DFlash Plugin Refactoring
modelopt/torch/speculative/plugins/hf_dflash.py
Removed local draft architecture definitions. Replaced with imports from modeling_dflash and HFSpecDecMixin. Introduced _dflash_base_model_forward() helper. Updated offline mode to delete base layers and fetch base outputs from precomputed context. Reworked draft attention and forward pass to use shared base output structure.
Configuration Defaults Cleanup
modelopt/torch/speculative/eagle/default_config.py, modelopt/torch/export/plugins/hf_spec_export.py
Removed parallel draft configuration keys from default Eagle configs. Updated exporter to source Eagle config from model.eagle_config instead of model._draft_model_config.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main changes: offline DFlash feature, the new HFSpecDecMixin, and deprecation of parallel draft support.
Docstring Coverage ✅ Passed Docstring coverage is 80.70% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed Comprehensive security audit found no torch.load without weights_only, numpy.load with allow_pickle, hardcoded trust_remote_code=True, eval/exec on external input, # nosec comments, or non-permissive PIP dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch haoguo/spec-mixin

Comment @coderabbitai help to get the list of available commands and usage tips.

@h-guo18 h-guo18 changed the title code reuse for dflashj/eagle, depreacet parallel draft Refactor: code reuse for dflash/eagle; Deprecate parallel draft Apr 16, 2026
@h-guo18 h-guo18 changed the title Refactor: code reuse for dflash/eagle; Deprecate parallel draft [Feat,Refactor]: Offline Dflash; Spec Mixin; Deprecate parallel draft; Apr 16, 2026
@h-guo18 h-guo18 marked this pull request as ready for review April 16, 2026 07:32
@h-guo18 h-guo18 requested review from a team as code owners April 16, 2026 07:32
@h-guo18 h-guo18 marked this pull request as draft April 16, 2026 07:35
@h-guo18 h-guo18 self-assigned this Apr 16, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/modeling_dflash.py (1)

118-124: Consider handling None for _attn_implementation more explicitly.

The code assumes config._attn_implementation is set (per the comment referencing dflash/default_config.py), but if it's None, ALL_ATTENTION_FUNCTIONS.get(None, ...) would still work and fall back to SDPA. However, this could be made more explicit for clarity.

♻️ Optional: More explicit None handling
     def _get_attn_fn(self):
         """Lazily resolve the HF attention function (default: sdpa)."""
         if self._attn_fn is not None:
             return self._attn_fn
-        impl = self.config._attn_implementation  # default set in dflash/default_config.py
+        impl = self.config._attn_implementation or "sdpa"
         self._attn_fn = ALL_ATTENTION_FUNCTIONS.get(impl, ALL_ATTENTION_FUNCTIONS["sdpa"])
         return self._attn_fn
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/modeling_dflash.py` around lines 118 -
124, The _get_attn_fn method should explicitly handle a None or missing
config._attn_implementation instead of relying on dict.get's default; update
_get_attn_fn to read impl = self.config._attn_implementation and if impl is None
or impl not in ALL_ATTENTION_FUNCTIONS explicitly assign impl = "sdpa" (or the
intended default) before setting self._attn_fn via
ALL_ATTENTION_FUNCTIONS[impl], so callers of _get_attn_fn and readers of the
code see clear, intentional fallback behavior referencing the _get_attn_fn
method, config._attn_implementation, and ALL_ATTENTION_FUNCTIONS.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/plugins/hf_medusa.py`:
- Line 137: The call in hf_medusa.py is passing an incorrectly named keyword
rcache_position which causes the cache position to be ignored; update the
invocation (around the code that constructs the KV cache / calls the method
using rcache_position) to use the correct keyword name cache_position so the
value is honored (search for rcache_position in the hf_medusa.py plugin and
replace it with cache_position in that function/method call).
- Around line 162-171: The loop mutates labels in-place causing misaligned
targets for later heads; preserve the original labels (e.g., store
original_labels = labels before the for loop) and inside the loop compute a
shifted copy like shifted = original_labels[..., 1 + i :].contiguous() (instead
of reassigning labels), then use loss_logits = medusa_logits[i][:, : -(1 +
i)].contiguous(), loss_labels = shifted.view(-1), and compute loss as before
with loss_fct, medusa_decay_coefficient, and medusa_heads_coefficient to avoid
cumulative shifts across iterations.

In `@modelopt/torch/speculative/plugins/hf_spec_mixin.py`:
- Around line 157-165: The code temporarily mutates the global
torch._dynamo.config.suppress_errors but never restores it; wrap the change
around the compile loop by saving the original value of
torch._dynamo.config.suppress_errors, set it to True before iterating
self._compile_targets and compiling each target (using getattr(self, name),
torch.compile(...), setattr(self, name, ...)), and ensure you restore the
original suppress_errors value in a finally block so the global config is
returned to its prior state regardless of compilation success or exceptions.
- Around line 169-175: HFDFlashModel currently inherits
HFSpecDecMixin.get_dummy_inputs which raises NotImplementedError; add an
override in HFDFlashModel (hf_dflash.py) implementing get_dummy_inputs()
following the same shape/keys pattern used by HFEagleModel (so it returns a dict
of dummy tensors/arrays for the export forward pass rather than raising),
ensuring the method signature matches the base and provides the expected keys
consumed by unified_export_hf.py (called at the export flow around
unified_export_hf.py:386) so speculative/offline export does not fail at
runtime.

In `@modelopt/torch/speculative/plugins/modeling_eagle.py`:
- Around line 159-163: The code always assigns self._input_embeds from
self.layers[0].input_layernorm(inputs_embeds) even when the EAGLE-3 pre-hook
that consumes it is not registered; guard the assignment so it only runs when
the pre-hook will consume the value (e.g., check
self.config.use_aux_hidden_state or the same condition used when registering the
pre-hook) to avoid retaining an unnecessary activation tensor. Update the block
that sets self._input_embeds (referencing inputs_embeds,
self.layers[0].input_layernorm, and self._input_embeds) to perform the
assignment only when use_aux_hidden_state (or the hook-registered flag) is true.

---

Nitpick comments:
In `@modelopt/torch/speculative/plugins/modeling_dflash.py`:
- Around line 118-124: The _get_attn_fn method should explicitly handle a None
or missing config._attn_implementation instead of relying on dict.get's default;
update _get_attn_fn to read impl = self.config._attn_implementation and if impl
is None or impl not in ALL_ATTENTION_FUNCTIONS explicitly assign impl = "sdpa"
(or the intended default) before setting self._attn_fn via
ALL_ATTENTION_FUNCTIONS[impl], so callers of _get_attn_fn and readers of the
code see clear, intentional fallback behavior referencing the _get_attn_fn
method, config._attn_implementation, and ALL_ATTENTION_FUNCTIONS.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 0cc18e20-ee50-4dfc-bcc2-bad5a1a6e6e1

📥 Commits

Reviewing files that changed from the base of the PR and between 07ae8e7 and 6df234f.

📒 Files selected for processing (16)
  • .pre-commit-config.yaml
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/main.py
  • examples/speculative_decoding/scripts/ar_validate.py
  • modelopt/torch/export/plugins/hf_spec_export.py
  • modelopt/torch/speculative/config.py
  • modelopt/torch/speculative/dflash/dflash_model.py
  • modelopt/torch/speculative/eagle/default_config.py
  • modelopt/torch/speculative/plugins/__init__.py
  • modelopt/torch/speculative/plugins/hf_dflash.py
  • modelopt/torch/speculative/plugins/hf_eagle.py
  • modelopt/torch/speculative/plugins/hf_medusa.py
  • modelopt/torch/speculative/plugins/hf_spec_mixin.py
  • modelopt/torch/speculative/plugins/modeling_dflash.py
  • modelopt/torch/speculative/plugins/modeling_eagle.py
  • modelopt/torch/speculative/utils.py
💤 Files with no reviewable changes (1)
  • modelopt/torch/speculative/eagle/default_config.py

use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
rcache_position=cache_position,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Typo: rcache_position should be cache_position.

This will cause the cache position to be ignored, potentially leading to incorrect behavior with KV cache.

🐛 Proposed fix
-                rcache_position=cache_position,
+                cache_position=cache_position,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
rcache_position=cache_position,
cache_position=cache_position,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_medusa.py` at line 137, The call in
hf_medusa.py is passing an incorrectly named keyword rcache_position which
causes the cache position to be ignored; update the invocation (around the code
that constructs the KV cache / calls the method using rcache_position) to use
the correct keyword name cache_position so the value is honored (search for
rcache_position in the hf_medusa.py plugin and replace it with cache_position in
that function/method call).

Comment on lines +162 to +171
for i in range(self.medusa_num_heads):
labels = labels[..., 1:].contiguous()
loss_logits = medusa_logits[i][:, : -(1 + i)].contiguous()
loss_logits = loss_logits.view(-1, loss_logits.shape[-1])
loss_labels = labels.view(-1)
loss += (
loss_fct(loss_logits, loss_labels)
* medusa_decay_coefficient**i
* medusa_heads_coefficient
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Labels are mutated in-place, causing incorrect loss computation for heads > 0.

The line labels = labels[..., 1:].contiguous() overwrites labels on each iteration. For head i, the labels will be shifted i+1 times cumulatively instead of i+1 times from the original. This causes later heads to train on increasingly truncated and misaligned labels.

🐛 Proposed fix
+        # Store original labels for computing per-head loss
+        original_labels = labels
         # Medusa loss
         for i in range(self.medusa_num_heads):
-            labels = labels[..., 1:].contiguous()
-            loss_logits = medusa_logits[i][:, : -(1 + i)].contiguous()
+            # Shift labels by (i+1) positions for head i
+            shifted_labels = original_labels[..., (i + 1) :].contiguous()
+            loss_logits = medusa_logits[i][:, : -(1 + i)].contiguous()
             loss_logits = loss_logits.view(-1, loss_logits.shape[-1])
-            loss_labels = labels.view(-1)
+            loss_labels = shifted_labels.view(-1)
             loss += (
                 loss_fct(loss_logits, loss_labels)
                 * medusa_decay_coefficient**i
                 * medusa_heads_coefficient
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_medusa.py` around lines 162 - 171, The
loop mutates labels in-place causing misaligned targets for later heads;
preserve the original labels (e.g., store original_labels = labels before the
for loop) and inside the loop compute a shifted copy like shifted =
original_labels[..., 1 + i :].contiguous() (instead of reassigning labels), then
use loss_logits = medusa_logits[i][:, : -(1 + i)].contiguous(), loss_labels =
shifted.view(-1), and compute loss as before with loss_fct,
medusa_decay_coefficient, and medusa_heads_coefficient to avoid cumulative
shifts across iterations.

Comment on lines +157 to +165
import torch._dynamo

torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode

for name, kwargs in self._compile_targets:
try:
setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs))
except Exception: # noqa: PERF203
print(f"Disabling torch.compile for {name} due to compilation error.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n modelopt/torch/speculative/plugins/hf_spec_mixin.py | sed -n '150,175p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1319


🏁 Script executed:

# Search for torch._dynamo.config.suppress_errors usage in the codebase
rg 'torch\._dynamo\.config\.suppress_errors' -A 2 -B 2

Repository: NVIDIA/Model-Optimizer

Length of output: 491


🏁 Script executed:

# Check if there are any imports or context managers related to torch._dynamo
rg 'torch\._dynamo' -A 1 -B 1

Repository: NVIDIA/Model-Optimizer

Length of output: 1953


Restore torch._dynamo.config.suppress_errors after compilation.

Line 159 sets a process-global Dynamo flag without restoring it. After one model calls this helper, all subsequent code in the process will silently fall back to eager mode on compilation errors, even if not intended.

♻️ Proposed fix
     def _activate_torch_compile(self):
         """Apply ``torch.compile`` to methods listed in ``_compile_targets``.
 
         Each entry is ``(method_name, extra_kwargs)`` passed to ``torch.compile(..., dynamic=False)``.
         Failures fall back to eager mode silently.
         """
         import torch._dynamo
 
-        torch._dynamo.config.suppress_errors = True  # Allow fallback to eager mode
-
-        for name, kwargs in self._compile_targets:
-            try:
-                setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs))
-            except Exception:  # noqa: PERF203
-                print(f"Disabling torch.compile for {name} due to compilation error.")
+        prev_suppress_errors = torch._dynamo.config.suppress_errors
+        torch._dynamo.config.suppress_errors = True  # Allow fallback to eager mode
+        try:
+            for name, kwargs in self._compile_targets:
+                try:
+                    setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs))
+                except Exception:  # noqa: PERF203
+                    print(f"Disabling torch.compile for {name} due to compilation error.")
+        finally:
+            torch._dynamo.config.suppress_errors = prev_suppress_errors
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_spec_mixin.py` around lines 157 - 165,
The code temporarily mutates the global torch._dynamo.config.suppress_errors but
never restores it; wrap the change around the compile loop by saving the
original value of torch._dynamo.config.suppress_errors, set it to True before
iterating self._compile_targets and compiling each target (using getattr(self,
name), torch.compile(...), setattr(self, name, ...)), and ensure you restore the
original suppress_errors value in a finally block so the global config is
returned to its prior state regardless of compilation success or exceptions.

Comment on lines +169 to +175
def get_dummy_inputs(self) -> dict:
"""Construct dummy inputs for export forward pass. Subclasses must override."""
raise NotImplementedError

def get_exporter(self):
"""Return the exporter for the draft model. Subclasses must override."""
raise NotImplementedError
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== overrides in speculative plugins ==="
for name in hf_spec_mixin.py hf_eagle.py hf_dflash.py; do
  file="$(fd -a "^${name}$" modelopt/torch/speculative/plugins || true)"
  if [ -n "$file" ]; then
    echo "--- $file ---"
    rg -nC2 'class\s+HF(Eagle|DFlash)Model|def\s+get_dummy_inputs|def\s+get_exporter' "$file"
  fi
done

echo
echo "=== generic get_dummy_inputs call sites ==="
rg -nC2 '\bget_dummy_inputs\s*\(' modelopt examples

Repository: NVIDIA/Model-Optimizer

Length of output: 3721


🏁 Script executed:

cat -n modelopt/torch/speculative/plugins/hf_dflash.py | grep -A 50 "class HFDFlashModel"

Repository: NVIDIA/Model-Optimizer

Length of output: 2800


🏁 Script executed:

rg -A 100 "class HFDFlashModel" modelopt/torch/speculative/plugins/hf_dflash.py | head -120

Repository: NVIDIA/Model-Optimizer

Length of output: 4569


🏁 Script executed:

grep -n "get_dummy_inputs" modelopt/torch/speculative/plugins/hf_dflash.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

wc -l modelopt/torch/speculative/plugins/hf_dflash.py

Repository: NVIDIA/Model-Optimizer

Length of output: 117


🏁 Script executed:

cat -n modelopt/torch/export/unified_export_hf.py | sed -n '375,395p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1221


🏁 Script executed:

grep -B 10 -A 5 "model.get_dummy_inputs()" modelopt/torch/export/unified_export_hf.py

Repository: NVIDIA/Model-Optimizer

Length of output: 986


🏁 Script executed:

rg -n "get_dummy_inputs\s*\(" --type py modelopt examples | grep -v "def get_dummy_inputs\|^\s*#"

Repository: NVIDIA/Model-Optimizer

Length of output: 287


🏁 Script executed:

grep -r "HFDFlashModel" modelopt examples --include="*.py" | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 359


🏁 Script executed:

rg -n "export.*dflash|dflash.*export" modelopt examples --type py -i | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 948


🏁 Script executed:

rg -n "HFDFlashModel.*export\|unified_export.*dflash" modelopt examples --type py -i

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

grep -r "specdec_offline" examples --include="*.py" -B 3 -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 5824


🏁 Script executed:

rg -n "get_dummy_inputs" modelopt examples --type py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 2131


🏁 Script executed:

rg -n "class DFlashExporter" -A 20 modelopt/torch/export/plugins/hf_spec_export.py | head -30

Repository: NVIDIA/Model-Optimizer

Length of output: 986


🏁 Script executed:

rg -n "def export" modelopt/torch/export/plugins/hf_spec_export.py | head -10

Repository: NVIDIA/Model-Optimizer

Length of output: 188


🏁 Script executed:

sed -n '354,380p' modelopt/torch/export/plugins/hf_spec_export.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1344


🏁 Script executed:

sed -n '268,360p' modelopt/torch/export/plugins/hf_spec_export.py | grep -A 20 "def export"

Repository: NVIDIA/Model-Optimizer

Length of output: 391


🏁 Script executed:

python3 << 'EOF'
class Base:
    def method(self):
        raise NotImplementedError

class Child(Base):
    pass

c = Child()
print("hasattr check:", hasattr(c, "method"))
try:
    c.method()
except NotImplementedError:
    print("NotImplementedError raised as expected")
EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 124


🏁 Script executed:

sed -n '50,160p' modelopt/torch/speculative/plugins/hf_dflash.py | grep -n "def "

Repository: NVIDIA/Model-Optimizer

Length of output: 130


🏁 Script executed:

rg -n "specdec_offline" examples/llm_ptq/hf_ptq.py -B 5 -A 5 | grep -i "model\|eagle\|dflash" | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 1351


🏁 Script executed:

sed -n '860,885p' examples/llm_ptq/hf_ptq.py

Repository: NVIDIA/Model-Optimizer

Length of output: 949


HFDFlashModel is missing the required get_dummy_inputs() implementation.

HFDFlashModel (hf_dflash.py) only implements get_exporter() but not get_dummy_inputs(). It inherits the base NotImplementedError stub from HFSpecDecMixin. The export code at unified_export_hf.py:386 calls model.get_dummy_inputs() when hasattr(model, "get_dummy_inputs") is true—which passes for HFDFlashModel since the method exists (inherited)—but will raise NotImplementedError at runtime. This breaks the offline speculative decoding export flow used in examples/llm_ptq/hf_ptq.py when --specdec_offline_dataset is specified.

Add a get_dummy_inputs() override to HFDFlashModel matching the pattern in HFEagleModel.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_spec_mixin.py` around lines 169 - 175,
HFDFlashModel currently inherits HFSpecDecMixin.get_dummy_inputs which raises
NotImplementedError; add an override in HFDFlashModel (hf_dflash.py)
implementing get_dummy_inputs() following the same shape/keys pattern used by
HFEagleModel (so it returns a dict of dummy tensors/arrays for the export
forward pass rather than raising), ensuring the method signature matches the
base and provides the expected keys consumed by unified_export_hf.py (called at
the export flow around unified_export_hf.py:386) so speculative/offline export
does not fail at runtime.

Comment on lines +159 to +163
inputs_embeds = inputs_embeds.to(hidden_states.dtype).to(hidden_states.device)
# In EAGLE-3, we save input embeddings to attribute, and use it in first decoder layer by hook function
# Also, we normalize input embeddings and hidden states before concatenating them.
# The default input norm in first layer attn will be disabled.
self._input_embeds = self.layers[0].input_layernorm(inputs_embeds)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Only stash _input_embeds when the EAGLE-3 pre-hook will consume it.

These lines run for every config, but the pre-hook is only registered under use_aux_hidden_state. In the normal path this leaves one extra activation tensor hanging off self for no reason, which increases memory pressure and can keep autograd state alive longer than necessary.

♻️ Proposed fix
         inputs_embeds = inputs_embeds.to(hidden_states.dtype).to(hidden_states.device)
         # In EAGLE-3, we save input embeddings to attribute, and use it in first decoder layer by hook function
         # Also, we normalize input embeddings and hidden states before concatenating them.
         # The default input norm in first layer attn will be disabled.
-        self._input_embeds = self.layers[0].input_layernorm(inputs_embeds)
+        if self.config.use_aux_hidden_state:
+            self._input_embeds = self.layers[0].input_layernorm(inputs_embeds)
+        else:
+            self._input_embeds = None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/modeling_eagle.py` around lines 159 - 163,
The code always assigns self._input_embeds from
self.layers[0].input_layernorm(inputs_embeds) even when the EAGLE-3 pre-hook
that consumes it is not registered; guard the assignment so it only runs when
the pre-hook will consume the value (e.g., check
self.config.use_aux_hidden_state or the same condition used when registering the
pre-hook) to avoid retaining an unnecessary activation tensor. Update the block
that sets self._input_embeds (referencing inputs_embeds,
self.layers[0].input_layernorm, and self._input_embeds) to perform the
assignment only when use_aux_hidden_state (or the hook-registered flag) is true.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 16, 2026

Codecov Report

❌ Patch coverage is 61.84539% with 153 lines in your changes missing coverage. Please review.
✅ Project coverage is 71.35%. Comparing base (361f7e3) to head (6df234f).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
...delopt/torch/speculative/plugins/modeling_eagle.py 25.30% 62 Missing ⚠️
modelopt/torch/speculative/plugins/hf_medusa.py 35.41% 31 Missing ⚠️
modelopt/torch/speculative/plugins/hf_eagle.py 29.03% 22 Missing ⚠️
...odelopt/torch/speculative/plugins/hf_spec_mixin.py 70.17% 17 Missing ⚠️
modelopt/torch/speculative/plugins/hf_dflash.py 72.00% 7 Missing ⚠️
modelopt/torch/speculative/config.py 78.26% 5 Missing ⚠️
...elopt/torch/speculative/plugins/modeling_dflash.py 96.82% 4 Missing ⚠️
modelopt/torch/speculative/utils.py 0.00% 3 Missing ⚠️
modelopt/torch/export/plugins/hf_spec_export.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1271      +/-   ##
==========================================
- Coverage   76.07%   71.35%   -4.72%     
==========================================
  Files         459      464       +5     
  Lines       48528    48993     +465     
==========================================
- Hits        36917    34961    -1956     
- Misses      11611    14032    +2421     
Flag Coverage Δ
examples 39.11% <27.93%> (+0.79%) ⬆️
gpu 51.87% <61.84%> (-8.61%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment thread modelopt/torch/speculative/plugins/hf_medusa.py Outdated
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@h-guo18 h-guo18 force-pushed the haoguo/spec-mixin branch from 6df234f to f91cf9d Compare April 16, 2026 21:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants