Skip to content

feat: add TRT-RTX native CUDA graph support#4187

Draft
tp5uiuc wants to merge 1 commit intopytorch:mainfrom
tp5uiuc:feat/trtrtx-cudagraphs
Draft

feat: add TRT-RTX native CUDA graph support#4187
tp5uiuc wants to merge 1 commit intopytorch:mainfrom
tp5uiuc:feat/trtrtx-cudagraphs

Conversation

@tp5uiuc
Copy link
Copy Markdown
Contributor

@tp5uiuc tp5uiuc commented Apr 14, 2026

Description

Add cuda_graph_strategy compilation setting and automatic RTX-native CUDA graph integration for the Python runtime path (PythonTorchTensorRTModule).

TensorRT-RTX has native CUDA graph support via IRuntimeConfig.cuda_graph_strategy, where the JIT compiler handles capture/replay/invalidation internally. This is superior to manual torch.cuda.CUDAGraph() capture on RTX because:

  • Manual capture freezes fallback kernels; lazy-compiled specialized kernels can never replace them
  • Runtime allocation or data-dependent shapes can cause cudaStreamBeginCapture to fail
  • The JIT compiler automatically manages graph staleness (shape changes, pointer changes, kernel readiness)

Key changes

  • New cuda_graph_strategy setting on CompilationSettings ("disabled" / "whole_graph_capture")
  • Mapped to trt.CudaGraphStrategy on IRuntimeConfig (same pattern as dynamic_shapes_kernel_specialization_strategy)
  • SUBGRAPH mode (set_cudagraphs_mode(True)): On RTX, always use RTX-native CUDA graphs — manual capture is bypassed. If cuda_graph_strategy was not explicitly set, the runtime overrides to whole_graph_capture and warns.
  • WHOLE_GRAPH mode (enable_cudagraphs() with mixed TRT + PyTorch ops): Validates all TRT engines are monolithically capturable via context.is_stream_capturable(stream) and strategy != "lazy". If capturable, proceeds with outer monolithic capture (RTX-native disabled per-engine). If not capturable, raises RuntimeError.
  • _is_monolithic_capturable() — runtime check combining stream capturability and kernel specialization strategy
  • _enable_rtx_native_cudagraphs() — recreates execution context with WHOLE_GRAPH_CAPTURE
  • _check_monolithic_capturability() in CudaGraphsTorchTensorRTModule for mixed graph validation

Behavior matrix

Graph type cudagraph mode RTX? Behavior
TRT-only SUBGRAPH Yes RTX-native always (override if needed)
TRT-only SUBGRAPH No Manual capture (existing)
Mixed WHOLE_GRAPH Yes + capturable Monolithic capture; RTX-native disabled per-engine
Mixed WHOLE_GRAPH Yes + NOT capturable RuntimeError
Mixed WHOLE_GRAPH No Monolithic capture (existing)
Any No cudagraphs + strategy set Yes RTX-native runs transparently

Depends on #4180 (runtime cache) and #4184 (dynamic shapes strategy).

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@meta-cla meta-cla Bot added the cla signed label Apr 14, 2026
@github-actions github-actions Bot added documentation Improvements or additions to documentation component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: build system Issues re: Build system component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Apr 14, 2026
@github-actions github-actions Bot requested a review from zewenli98 April 14, 2026 11:12
@tp5uiuc tp5uiuc force-pushed the feat/trtrtx-cudagraphs branch 2 times, most recently from 56f1d3f to 52d9ba5 Compare April 20, 2026 16:08
@narendasan
Copy link
Copy Markdown
Collaborator

@tp5uiuc The Python runtime is about to get re-implemented with #4164. Make sure to align your implementation with that. Also there are existing CUDA graph facilities you might want to integrate into.

@tp5uiuc
Copy link
Copy Markdown
Contributor Author

tp5uiuc commented Apr 20, 2026

@tp5uiuc The Python runtime is about to get re-implemented with #4164. Make sure to align your implementation with that. Also there are existing CUDA graph facilities you might want to integrate into.

Thanks for calling this out Naren, I was unaware. I will align my implementation as you suggest, thanks!

Add cuda_graph_strategy compilation setting and automatic RTX-native
CUDA graph integration for the Python runtime path.

Key changes:
- New cuda_graph_strategy setting ("disabled" / "whole_graph_capture")
  on CompilationSettings, mapped to trt.CudaGraphStrategy on
  IRuntimeConfig (same pattern as dynamic_shapes_kernel_specialization)
- In SUBGRAPH cudagraph mode on RTX, always use RTX-native CUDA graphs
  (manual torch.cuda.CUDAGraph capture is not safe due to lazy kernel
  specialization and potential runtime allocation)
- _is_monolithic_capturable() check using context.is_stream_capturable()
  and strategy != "lazy" for WHOLE_GRAPH mode safety validation
- _enable_rtx_native_cudagraphs() for runtime context recreation
- _check_monolithic_capturability() in CudaGraphsTorchTensorRTModule
  for mixed TRT + PyTorch graph validation
- Comprehensive unit tests covering all code paths

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@tp5uiuc tp5uiuc force-pushed the feat/trtrtx-cudagraphs branch from 52d9ba5 to c36fba4 Compare April 22, 2026 14:38
def set_use_output_allocator(self, enable: bool) -> None:
self.use_output_allocator_outputs = enable

def _check_monolithic_capturability(self, stream: torch.cuda.Stream) -> None:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TRT-RTX would need to avoid the
"If your input shapes change between requests, the graph is re-recorded for each new shape. "
behavior from torch-TRT here in subgraphs mode. TRT-RTX takes care of re-capturing graphs internally if shapes have changed.

https://docs.pytorch.org/TensorRT/tutorials/runtime_opt/cuda_graphs.html

We should add an explicit test to verify this.

# Check 2: Lazy kernel specialization would invalidate captured graph
if self.settings.dynamic_shapes_kernel_specialization_strategy == "lazy":
return False
return True
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor to use any(conditions) rather than individual checks.

if ENABLED_FEATURES.tensorrt_rtx:
self._setup_runtime_config()
self._rtx_native_cudagraphs = (
ENABLED_FEATURES.tensorrt_rtx
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ENABLED_FEATURES.tensorrt_rtx is already true, don't need to double check.

tp5uiuc added a commit to tp5uiuc/TensorRT that referenced this pull request Apr 22, 2026
Address PR review comments that asked the new C++ runtime tests be
folded into existing feature-level files rather than shipped as
parallel `*_cpp.py` files.

What
 - Merge `test_000_runtime_cache_cpp.py` into the existing
   `test_000_runtime_cache.py`. The file already covered the Python
   runtime path; two new classes (`TestRuntimeCacheCppPersistence`,
   `TestCppSerializationIndices`) cover the C++ runtime path via
   `use_python_runtime=False`, and the serialization-index
   assertions. Skip on non-RTX builds.
 - Fold the C++ runtime cases for dynamic shapes kernel
   specialization strategy into `test_001_dynamic_shapes_kernel_
   strategy.py` (introduced upstream in PR pytorch#4184). Two new classes
   (`TestDynamicShapesKernelStrategyCpp`, `TestDynamicShapesKernel
   StrategyCppInvalidValue`) exercise lazy/eager/none end-to-end and
   reject invalid strategy names. The pre-existing Python runtime
   tests remain untouched.
 - Rename `test_000_cuda_graph_strategy.py` to `test_001_cuda_graph_
   strategy.py` to match the `test_001_*` convention used for L1
   RTX-only features. When upstream lands the Python runtime
   counterpart (PR pytorch#4187), both sets fold into the same file.
 - Add model-level tests: `test_runtime_cache_models.py` gains a
   `TestRuntimeCacheCppModels` class exercising ResNet18 through the
   C++ runtime with warm-cache roundtrip. `test_dynamic_shapes_
   kernel_strategy_models.py` gains `TestDynamicShapesKernelStrategy
   CppModels` covering lazy/eager/none on ResNet18 via the C++
   runtime.

Verified
 - 35 passed / 3 skipped in the runtime/ tests (merged file plus
   test_001 strategy files).
 - No regression in test_002_cudagraphs_cpp.py (8 passed) or
   test_005_dynamic_allocation.py (1 passed).

Addresses PR pytorch#4202 review comments asking for test file merges and
the addition of model-level runtime_cache_models.py /
dynamic_shapes_kernel_strategy_models.py coverage.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend: TensorRT-RTX cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime component: tests Issues re: Tests documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants