From 290cf9719f74fda235b2afb01ee47ba7f78164b9 Mon Sep 17 00:00:00 2001 From: IDExpensive-One Date: Mon, 9 Mar 2026 19:41:56 +0800 Subject: [PATCH] Support variance ONNX export on PyTorch 2.x The variance model export previously required PyTorch 1.13.x because it used torch.jit.script on models returned by view_as_*_predictor(). This fails on PyTorch 2.x with "Unsupported value kind: Tensor" as TorchScript tries to compile all methods on the class, including those referencing deleted attributes. This adds a fallback path for PyTorch 2.x that uses lightweight wrapper modules with trace-based torch.onnx.export instead of torch.jit.script. The original script-based path is preserved for PyTorch 1.13.x. The version check in export.py is relaxed from a hard error to a warning. --- deployment/exporters/variance_exporter.py | 120 +++++++++++++--------- scripts/export.py | 7 +- 2 files changed, 79 insertions(+), 48 deletions(-) diff --git a/deployment/exporters/variance_exporter.py b/deployment/exporters/variance_exporter.py index 74b455ee5..d63b78c72 100644 --- a/deployment/exporters/variance_exporter.py +++ b/deployment/exporters/variance_exporter.py @@ -388,33 +388,47 @@ def _torch_export_model(self): noise = torch.randn(shape, device=self.device) condition = torch.rand((1, hparams['hidden_size'], 15), device=self.device) - print(f'Tracing {self.pitch_backbone_class_name} backbone...') pitch_predictor = self.model.view_as_pitch_predictor() - pitch_predictor.pitch_predictor.set_backbone( - torch.jit.trace( - pitch_predictor.pitch_predictor.backbone, - ( - noise, - dummy_time, - condition + + if torch.__version__.startswith('1.13.'): + print(f'Tracing {self.pitch_backbone_class_name} backbone...') + pitch_predictor.pitch_predictor.set_backbone( + torch.jit.trace( + pitch_predictor.pitch_predictor.backbone, + ( + noise, + dummy_time, + condition + ) ) ) - ) - print(f'Scripting {self.pitch_predictor_class_name}...') - pitch_predictor = torch.jit.script( - pitch_predictor, - example_inputs=[ - ( - condition.transpose(1, 2), - 1 # p_sample branch - ), - ( - condition.transpose(1, 2), - dummy_steps # p_sample_plms branch - ) - ] - ) + print(f'Scripting {self.pitch_predictor_class_name}...') + pitch_predictor = torch.jit.script( + pitch_predictor, + example_inputs=[ + ( + condition.transpose(1, 2), + 1 # p_sample branch + ), + ( + condition.transpose(1, 2), + dummy_steps # p_sample_plms branch + ) + ] + ) + else: + print(f'Wrapping {self.pitch_predictor_class_name} for trace-based export...') + + class _PitchPredWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.pitch_predictor = model.pitch_predictor + + def forward(self, pitch_cond, steps): + return self.pitch_predictor(pitch_cond, steps=steps) + + pitch_predictor = _PitchPredWrapper(pitch_predictor) print(f'Exporting {self.pitch_predictor_class_name}...') torch.onnx.export( @@ -535,33 +549,47 @@ def _torch_export_model(self): condition = torch.rand((1, hparams['hidden_size'], 15), device=self.device) step = (torch.rand((1,), device=self.device) * hparams['K_step']).long() - print(f'Tracing {self.variance_backbone_class_name} backbone...') multi_var_predictor = self.model.view_as_variance_predictor() - multi_var_predictor.variance_predictor.set_backbone( - torch.jit.trace( - multi_var_predictor.variance_predictor.backbone, - ( - noise, - step, - condition + + if torch.__version__.startswith('1.13.'): + print(f'Tracing {self.variance_backbone_class_name} backbone...') + multi_var_predictor.variance_predictor.set_backbone( + torch.jit.trace( + multi_var_predictor.variance_predictor.backbone, + ( + noise, + step, + condition + ) ) ) - ) - print(f'Scripting {self.multi_var_predictor_class_name}...') - multi_var_predictor = torch.jit.script( - multi_var_predictor, - example_inputs=[ - ( - condition.transpose(1, 2), - 1 # p_sample branch - ), - ( - condition.transpose(1, 2), - dummy_steps # p_sample_plms branch - ) - ] - ) + print(f'Scripting {self.multi_var_predictor_class_name}...') + multi_var_predictor = torch.jit.script( + multi_var_predictor, + example_inputs=[ + ( + condition.transpose(1, 2), + 1 # p_sample branch + ), + ( + condition.transpose(1, 2), + dummy_steps # p_sample_plms branch + ) + ] + ) + else: + print(f'Wrapping {self.multi_var_predictor_class_name} for trace-based export...') + + class _VarPredWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.variance_predictor = model.variance_predictor + + def forward(self, variance_cond, steps): + return self.variance_predictor(variance_cond, steps=steps) + + multi_var_predictor = _VarPredWrapper(multi_var_predictor) print(f'Exporting {self.multi_var_predictor_class_name}...') torch.onnx.export( diff --git a/scripts/export.py b/scripts/export.py index d666175d6..8d6eb2168 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -15,10 +15,13 @@ def check_pytorch_version(): - # Require PyTorch version to be exactly 1.13.x + import warnings if torch.__version__.startswith('1.13.'): return - raise RuntimeError('This script requires PyTorch 1.13.x. Please install the correct version.') + warnings.warn( + f'ONNX export is tested on PyTorch 1.13.x, but you have {torch.__version__}. ' + f'Proceeding with trace-based fallback for variance models.' + ) def find_exp(exp):