feat(runtime): add TensorRT-RTX runtime cache, dynamic shapes strategy, and native CUDA graph support to C++ runtime#4202
Conversation
| trt_engine_profiler.reset(); | ||
| exec_ctx = make_trt(cuda_engine->createExecutionContext()); | ||
| TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context"); | ||
| recreate_execution_context(); |
There was a problem hiding this comment.
Disabling profiling doesn't seem to respect nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED if this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic. Is this expected?
| TORCHTRT_CHECK( | ||
| (exec_ctx.get() != nullptr), | ||
| "Unable to recreate TensorRT execution context after setting new device memory budget"); | ||
| recreate_execution_context(); |
There was a problem hiding this comment.
Same as https://github.com/pytorch/TensorRT/pull/4202/changes#r3120574611, confirm
| apply_dynamic_shapes_kernel_strategy(); | ||
| apply_cuda_graph_strategy(); | ||
| } | ||
| runtime_config->setExecutionContextAllocationStrategy( |
There was a problem hiding this comment.
If runtime_config is not none, I wonder if we should save_runtime_cache() here so that the next reload can reuse any compiled kernels from the present run.
Lay the shared infrastructure used by three upcoming TensorRT-RTX-only
runtime features (runtime cache, dynamic shapes kernel specialization
strategy, native CUDA graph strategy) in the C++ runtime path.
Core changes
- Bump ABI_VERSION from "8" to "9" and add three new SerializedInfoIndex
entries (RUNTIME_CACHE_PATH_IDX, DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX,
CUDA_GRAPH_STRATEGY_IDX). One bump covers all three feature fields.
- Add an IRuntimeConfig + IRuntimeCache shared_ptr pair to TRTEngine
behind TRT_MAJOR_RTX, plus three plain string/int fields that remain
serializable on non-RTX builds so the ABI is stable across both.
- Extract a private recreate_execution_context() helper that is the
single site where exec_ctx is built. On RTX builds it creates (once)
the IRuntimeConfig, invokes per-feature appliers, and then creates
the execution context via createExecutionContext(IRuntimeConfig*).
Replaces four prior direct createExecutionContext call sites in the
constructor, disable_profiling, set_device_memory_budget, and
set_resource_allocation_strategy so each automatically inherits the
runtime-config path on RTX.
- Declare apply_runtime_cache / apply_dynamic_shapes_kernel_strategy /
apply_cuda_graph_strategy as private RTX-only helpers with empty
bodies; follow-up commits fill these in per feature. The empty
stubs keep this commit behavior-neutral.
- Extend TRTEngine::serialize, the deserialization constructor, the
__obj_flatten__ tuple, and to_str so the new fields round-trip.
- Expose RUNTIME_CACHE_PATH_IDX, DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX,
and CUDA_GRAPH_STRATEGY_IDX via torch.ops.tensorrt.
Python side
- Add dynamic_shapes_kernel_specialization_strategy ("lazy" default)
and cuda_graph_strategy ("disabled" default) to _defaults.py,
CompilationSettings, and the three compile() entry points.
- Thread runtime_cache_path, dynamic_shapes_kernel_specialization_
strategy, and cuda_graph_strategy through _TorchTensorRTModule._
pack_engine_info with string-to-int maps so the C++ engine sees
validated integer codes (0/1/2 for strategies) and raises
ValueError for unknown strings.
No behavior change yet: the RTX appliers are empty and all new strategy
defaults select the prior code paths.
Implement TensorRT-RTX runtime cache persistence in the C++ runtime path (TorchTensorRTModule / TRTEngine). Mirrors the Python-runtime feature landed in pytorch#4180. What - apply_runtime_cache() (no-op stub from the prior commit) now creates an IRuntimeCache from the IRuntimeConfig, loads any existing cache file from the configured path, and attaches the cache to the config via IRuntimeConfig::setRuntimeCache (taken by const reference). - load_runtime_cache() reads the cache under an advisory shared lock (flock LOCK_SH) on POSIX. Concurrent readers coexist; transient failures downgrade to warnings so inference never blocks on cache IO. - save_runtime_cache() writes the serialized cache atomically via tmp-file + rename under an exclusive lock (flock LOCK_EX). The write path creates intermediate directories as needed. On Windows the save falls back to a best-effort write without advisory locking and emits a warning; LockFileEx support is a follow-up. - ~TRTEngine() now invokes save_runtime_cache() before tearing down the cache, config, and execution context so JIT compilation results survive process exits. Why - TensorRT-RTX JIT-compiles specialized kernels at inference time. The runtime cache lets those compilations persist across runs and across processes, which was measured at ~8x warm-vs-cold speedup in the Python-runtime implementation. - Without this commit, users relying on the C++ runtime (TorchScript deployments, use_python_runtime=False) would have no way to retain JIT work and would pay the cold-start cost on every process start. Tests - tests/py/dynamo/runtime/test_000_runtime_cache_cpp.py exercises the C++ runtime path (use_python_runtime=False) with cache save on destructor, directory creation, warm-cache roundtrip correctness via cosine-similarity, and ABI/index registration.
…++ runtime
Wire the dynamic_shapes_kernel_specialization_strategy compile setting
into the C++ runtime path on TensorRT-RTX by filling in the
apply_dynamic_shapes_kernel_strategy() body introduced in the
scaffolding commit.
What
- apply_dynamic_shapes_kernel_strategy() now calls
IRuntimeConfig::setDynamicShapesKernelSpecializationStrategy with
the integer code (0=lazy, 1=eager, 2=none) that was validated at
engine construction.
- The setting is applied once when the IRuntimeConfig is first built
inside recreate_execution_context(); the value is serialized with
the engine so deserialized modules restore the same strategy.
Why
- "lazy" (the default) compiles specialized kernels in the background
and uses fallbacks until they are ready - good for latency of the
first call but hands-off for steady-state throughput.
- "eager" compiles the specialized kernel synchronously on first use,
blocking inference but eliminating the fallback phase.
- "none" disables kernel specialization entirely and always uses the
generic fallback. Useful in combination with outer CUDA graph
capture where a stable set of kernels is required.
Tests
- tests/py/dynamo/runtime/test_000_dynamic_shapes_kernel_strategy.py
validates the setting default, the full {lazy, eager, none}
matrix through the C++ runtime (use_python_runtime=False), dynamic
shape traversal under "eager", and ValueError rejection of unknown
strategy names at engine-packing time.
…time
Wire cuda_graph_strategy into the C++ runtime and make the execute_engine
CUDA graph path TensorRT-RTX-aware. Fills in the apply_cuda_graph_strategy
stub and adds coexistence handling for outer whole-graph capture.
What
- apply_cuda_graph_strategy() now calls IRuntimeConfig::setCudaGraphStrategy
with either kDISABLED (default) or kWHOLE_GRAPH_CAPTURE. On RTX this
hands capture/replay off to the TRT-RTX runtime, avoiding the lazy-kernel
and dynamic-shape hazards of wrapping enqueueV3 in at::cuda::CUDAGraph.
- is_monolithic_capturable(stream) returns whether an engine can safely
be captured by an outer torch.cuda.CUDAGraph: RTX builds check
IExecutionContext::isStreamCapturable and require a non-lazy kernel
strategy; non-RTX builds always return true.
- disable_rtx_native_cudagraphs() is a one-shot switch that turns off
the engine internal capture and recreates the execution context so
that outer stream captures contain the kernel launches directly.
- execute_engine.cpp now computes effective_cudagraphs. On RTX, if a
cuda_graph_strategy is set or SUBGRAPH cudagraphs is enabled, it
bypasses the manual at::cuda::CUDAGraph path (the TRT-RTX runtime
handles that inside enqueueV3). It also polls cudaStreamIsCapturing
on the engine stream and, if an outer capture is already running,
invokes disable_rtx_native_cudagraphs() so the outer capture proceeds
without collision.
Why
- On TRT-RTX, the manual at::cuda::CUDAGraph wrapper around enqueueV3
can freeze fallback kernels in the captured graph (kLAZY specialisation
would swap them later), and fails outright when the engine needs
runtime allocation, DDS, control flow, or weight streaming.
- Letting the TRT-RTX runtime own capture fixes both problems, and the
outer-capture detection keeps the feature compatible with the
existing CudaGraphsTorchTensorRTModule whole-graph wrapper without
requiring it to know anything about RTX internals.
Tests
- tests/py/dynamo/runtime/test_000_cuda_graph_strategy.py validates the
setting default, both {disabled, whole_graph_capture} through the
C++ runtime, the RTX-native override when set_cudagraphs_mode(True)
is combined with a strategy, repeated inference correctness, and
ValueError rejection of unknown strategy names.
37ba9f5 to
2b630e8
Compare
Address the structural PR feedback by extracting TensorRT-RTX-specific
IRuntimeConfig state into its own type and collapsing the per-feature
appliers that previously scattered `#ifdef TRT_MAJOR_RTX` through
TRTEngine.
What
- New core/runtime/TRTRuntimeConfig.{h,cpp} owns the IRuntimeConfig
shared_ptr plus (on TRT-RTX) the IRuntimeCache, runtime-cache path,
dynamic shapes kernel strategy, CUDA graph strategy, and the
rtx_native_cudagraphs_disabled one-shot flag. All per-feature
appliers live there as public members and are no-ops on non-RTX
builds, keeping the only `#ifdef TRT_MAJOR_RTX` scatter contained
in this new file.
- Strategy fields are now strongly-typed enums
(`DynamicShapesKernelStrategy`, `CudaGraphStrategyOption`) with
matching `to_string`/`to_int` helpers, validated at engine
construction via `to_dynamic_shapes_kernel_strategy` / `to_cuda_
graph_strategy_option` rather than raw int ranges.
- `TRTEngine::recreate_execution_context` is now backend-agnostic:
it calls `runtime_cfg.ensure_initialized`, applies the allocation
strategy, and creates the execution context via
`createExecutionContext(IRuntimeConfig*)`. Both standard TensorRT
and TRT-RTX go through this uniform path; only the three RTX-only
setters (`setRuntimeCache`, `setDynamicShapesKernel
SpecializationStrategy`, `setCudaGraphStrategy`) stay behind an
`#ifdef TRT_MAJOR_RTX` guard inside the struct.
- `~TRTEngine` now wraps cleanup in try/catch and delegates cache
persistence to `TRTRuntimeConfig::save_runtime_cache_nothrow`, so
stack unwinding can no longer propagate a cache-save failure out
of the destructor.
- `save_runtime_cache_nothrow` uses `std::filesystem` + atomic
`tmp+rename` only; file locking is out of scope for this PR and
will be introduced in a follow-up once we pick a portable
mechanism.
- `is_monolithic_capturable` asserts `exec_ctx` is non-null; the
three RTX-only appliers `TORCHTRT_ASSERT` that `config` is live
before dereferencing.
- `disable_rtx_native_cudagraphs` persists the runtime cache before
flipping the strategy so any kernels compiled under the internal
capture survive to the next reload.
- `TRTEngine::to_str` now emits human-readable strategy names (via
`to_string(enum)`) instead of integer codes.
- New serialization indices (`RUNTIME_CACHE_PATH_IDX`, `DYNAMIC_
SHAPES_KERNEL_STRATEGY_IDX`, `CUDA_GRAPH_STRATEGY_IDX`) are now
`#ifdef TRT_MAJOR_RTX`-gated in runtime.h, register_jit_hooks.cpp,
the FlattenedState tuple, the serialize/deserialize constructors,
and `__obj_flatten__`. Standard TRT builds keep `SERIALIZATION_LEN
== 11` so engines serialized there do not carry RTX-only slots.
- Python `_TorchTensorRTModule` reads the RTX-only index accessors
and writes the RTX-only engine-info slots only when
`ENABLED_FEATURES.tensorrt_rtx` is true. Standard TRT users see
no new behavior at runtime.
- Deduplicated `_compiler.py` arguments after rebase on upstream
main where PR pytorch#4184 had already added
`dynamic_shapes_kernel_specialization_strategy`. Kept one copy of
each arg; `cuda_graph_strategy` is threaded through all three
compile() entry points.
Build + tests
- RTX build on A100 / L40S: libtorchtrt.so and libtorchtrt_
runtime.so link clean, no `#ifdef` diagnostics. Pre-commit checks
pass (clang-format, black, isort, ruff, mypy, typos, buildifier).
- All 35 runtime-cache/strategy tests pass; regression across
test_000_runtime_cache.py (Python runtime), test_002_cudagraphs_
cpp.py, test_005_dynamic_allocation.py is green.
Addresses review comments on PR pytorch#4202:
- Guarding of new IDX entries and Python accessors on
TRT_MAJOR_RTX / ENABLED_FEATURES.tensorrt_rtx.
- Encapsulation of RTX-specific state in a dedicated type with
enumerated strategies and transparent standard-TRT/RTX behavior.
- Destructor exception safety.
- Unification of the execution-context creation path via
IRuntimeConfig.
- Removal of file locking for runtime-cache persistence.
- Debug asserts before dereferencing the live IRuntimeConfig.
- Human-readable to_str output.
- save_runtime_cache invoked from disable_rtx_native_cudagraphs.
Address PR review comments that asked the new C++ runtime tests be folded into existing feature-level files rather than shipped as parallel `*_cpp.py` files. What - Merge `test_000_runtime_cache_cpp.py` into the existing `test_000_runtime_cache.py`. The file already covered the Python runtime path; two new classes (`TestRuntimeCacheCppPersistence`, `TestCppSerializationIndices`) cover the C++ runtime path via `use_python_runtime=False`, and the serialization-index assertions. Skip on non-RTX builds. - Fold the C++ runtime cases for dynamic shapes kernel specialization strategy into `test_001_dynamic_shapes_kernel_ strategy.py` (introduced upstream in PR pytorch#4184). Two new classes (`TestDynamicShapesKernelStrategyCpp`, `TestDynamicShapesKernel StrategyCppInvalidValue`) exercise lazy/eager/none end-to-end and reject invalid strategy names. The pre-existing Python runtime tests remain untouched. - Rename `test_000_cuda_graph_strategy.py` to `test_001_cuda_graph_ strategy.py` to match the `test_001_*` convention used for L1 RTX-only features. When upstream lands the Python runtime counterpart (PR pytorch#4187), both sets fold into the same file. - Add model-level tests: `test_runtime_cache_models.py` gains a `TestRuntimeCacheCppModels` class exercising ResNet18 through the C++ runtime with warm-cache roundtrip. `test_dynamic_shapes_ kernel_strategy_models.py` gains `TestDynamicShapesKernelStrategy CppModels` covering lazy/eager/none on ResNet18 via the C++ runtime. Verified - 35 passed / 3 skipped in the runtime/ tests (merged file plus test_001 strategy files). - No regression in test_002_cudagraphs_cpp.py (8 passed) or test_005_dynamic_allocation.py (1 passed). Addresses PR pytorch#4202 review comments asking for test file merges and the addition of model-level runtime_cache_models.py / dynamic_shapes_kernel_strategy_models.py coverage.
| LOG_DEBUG("Dynamic shapes kernel specialization strategy set to " << to_string(dynamic_shapes_kernel_strategy)); | ||
|
|
||
| // CUDA graph strategy -- TRT-RTX only. | ||
| bool ok = config->setCudaGraphStrategy( |
There was a problem hiding this comment.
Can we enforce const correctness, this should be bool const ok or better yet, don't make the variable in the first place, but use directly in the site
Follow-up to 54f9ccd / 1fa8c82 addressing the second batch of PR pytorch#4202 review feedback. Pure refactor with no user-visible behavior change; all tests green on A100 (35 passed / 3 skipped + 9 regression passed). TRTEngine - Constructor signature simplified: three separate `runtime_cache_path` / `dynamic_shapes_kernel_strategy` / `cuda_graph_strategy` parameters collapsed into a single `TRTRuntimeConfig runtime_cfg` sink parameter. The forwarding ctor std::moves it into the primary ctor, which std::moves it into the member. - String sink parameters (mod_name, serialized_engine, serialized_ metadata) taken by value and moved into members / slugify. - Deserialization constructor routes through the new free function make_runtime_config_from_serialized, which internalizes the TRT_MAJOR_RTX-gated index reads so the constructor itself stays unguarded. - FlattenedState uses a single TRTRTX_FLATTENED_STATE_EXTRAS macro for the three RTX-only tuple entries instead of duplicating the first eleven entries across two branches. - Destructor restored to the pre-refactor structure: torch::cuda:: synchronize runs outside a try block and runtime_cfg.save_runtime_ cache (now noexcept by signature) is called directly. Exception safety is guaranteed by the member's type, not by a defensive try/catch. - __obj_flatten__ and serialize cast enum values via std::underlying_type_t<...> instead of int so serialization stays in lockstep with any future underlying-type change on the enums. TRTRuntimeConfig - Conversion helpers take std::underlying_type_t<Enum> (the declared 32-bit integer type) instead of raw int. Callers at serialization boundaries explicitly std::stoi / static_cast into the right type. - [[nodiscard]] added to to_string, to_dynamic_shapes_kernel_strategy, to_cuda_graph_strategy_option, uses_internal_capture, is_monolithic_ capturable, to_str, and make_runtime_config_from_serialized. - to_string default cases now TORCHTRT_CHECK(false, ...) with the unexpected integer value; std::unreachable is C++23. - set_execution_context_allocation_strategy is now const. - Cache I/O split into two layers: - Free functions load_runtime_cache(path, cache) and save_runtime_cache(path, cache) perform the raw std::filesystem I/O and use TORCHTRT_CHECK on failure -- exception-propagating, easier to test in isolation. - Member TRTRuntimeConfig::save_runtime_cache() is a noexcept wrapper that calls the free function and swallows exceptions via try/catch -- safe from a destructor. The _nothrow suffix is dropped from the member name (the signature now carries that contract). - write_to_str(ostream&) replaced by two functions: a const-correct to_str() -> std::string, and a free operator<<(ostream&, const TRTRuntimeConfig&) that wraps it with "Runtime cfg { ... }" delimiters. TRTEngine::to_str streams the config via the free operator. Python - _settings.py: removed a duplicated dynamic_shapes_kernel_ specialization_strategy field and its duplicated docstring left over from the upstream rebase of PR pytorch#4184 into our changes. Covers review comments 3126538200, 3126541782, 3126547529, 3126549147, 3126682329, 3126683329, 3126693226, 3126715369, 3126725953, 3126736626, 3126738422, 3126745230, 3126747553, 3126749405, 3126764831, 3126772536, 3126786564, 3126803652, 3126816780, 3126818065, 3126818561, 3126819429, 3126823781, 3126840987, 3126846827.
| ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl; | ||
| ss << " Target Platform: " << target_platform << std::endl; | ||
| ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl; | ||
| ss << runtime_cfg; |
There was a problem hiding this comment.
Use ss << runtime_cfg.to_str() here.
| std::tuple<std::string, std::string>, // Platform | ||
| std::tuple<std::string, std::string>>; // Resource Allocation Strategy | ||
| std::tuple<std::string, std::string> /* Resource Allocation Strategy */ | ||
| TRTRTX_FLATTENED_STATE_EXTRAS>; |
There was a problem hiding this comment.
TODO : Inline and fix
| to_dynamic_shapes_kernel_strategy(std::stoi(info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX])); | ||
| cfg.cuda_graph_strategy = to_cuda_graph_strategy_option(std::stoi(info[CUDA_GRAPH_STRATEGY_IDX])); | ||
| #else | ||
| (void)info; |
There was a problem hiding this comment.
Use TORCHTRT_UNUSED instead
| // narrowing. | ||
| [[nodiscard]] std::string to_string(DynamicShapesKernelStrategy s); | ||
| [[nodiscard]] std::string to_string(CudaGraphStrategyOption s); | ||
| [[nodiscard]] DynamicShapesKernelStrategy to_dynamic_shapes_kernel_strategy( | ||
| std::underlying_type_t<DynamicShapesKernelStrategy> v); | ||
| [[nodiscard]] CudaGraphStrategyOption to_cuda_graph_strategy_option(std::underlying_type_t<CudaGraphStrategyOption> v); |
There was a problem hiding this comment.
We don't need to expose these functions in the header, perfectly fine to use helper functions in C++ files, in an anonymous namespace.
| } | ||
|
|
||
| bool TRTRuntimeConfig::is_monolithic_capturable(nvinfer1::IExecutionContext* exec_ctx, cudaStream_t stream) const { | ||
| #if defined(TRT_MAJOR_RTX) && defined(ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION) |
There was a problem hiding this comment.
Remove the defined(ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION) everywhere. It is always true for TRT_MAJOR_RTX
Description
Extends three TensorRT-RTX runtime features that landed on the Python runtime (
PythonTorchTensorRTModule) to the C++ runtime path (TorchTensorRTModule→core/runtime/TRTEngine). All three features center onnvinfer1::IRuntimeConfig, which the C++ runtime previously did not use — it calledcreateExecutionContext(...)directly at four sites.Features ported in this stack (each in its own commit):
Without this PR, users on the C++ runtime path (TorchScript deployments,
use_python_runtime=False) cannot access any of these TRT-RTX features, and runtime-cache warm-start savings (~8× measured in #4180) are unavailable on that path.Commits (stacked, each reviewable independently)
feat(runtime): introduce IRuntimeConfig scaffolding and bump ABI to v9— shared infra. Adds theIRuntimeConfig/IRuntimeCachemembers (RTX-only), a privaterecreate_execution_context()helper replacing 4 directcreateExecutionContextcall sites, three newSerializedInfoIndexentries (RUNTIME_CACHE_PATH_IDX,DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX,CUDA_GRAPH_STRATEGY_IDX) with a single ABI bump8 → 9. Python settings +compile()parameter threading. No behavior change — the per-featureapply_*appliers are empty stubs filled in by subsequent commits.feat(runtime): add runtime cache to C++ runtime for TensorRT-RTX— mirror of feat: add runtime cache API for TensorRT-RTX #4180. Load on engine setup underflock(LOCK_SH), atomic save on destructor (tmp + renameunderflock(LOCK_EX)). Linux-only file locking in Phase 1; Windows falls back to a best-effort write with a warning.feat(runtime): add dynamic shapes kernel specialization strategy to C++ runtime— mirror of feat: add dynamic shapes kernel specialization strategy for TRT-RTX #4184. WiresIRuntimeConfig::setDynamicShapesKernelSpecializationStrategywith the validated integer code from the compile setting.feat(runtime): add TensorRT-RTX native CUDA graph strategy to C++ runtime— mirror of feat: add TRT-RTX native CUDA graph support #4187. WiresIRuntimeConfig::setCudaGraphStrategyand makesexecute_engine.cppTensorRT-RTX-aware: bypasses manualat::cuda::CUDAGraphcapture on RTX (TRT-RTX handles it internally) and usescudaStreamIsCapturingto detect outer whole-graph capture — disabling per-engine RTX-native capture one-shot soCudaGraphsTorchTensorRTModulecoexists without special handling.Why bundle three features in one PR
All three features require an
IRuntimeConfigon the engine, a single ABI bump, and extensions to the same serialization/deserialization code paths. Splitting into three independent PRs would trigger three consecutive ABI bumps and triple the surface area for backward-compat fallout. Keeping them in one stack keeps ABI changes atomic while still giving reviewers clean per-feature diffs.Type of change
"8"to"9"— old.pt/.epfiles targeting the C++ runtime will failverify_serialization_fmtwith a clear error, as with every prior ABI bump)"lazy","disabled") keep existing behavior; existing docs for the Python-runtime runtime cache already cover the conceptChecklist
CompilationSettings; full user-guide updates pending the feature PRs on the Python path)tests/py/dynamo/runtime/, each forcinguse_python_runtime=Falseso the C++ runtime is exercised end-to-end)test_000_runtime_cache.py,test_002_cudagraphs_cpp.py,test_005_dynamic_allocation.py)Test plan
Verified on A100 (via L40S in this iteration) with
TensorRT-RTX-1.4.0.76, CUDA 13.0, PyTorch nightly 2.13.0.dev20260420:python3 -m pytest runtime/test_000_runtime_cache_cpp.py— 5/5 pass (save on__del__, directory creation, cosine-similarity roundtrip after warm cache, ABI/index registration)python3 -m pytest runtime/test_000_dynamic_shapes_kernel_strategy.py— 7/7 pass (settings defaults +{lazy, eager, none}end-to-end + dynamic-shape traversal + invalid-value rejection)python3 -m pytest runtime/test_000_cuda_graph_strategy.py— 7/7 pass (settings defaults +{disabled, whole_graph_capture}end-to-end + RTX-native override underset_cudagraphs_mode(True)+ repeated inference + invalid-value rejection)test_000_runtime_cache.py(Python runtime) 12 passed 2 skipped;test_005_dynamic_allocation.py1 passed;test_002_cudagraphs_cpp.py8 passeduse_python_runtime=False,dynamic_shapes_kernel_specialization_strategy="eager",runtime_cache_path=/tmp/...; confirms cache saved402534bytes on destructor underflock