Skip to content

[Bug] NVFP4 TRT compile fails on B300 HyperPod (Server-Blackwell): Myelin pw_vtx_analyzer not implemented (works on RTX PRO 6000 Blackwell) #4196

@crissed-12labs

Description

@crissed-12labs

Bug Description

torch_tensorrt.dynamo.compile() fails when compiling a modelopt-NVFP4-quantized PE-Core Vision Transformer (public OpenCLIP-style visual encoder) on B300 HyperPod (GH200-class / server-Blackwell). The same stack + model + quantization config + compile call succeeds on RTX PRO 6000 Blackwell (Workstation Blackwell / sm_120). Two different failures are observed depending on NGC tag:

Failure A — NGC pytorch:26.03-py3 (torch_tensorrt 2.11, TRT 10.16, modelopt 0.41)

[TensorRT] ERROR: MyelinCheckException: pw_vtx_analyzer CHECK(false) Not implemented

raised during torch_tensorrt.dynamo.compile() at the TRT engine-build step, after the modelopt quantize-op graph is converted.

Failure B — NGC pytorch:26.01-py3 (torch_tensorrt 2.10, TRT 10.13, modelopt 0.40)

[TensorRT] ERROR: [DYNAMIC_QUANTIZE] double quant scale must be FP32

Same model, same static-shape compile call. (Note: this is static-shape NVFP4; we are not using torch.export.Dim(...) ranges on the blocked axis — this is a different issue from #3745.)

Working counterexample (same config on RTX Blackwell)

On g7e (NVIDIA RTX PRO 6000 Blackwell, Workstation Blackwell), the same container image + same model + same modelopt NVFP4_DEFAULT_CFG + same torch_tensorrt.dynamo.compile(...) call succeeds and reaches production-grade perf:

  • PE-Core-L14-336 visual encoder: 220.7x real-time ratio at B=1, T=16
  • PE-Core-G14-448 visual encoder: 47.7x real-time ratio at B=1, T=16

So the failure is B300-specific, not a general Blackwell-NVFP4 issue. Most likely cause is missing kernel specialization or a missing vertex-analyzer code path for Server-Blackwell in Myelin (GB200/GH200-class) — RTX-Blackwell sm_120 and Server-Blackwell sm_100 take different Myelin tactic paths for the DYNAMIC_QUANTIZE → Dyna fusion.

Environment

  • GPU: NVIDIA B300 (HyperPod, GH200-class, Server-Blackwell)
  • NGC container: nvcr.io/nvidia/pytorch:26.03-py3 (primary) and 26.01-py3 (secondary)
  • torch: 2.11.0a0 (NGC build) / 2.10 (26.01)
  • torch_tensorrt: 2.11.x / 2.10
  • TensorRT: 10.16 / 10.13
  • modelopt (nvidia-modelopt): 0.41 / 0.40
  • CUDA: 13.0 (26.03) / 12.9 (26.01)
  • Driver: HyperPod stock

Working-reference GPU: NVIDIA RTX PRO 6000 Blackwell, same container images.

Minimal reproducer

Same shape as the existing repro in #3745 but no torch.export.Dim — fully static inputs. Quantization config is public mtq.NVFP4_DEFAULT_CFG. Model is any PE-Core Vision Transformer checkpoint (e.g. the public PE-Core-L14-336 OpenCLIP-format weights — ViT-L/14, patch size 14, 336x336, plus standard conv patch embed + class token + learned abs pos emb + 24-layer transformer + global-pool head).

import torch, torch_tensorrt as torchtrt
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode

model = build_pe_core_l14_336().eval().cuda()  # public PE-Core visual encoder
dummy = torch.randn(1, 16, 3, 336, 336, dtype=torch.bfloat16).cuda()

mtq.quantize(model, mtq.NVFP4_DEFAULT_CFG, forward_loop=lambda m: m(dummy))

with torch.no_grad(), export_torch_mode():
    ep = torch.export._trace._export(model, (dummy,), strict=False)
    trt_model = torchtrt.dynamo.compile(
        ep, inputs=[dummy], min_block_size=1,
        enabled_precisions={torch.bfloat16},
        use_explicit_typing=True,
        cache_built_engines=False, reuse_cached_engines=False,
    )

On RTX PRO 6000 Blackwell: succeeds.
On B300 HyperPod: fails with the stack above.

Happy to attach a full runnable script + recorded stacktrace + full TRT verbose logs from both NGC tags if it helps — just say the word.

Why this matters

NVFP4 is the whole point of shipping Blackwell for vision-transformer inference. Today we have to fall back to FP8-dynamic on B300 while keeping NVFP4 on RTX Blackwell — effectively a 20%+ perf regression on server-class Blackwell vs workstation-class Blackwell for the same workload.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions