diff --git a/deployment/exporters/variance_exporter.py b/deployment/exporters/variance_exporter.py index 74b455ee5..d63b78c72 100644 --- a/deployment/exporters/variance_exporter.py +++ b/deployment/exporters/variance_exporter.py @@ -388,33 +388,47 @@ def _torch_export_model(self): noise = torch.randn(shape, device=self.device) condition = torch.rand((1, hparams['hidden_size'], 15), device=self.device) - print(f'Tracing {self.pitch_backbone_class_name} backbone...') pitch_predictor = self.model.view_as_pitch_predictor() - pitch_predictor.pitch_predictor.set_backbone( - torch.jit.trace( - pitch_predictor.pitch_predictor.backbone, - ( - noise, - dummy_time, - condition + + if torch.__version__.startswith('1.13.'): + print(f'Tracing {self.pitch_backbone_class_name} backbone...') + pitch_predictor.pitch_predictor.set_backbone( + torch.jit.trace( + pitch_predictor.pitch_predictor.backbone, + ( + noise, + dummy_time, + condition + ) ) ) - ) - print(f'Scripting {self.pitch_predictor_class_name}...') - pitch_predictor = torch.jit.script( - pitch_predictor, - example_inputs=[ - ( - condition.transpose(1, 2), - 1 # p_sample branch - ), - ( - condition.transpose(1, 2), - dummy_steps # p_sample_plms branch - ) - ] - ) + print(f'Scripting {self.pitch_predictor_class_name}...') + pitch_predictor = torch.jit.script( + pitch_predictor, + example_inputs=[ + ( + condition.transpose(1, 2), + 1 # p_sample branch + ), + ( + condition.transpose(1, 2), + dummy_steps # p_sample_plms branch + ) + ] + ) + else: + print(f'Wrapping {self.pitch_predictor_class_name} for trace-based export...') + + class _PitchPredWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.pitch_predictor = model.pitch_predictor + + def forward(self, pitch_cond, steps): + return self.pitch_predictor(pitch_cond, steps=steps) + + pitch_predictor = _PitchPredWrapper(pitch_predictor) print(f'Exporting {self.pitch_predictor_class_name}...') torch.onnx.export( @@ -535,33 +549,47 @@ def _torch_export_model(self): condition = torch.rand((1, hparams['hidden_size'], 15), device=self.device) step = (torch.rand((1,), device=self.device) * hparams['K_step']).long() - print(f'Tracing {self.variance_backbone_class_name} backbone...') multi_var_predictor = self.model.view_as_variance_predictor() - multi_var_predictor.variance_predictor.set_backbone( - torch.jit.trace( - multi_var_predictor.variance_predictor.backbone, - ( - noise, - step, - condition + + if torch.__version__.startswith('1.13.'): + print(f'Tracing {self.variance_backbone_class_name} backbone...') + multi_var_predictor.variance_predictor.set_backbone( + torch.jit.trace( + multi_var_predictor.variance_predictor.backbone, + ( + noise, + step, + condition + ) ) ) - ) - print(f'Scripting {self.multi_var_predictor_class_name}...') - multi_var_predictor = torch.jit.script( - multi_var_predictor, - example_inputs=[ - ( - condition.transpose(1, 2), - 1 # p_sample branch - ), - ( - condition.transpose(1, 2), - dummy_steps # p_sample_plms branch - ) - ] - ) + print(f'Scripting {self.multi_var_predictor_class_name}...') + multi_var_predictor = torch.jit.script( + multi_var_predictor, + example_inputs=[ + ( + condition.transpose(1, 2), + 1 # p_sample branch + ), + ( + condition.transpose(1, 2), + dummy_steps # p_sample_plms branch + ) + ] + ) + else: + print(f'Wrapping {self.multi_var_predictor_class_name} for trace-based export...') + + class _VarPredWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.variance_predictor = model.variance_predictor + + def forward(self, variance_cond, steps): + return self.variance_predictor(variance_cond, steps=steps) + + multi_var_predictor = _VarPredWrapper(multi_var_predictor) print(f'Exporting {self.multi_var_predictor_class_name}...') torch.onnx.export( diff --git a/scripts/export.py b/scripts/export.py index d666175d6..8d6eb2168 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -15,10 +15,13 @@ def check_pytorch_version(): - # Require PyTorch version to be exactly 1.13.x + import warnings if torch.__version__.startswith('1.13.'): return - raise RuntimeError('This script requires PyTorch 1.13.x. Please install the correct version.') + warnings.warn( + f'ONNX export is tested on PyTorch 1.13.x, but you have {torch.__version__}. ' + f'Proceeding with trace-based fallback for variance models.' + ) def find_exp(exp):