From 0e06e680a0756deedbd8a68545e98a56f913268d Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 6 Mar 2026 11:30:26 +0800 Subject: [PATCH 1/6] support mova inference --- diffsynth/configs/model_configs.py | 24 +- .../configs/vram_management_module_maps.py | 18 + diffsynth/models/mova_audio_dit.py | 57 ++ diffsynth/models/mova_audio_vae.py | 796 ++++++++++++++++++ diffsynth/models/mova_dual_tower_bridge.py | 595 +++++++++++++ diffsynth/models/wan_video_dit.py | 14 +- diffsynth/pipelines/mova_audio_video.py | 455 ++++++++++ diffsynth/utils/xfuser/__init__.py | 2 +- .../utils/xfuser/xdit_context_parallel.py | 30 +- .../acceleration/unified_sequence_parallel.py | 55 ++ .../mova/model_inference/MOVA-360p-TI2AV.py | 52 ++ .../mova/model_inference/MOVA-720p-TI2AV.py | 53 ++ 12 files changed, 2147 insertions(+), 4 deletions(-) create mode 100644 diffsynth/models/mova_audio_dit.py create mode 100644 diffsynth/models/mova_audio_vae.py create mode 100644 diffsynth/models/mova_dual_tower_bridge.py create mode 100644 diffsynth/pipelines/mova_audio_video.py create mode 100644 examples/mova/acceleration/unified_sequence_parallel.py create mode 100644 examples/mova/model_inference/MOVA-360p-TI2AV.py create mode 100644 examples/mova/model_inference/MOVA-720p-TI2AV.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index f9fa595c1..998052325 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -735,4 +735,26 @@ "state_dict_converter": "diffsynth.utils.state_dict_converters.anima_dit.AnimaDiTStateDictConverter", } ] -MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + +mova_series = [ + # Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors") + { + "model_hash": "8c57e12790e2c45a64817e0ce28cde2f", + "model_name": "mova_audio_dit", + "model_class": "diffsynth.models.mova_audio_dit.MovaAudioDit", + "extra_kwargs": {'has_image_input': False, 'patch_size': [1], 'in_dim': 128, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 128, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06} + }, + # Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors") + { + "model_hash": "418517fb2b4e919d2cac8f314fcf82ac", + "model_name": "mova_audio_vae", + "model_class": "diffsynth.models.mova_audio_vae.DacVAE", + }, + # Example: ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors") + { + "model_hash": "d1139dbbc8b4ab53cf4b4243d57bbceb", + "model_name": "mova_dual_tower_bridge", + "model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge", + }, +] +MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py index 902c38b41..0142958ad 100644 --- a/diffsynth/configs/vram_management_module_maps.py +++ b/diffsynth/configs/vram_management_module_maps.py @@ -249,6 +249,24 @@ "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", }, + "diffsynth.models.mova_audio_dit.MovaAudioDit": { + "diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule", + "diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.mova_audo_vae.DacVAE": { + "diffsynth.models.mova_audo_vae.Snake1d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, } def QwenImageTextEncoder_Module_Map_Updater(): diff --git a/diffsynth/models/mova_audio_dit.py b/diffsynth/models/mova_audio_dit.py new file mode 100644 index 000000000..7b2b5d133 --- /dev/null +++ b/diffsynth/models/mova_audio_dit.py @@ -0,0 +1,57 @@ +import torch +import torch.nn as nn +from .wan_video_dit import WanModel, precompute_freqs_cis, sinusoidal_embedding_1d +from einops import rearrange +from ..core import gradient_checkpoint_forward + +def precompute_freqs_cis_1d(dim: int, end: int = 16384, theta: float = 10000.0): + f_freqs_cis = precompute_freqs_cis(dim, end, theta) + return f_freqs_cis.chunk(3, dim=-1) + +class MovaAudioDit(WanModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + head_dim = kwargs.get("dim", 1536) // kwargs.get("num_heads", 12) + self.freqs = precompute_freqs_cis_1d(head_dim) + self.patch_embedding = nn.Conv1d( + kwargs.get("in_dim", 128), kwargs.get("dim", 1536), kernel_size=[1], stride=[1] + ) + + def precompute_freqs_cis(self, dim: int, end: int = 16384, theta: float = 10000.0): + self.f_freqs_cis = precompute_freqs_cis_1d(dim, end, theta) + + def forward(self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + **kwargs, + ): + t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + x, (f, ) = self.patchify(x) + freqs = torch.cat([ + self.freqs[0][:f].view(f, -1).expand(f, -1), + self.freqs[1][:f].view(f, -1).expand(f, -1), + self.freqs[2][:f].view(f, -1).expand(f, -1), + ], dim=-1).reshape(f, 1, -1).to(x.device) + + for block in self.blocks: + x = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x, context, t_mod, freqs, + ) + x = self.head(x, t) + x = self.unpatchify(x, (f, )) + return x + + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): + return rearrange( + x, 'b f (p c) -> b c (f p)', + f=grid_size[0], + p=self.patch_size[0] + ) diff --git a/diffsynth/models/mova_audio_vae.py b/diffsynth/models/mova_audio_vae.py new file mode 100644 index 000000000..570cd43f8 --- /dev/null +++ b/diffsynth/models/mova_audio_vae.py @@ -0,0 +1,796 @@ +import math +from typing import List, Union +import numpy as np +import torch +from torch import nn +from torch.nn.utils import weight_norm +import torch.nn.functional as F +from einops import rearrange + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) + + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i]) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.mean( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2], + ) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2], + ) + + def nll(self, sample, dims=[1, 2]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] + + return 0.5 * ( + -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + ResidualUnit(dim // 2, dilation=1), + ResidualUnit(dim // 2, dilation=3), + ResidualUnit(dim // 2, dilation=9), + Snake1d(dim // 2), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ) + + def forward(self, x): + return self.block(x) + + +class Encoder(nn.Module): + def __init__( + self, + d_model: int = 64, + strides: list = [2, 4, 8, 8], + d_latent: int = 64, + ): + super().__init__() + # Create first convolution + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in strides: + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride)] + + # Create last convolution + self.block += [ + Snake1d(d_model), + WNConv1d(d_model, d_latent, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + def forward(self, x): + return self.block(x) + + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + Snake1d(input_dim), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + output_padding=stride % 2, + ), + ResidualUnit(output_dim, dilation=1), + ResidualUnit(output_dim, dilation=3), + ResidualUnit(output_dim, dilation=9), + ) + + def forward(self, x): + return self.block(x) + + +class Decoder(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + d_out: int = 1, + ): + super().__init__() + + # Add first conv layer + layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(rates): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride)] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + WNConv1d(output_dim, d_out, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class DacVAE(nn.Module): + + def __init__( + self, + encoder_dim: int = 128, + encoder_rates: List[int] = [2, 3, 4, 5, 8], + latent_dim: int = 128, + decoder_dim: int = 2048, + decoder_rates: List[int] = [8, 5, 4, 3, 2], + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: bool = False, + sample_rate: int = 48000, + continuous: bool = True, + use_weight_norm: bool = False, + ): + super().__init__() + + self.encoder_dim = encoder_dim + self.encoder_rates = encoder_rates + self.decoder_dim = decoder_dim + self.decoder_rates = decoder_rates + self.sample_rate = sample_rate + self.continuous = continuous + self.use_weight_norm = use_weight_norm + + if latent_dim is None: + latent_dim = encoder_dim * (2 ** len(encoder_rates)) + + self.latent_dim = latent_dim + + self.hop_length = np.prod(encoder_rates) + self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) + + if not continuous: + self.n_codebooks = n_codebooks + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.quantizer = ResidualVectorQuantize( + input_dim=latent_dim, + n_codebooks=n_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + else: + self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1) + self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1) + + self.decoder = Decoder( + latent_dim, + decoder_dim, + decoder_rates, + ) + self.sample_rate = sample_rate + self.apply(init_weights) + + self.delay = self.get_delay() + + if not self.use_weight_norm: + self.remove_weight_norm() + + def get_delay(self): + # Any number works here, delay is invariant to input length + l_out = self.get_output_length(0) + L = l_out + + layers = [] + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + layers.append(layer) + + for layer in reversed(layers): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.ConvTranspose1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.Conv1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.ceil(L) + + l_in = L + + return (l_in - l_out) // 2 + + def get_output_length(self, input_length): + L = input_length + # Calculate output length + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.Conv1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.ConvTranspose1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.floor(L) + return L + + @property + def dtype(self): + """Get the dtype of the model parameters.""" + # Return the dtype of the first parameter found + for param in self.parameters(): + return param.dtype + return torch.float32 # fallback + + @property + def device(self): + """Get the device of the model parameters.""" + # Return the device of the first parameter found + for param in self.parameters(): + return param.device + return torch.device('cpu') # fallback + + def preprocess(self, audio_data, sample_rate): + if sample_rate is None: + sample_rate = self.sample_rate + assert sample_rate == self.sample_rate + + length = audio_data.shape[-1] + right_pad = math.ceil(length / self.hop_length) * self.hop_length - length + audio_data = nn.functional.pad(audio_data, (0, right_pad)) + + return audio_data + + def encode( + self, + audio_data: torch.Tensor, + n_quantizers: int = None, + ): + """Encode given audio data and return quantized latent codes + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + n_quantizers : int, optional + Number of quantizers to use, by default None + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + """ + z = self.encoder(audio_data) # [B x D x T] + if not self.continuous: + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers) + else: + z = self.quant_conv(z) # [B x 2D x T] + z = DiagonalGaussianDistribution(z) + codes, latents, commitment_loss, codebook_loss = None, None, 0, 0 + + return z, codes, latents, commitment_loss, codebook_loss + + def decode(self, z: torch.Tensor): + """Decode given latent codes and return audio data + + Parameters + ---------- + z : Tensor[B x D x T] + Quantized continuous representation of input + length : int, optional + Number of samples in output audio, by default None + + Returns + ------- + dict + A dictionary with the following keys: + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + if not self.continuous: + audio = self.decoder(z) + else: + z = self.post_quant_conv(z) + audio = self.decoder(z) + + return audio + + def forward( + self, + audio_data: torch.Tensor, + sample_rate: int = None, + n_quantizers: int = None, + ): + """Model forward pass + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + sample_rate : int, optional + Sample rate of audio data in Hz, by default None + If None, defaults to `self.sample_rate` + n_quantizers : int, optional + Number of quantizers to use, by default None. + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + length = audio_data.shape[-1] + audio_data = self.preprocess(audio_data, sample_rate) + if not self.continuous: + z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers) + + x = self.decode(z) + return { + "audio": x[..., :length], + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + else: + posterior, _, _, _, _ = self.encode(audio_data, n_quantizers) + z = posterior.sample() + x = self.decode(z) + + kl_loss = posterior.kl() + kl_loss = kl_loss.mean() + + return { + "audio": x[..., :length], + "z": z, + "kl_loss": kl_loss, + } + + def remove_weight_norm(self): + """ + Remove weight_norm from all modules in the model. + This fuses the weight_g and weight_v parameters into a single weight parameter. + Should be called before inference for better performance. + Returns: + self: The model with weight_norm removed + """ + from torch.nn.utils import remove_weight_norm + num_removed = 0 + for name, module in list(self.named_modules()): + if hasattr(module, "_forward_pre_hooks"): + for hook_id, hook in list(module._forward_pre_hooks.items()): + if "WeightNorm" in str(type(hook)): + try: + remove_weight_norm(module) + num_removed += 1 + # print(f"Removed weight_norm from: {name}") + except ValueError as e: + print(f"Failed to remove weight_norm from {name}: {e}") + if num_removed > 0: + # print(f"Successfully removed weight_norm from {num_removed} modules") + self.use_weight_norm = False + else: + print("No weight_norm found in the model") + return self diff --git a/diffsynth/models/mova_dual_tower_bridge.py b/diffsynth/models/mova_dual_tower_bridge.py new file mode 100644 index 000000000..ddb342e01 --- /dev/null +++ b/diffsynth/models/mova_dual_tower_bridge.py @@ -0,0 +1,595 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Dict, List, Tuple, Optional +from einops import rearrange +from .wan_video_dit import AttentionModule, RMSNorm +from ..core import gradient_checkpoint_forward + +class RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, base: float, dim: int, device=None): + super().__init__() + self.base = base + self.dim = dim + self.attention_scaling = 1.0 + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@torch.compile(fullgraph=True) +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class PerFrameAttentionPooling(nn.Module): + """ + Per-frame multi-head attention pooling. + + Given a flattened token sequence [B, L, D] and grid size (T, H, W), perform a + single-query attention pooling over the H*W tokens for each time frame, producing + [B, T, D]. + + Inspired by SigLIP's Multihead Attention Pooling head (without MLP/residual stack). + """ + + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + assert dim % num_heads == 0, "dim must be divisible by num_heads" + self.dim = dim + self.num_heads = num_heads + + self.probe = nn.Parameter(torch.randn(1, 1, dim)) + nn.init.normal_(self.probe, std=0.02) + + self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True) + self.layernorm = nn.LayerNorm(dim, eps=eps) + + def forward(self, x: torch.Tensor, grid_size: Tuple[int, int, int]) -> torch.Tensor: + """ + Args: + x: [B, L, D], where L = T*H*W + grid_size: (T, H, W) + Returns: + pooled: [B, T, D] + """ + B, L, D = x.shape + T, H, W = grid_size + assert D == self.dim, f"Channel dimension mismatch: D={D} vs dim={self.dim}" + assert L == T * H * W, f"Flattened length mismatch: L={L} vs T*H*W={T*H*W}" + + S = H * W + # Re-arrange tokens grouped by frame. + x_bt_s_d = x.view(B, T, S, D).contiguous().view(B * T, S, D) # [B*T, S, D] + + # A learnable probe as the query (one query per frame). + probe = self.probe.expand(B * T, -1, -1) # [B*T, 1, D] + + # Attention pooling: query=probe, key/value=H*W tokens within the frame. + pooled_bt_1_d = self.attention(probe, x_bt_s_d, x_bt_s_d, need_weights=False)[0] # [B*T, 1, D] + pooled_bt_d = pooled_bt_1_d.squeeze(1) # [B*T, D] + + # Restore to [B, T, D]. + pooled = pooled_bt_d.view(B, T, D) + pooled = self.layernorm(pooled) + return pooled + + +class CrossModalInteractionController: + """ + Strategy class that controls interactions between two towers. + Manages the interaction mapping between visual DiT (e.g. 30 layers) and audio DiT (e.g. 30 layers). + """ + + def __init__(self, visual_layers: int = 30, audio_layers: int = 30): + self.visual_layers = visual_layers + self.audio_layers = audio_layers + self.min_layers = min(visual_layers, audio_layers) + + def get_interaction_layers(self, strategy: str = "shallow_focus") -> Dict[str, List[Tuple[int, int]]]: + """ + Get interaction layer mappings. + + Args: + strategy: interaction strategy + - "shallow_focus": emphasize shallow layers to avoid deep-layer asymmetry + - "distributed": distributed interactions across the network + - "progressive": dense shallow interactions, sparse deeper interactions + - "custom": custom interaction layers + + Returns: + A dict containing mappings for 'v2a' (visual -> audio) and 'a2v' (audio -> visual). + """ + + if strategy == "shallow_focus": + # Emphasize the first ~1/3 layers to avoid deep-layer asymmetry. + num_interact = min(10, self.min_layers // 3) + interact_layers = list(range(0, num_interact)) + + elif strategy == "distributed": + # Distribute interactions across the network (every few layers). + step = 3 + interact_layers = list(range(0, self.min_layers, step)) + + elif strategy == "progressive": + # Progressive: dense shallow interactions, sparse deeper interactions. + shallow = list(range(0, min(8, self.min_layers))) # Dense for the first 8 layers. + if self.min_layers > 8: + deep = list(range(8, self.min_layers, 3)) # Every 3 layers afterwards. + interact_layers = shallow + deep + else: + interact_layers = shallow + + elif strategy == "custom": + # Custom strategy: adjust as needed. + interact_layers = [0, 2, 4, 6, 8, 12, 16, 20] # Explicit layer indices. + interact_layers = [i for i in interact_layers if i < self.min_layers] + + elif strategy == "full": + interact_layers = list(range(0, self.min_layers)) + + else: + raise ValueError(f"Unknown interaction strategy: {strategy}") + + # Build bidirectional mapping. + mapping = { + 'v2a': [(i, i) for i in interact_layers], # visual layer i -> audio layer i + 'a2v': [(i, i) for i in interact_layers] # audio layer i -> visual layer i + } + + return mapping + + def should_interact(self, layer_idx: int, direction: str, interaction_mapping: Dict) -> bool: + """ + Check whether a given layer should interact. + + Args: + layer_idx: current layer index + direction: interaction direction ('v2a' or 'a2v') + interaction_mapping: interaction mapping table + + Returns: + bool: whether to interact + """ + if direction not in interaction_mapping: + return False + + return any(src == layer_idx for src, _ in interaction_mapping[direction]) + + +class ConditionalCrossAttention(nn.Module): + def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + self.q_dim = dim + self.kv_dim = kv_dim + self.num_heads = num_heads + self.head_dim = self.q_dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(kv_dim, dim) + self.v = nn.Linear(kv_dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + def forward(self, x: torch.Tensor, y: torch.Tensor, x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None): + ctx = y + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(ctx)) + v = self.v(ctx) + if x_freqs is not None: + x_cos, x_sin = x_freqs + B, L, _ = q.shape + q_view = rearrange(q, 'b l (h d) -> b l h d', d=self.head_dim) + x_cos = x_cos.to(q_view.dtype).to(q_view.device) + x_sin = x_sin.to(q_view.dtype).to(q_view.device) + # Expect x_cos/x_sin shape: [B or 1, L, head_dim] + q_view, _ = apply_rotary_pos_emb(q_view, q_view, x_cos, x_sin, unsqueeze_dim=2) + q = rearrange(q_view, 'b l h d -> b l (h d)') + if y_freqs is not None: + y_cos, y_sin = y_freqs + Bc, Lc, _ = k.shape + k_view = rearrange(k, 'b l (h d) -> b l h d', d=self.head_dim) + y_cos = y_cos.to(k_view.dtype).to(k_view.device) + y_sin = y_sin.to(k_view.dtype).to(k_view.device) + # Expect y_cos/y_sin shape: [B or 1, L, head_dim] + _, k_view = apply_rotary_pos_emb(k_view, k_view, y_cos, y_sin, unsqueeze_dim=2) + k = rearrange(k_view, 'b l h d -> b l (h d)') + x = self.attn(q, k, v) + return self.o(x) + + +# from diffusers.models.attention import AdaLayerNorm +class AdaLayerNorm(nn.Module): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`, *optional*): The size of the embeddings dictionary. + output_dim (`int`, *optional*): + norm_elementwise_affine (`bool`, defaults to `False): + norm_eps (`bool`, defaults to `False`): + chunk_dim (`int`, defaults to `0`): + """ + + def __init__( + self, + embedding_dim: int, + num_embeddings: Optional[int] = None, + output_dim: Optional[int] = None, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + chunk_dim: int = 0, + ): + super().__init__() + + self.chunk_dim = chunk_dim + output_dim = output_dim or embedding_dim * 2 + + if num_embeddings is not None: + self.emb = nn.Embedding(num_embeddings, embedding_dim) + else: + self.emb = None + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) + + def forward( + self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if self.emb is not None: + temb = self.emb(timestep) + + temb = self.linear(self.silu(temb)) + + if self.chunk_dim == 2: + scale, shift = temb.chunk(2, dim=2) + # print(f"{x.shape = }, {scale.shape = }, {shift.shape = }") + elif self.chunk_dim == 1: + # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the + # other if-branch. This branch is specific to CogVideoX and OmniGen for now. + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + else: + scale, shift = temb.chunk(2, dim=0) + + x = self.norm(x) * (1 + scale) + shift + return x + + +class ConditionalCrossAttentionBlock(nn.Module): + """ + A thin wrapper around ConditionalCrossAttention. + Applies LayerNorm to the conditioning input `y` before cross-attention. + """ + def __init__(self, dim: int, kv_dim: int, num_heads: int, eps: float = 1e-6, pooled_adaln: bool = False): + super().__init__() + self.y_norm = nn.LayerNorm(kv_dim, eps=eps) + self.inner = ConditionalCrossAttention(dim=dim, kv_dim=kv_dim, num_heads=num_heads, eps=eps) + self.pooled_adaln = pooled_adaln + if pooled_adaln: + self.per_frame_pooling = PerFrameAttentionPooling(kv_dim, num_heads=num_heads, eps=eps) + self.adaln = AdaLayerNorm(kv_dim, output_dim=dim*2, chunk_dim=2) + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor, + x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + video_grid_size: Optional[Tuple[int, int, int]] = None, + ) -> torch.Tensor: + if self.pooled_adaln: + assert video_grid_size is not None, "video_grid_size must not be None" + pooled_y = self.per_frame_pooling(y, video_grid_size) + # Interpolate pooled_y along its temporal dimension to match x's sequence length. + if pooled_y.shape[1] != x.shape[1]: + pooled_y = F.interpolate( + pooled_y.permute(0, 2, 1), # [B, C, T] + size=x.shape[1], + mode='linear', + align_corners=False, + ).permute(0, 2, 1) # [B, T, C] + x = self.adaln(x, temb=pooled_y) + y = self.y_norm(y) + return self.inner(x=x, y=y, x_freqs=x_freqs, y_freqs=y_freqs) + + +class DualTowerConditionalBridge(nn.Module): + """ + Dual-tower conditional bridge. + """ + def __init__(self, + visual_layers: int = 40, + audio_layers: int = 30, + visual_hidden_dim: int = 5120, # visual DiT hidden state dimension + audio_hidden_dim: int = 1536, # audio DiT hidden state dimension + audio_fps: float = 50.0, + head_dim: int = 128, # attention head dimension + interaction_strategy: str = "full", + apply_cross_rope: bool = True, # whether to apply RoPE in cross-attention + apply_first_frame_bias_in_rope: bool = False, # whether to account for 1/video_fps bias for the first frame in RoPE alignment + trainable_condition_scale: bool = False, + pooled_adaln: bool = False, + ): + super().__init__() + + self.visual_hidden_dim = visual_hidden_dim + self.audio_hidden_dim = audio_hidden_dim + self.audio_fps = audio_fps + self.head_dim = head_dim + self.apply_cross_rope = apply_cross_rope + self.apply_first_frame_bias_in_rope = apply_first_frame_bias_in_rope + self.trainable_condition_scale = trainable_condition_scale + self.pooled_adaln = pooled_adaln + if self.trainable_condition_scale: + self.condition_scale = nn.Parameter(torch.tensor([1.0], dtype=torch.float32)) + else: + self.condition_scale = 1.0 + + self.controller = CrossModalInteractionController(visual_layers, audio_layers) + self.interaction_mapping = self.controller.get_interaction_layers(interaction_strategy) + + # Conditional cross-attention modules operating at the DiT hidden-state level. + self.audio_to_video_conditioners = nn.ModuleDict() # audio hidden states -> visual DiT conditioning + self.video_to_audio_conditioners = nn.ModuleDict() # visual hidden states -> audio DiT conditioning + + # Build conditioners for layers that should interact. + # audio hidden states condition the visual DiT + self.rotary = RotaryEmbedding(base=10000.0, dim=head_dim) + for v_layer, _ in self.interaction_mapping['a2v']: + self.audio_to_video_conditioners[str(v_layer)] = ConditionalCrossAttentionBlock( + dim=visual_hidden_dim, # 3072 (visual DiT hidden states) + kv_dim=audio_hidden_dim, # 1536 (audio DiT hidden states) + num_heads=visual_hidden_dim // head_dim, # derive number of heads from hidden dim + pooled_adaln=False # a2v typically does not need pooled AdaLN + ) + + # visual hidden states condition the audio DiT + for a_layer, _ in self.interaction_mapping['v2a']: + self.video_to_audio_conditioners[str(a_layer)] = ConditionalCrossAttentionBlock( + dim=audio_hidden_dim, # 1536 (audio DiT hidden states) + kv_dim=visual_hidden_dim, # 3072 (visual DiT hidden states) + num_heads=audio_hidden_dim // head_dim, # safe head count derivation + pooled_adaln=self.pooled_adaln + ) + + @torch.no_grad() + def build_aligned_freqs(self, + video_fps: float, + grid_size: Tuple[int, int, int], + audio_steps: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + """ + Build aligned RoPE (cos, sin) pairs based on video fps, video grid size (f_v, h, w), + and audio sequence length `audio_steps` (with fixed audio fps = 44100/2048). + + Returns: + visual_freqs: (cos_v, sin_v), shape [1, f_v*h*w, head_dim] + audio_freqs: (cos_a, sin_a), shape [1, audio_steps, head_dim] + """ + f_v, h, w = grid_size + L_v = f_v * h * w + L_a = int(audio_steps) + + device = device or next(self.parameters()).device + dtype = dtype or torch.float32 + + # Audio positions: 0,1,2,...,L_a-1 (audio as reference). + audio_pos = torch.arange(L_a, device=device, dtype=torch.float32).unsqueeze(0) + + # Video positions: align video frames to audio-step units. + # FIXME(dhyu): hard-coded VAE temporal stride = 4 + if self.apply_first_frame_bias_in_rope: + # Account for the "first frame lasts 1/video_fps" bias. + video_effective_fps = float(video_fps) / 4.0 + if f_v > 0: + t_starts = torch.zeros((f_v,), device=device, dtype=torch.float32) + if f_v > 1: + t_starts[1:] = (1.0 / float(video_fps)) + torch.arange(f_v - 1, device=device, dtype=torch.float32) * (1.0 / video_effective_fps) + else: + t_starts = torch.zeros((0,), device=device, dtype=torch.float32) + # Convert to audio-step units. + video_pos_per_frame = t_starts * float(self.audio_fps) + else: + # No first-frame bias: uniform alignment. + scale = float(self.audio_fps) / float(video_fps / 4.0) + video_pos_per_frame = torch.arange(f_v, device=device, dtype=torch.float32) * scale + # Flatten to f*h*w; tokens within the same frame share the same time position. + video_pos = video_pos_per_frame.repeat_interleave(h * w).unsqueeze(0) + + # print(f"video fps: {video_fps}, audio fps: {self.audio_fps}, scale: {scale}") + # print(f"video pos: {video_pos.shape}, audio pos: {audio_pos.shape}") + + # Build dummy x to produce cos/sin, dim=head_dim. + dummy_v = torch.zeros((1, L_v, self.head_dim), device=device, dtype=dtype) + dummy_a = torch.zeros((1, L_a, self.head_dim), device=device, dtype=dtype) + + cos_v, sin_v = self.rotary(dummy_v, position_ids=video_pos) + cos_a, sin_a = self.rotary(dummy_a, position_ids=audio_pos) + + return (cos_v, sin_v), (cos_a, sin_a) + + def should_interact(self, layer_idx: int, direction: str) -> bool: + return self.controller.should_interact(layer_idx, direction, self.interaction_mapping) + + def apply_conditional_control( + self, + layer_idx: int, + direction: str, + primary_hidden_states: torch.Tensor, + condition_hidden_states: torch.Tensor, + x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + condition_scale: Optional[float] = None, + video_grid_size: Optional[Tuple[int, int, int]] = None, + use_gradient_checkpointing: Optional[bool] = False, + use_gradient_checkpointing_offload: Optional[bool] = False, + ) -> torch.Tensor: + """ + Apply conditional control (at the DiT hidden-state level). + + Args: + layer_idx: current layer index + direction: conditioning direction + - 'a2v': audio hidden states -> visual DiT + - 'v2a': visual hidden states -> audio DiT + primary_hidden_states: primary DiT hidden states [B, L, hidden_dim] + condition_hidden_states: condition DiT hidden states [B, L, hidden_dim] + condition_scale: conditioning strength (similar to CFG scale) + + Returns: + Conditioned primary DiT hidden states [B, L, hidden_dim] + """ + + if not self.controller.should_interact(layer_idx, direction, self.interaction_mapping): + return primary_hidden_states + + if direction == 'a2v': + # audio hidden states condition the visual DiT + conditioner = self.audio_to_video_conditioners[str(layer_idx)] + + elif direction == 'v2a': + # visual hidden states condition the audio DiT + conditioner = self.video_to_audio_conditioners[str(layer_idx)] + else: + raise ValueError(f"Invalid direction: {direction}") + + conditioned_features = gradient_checkpoint_forward( + conditioner, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + x=primary_hidden_states, + y=condition_hidden_states, + x_freqs=x_freqs, + y_freqs=y_freqs, + video_grid_size=video_grid_size, + ) + + if self.trainable_condition_scale and condition_scale is not None: + print( + "[WARN] This model has a trainable condition_scale, but an external " + f"condition_scale={condition_scale} was provided. The trainable condition_scale " + "will be ignored in favor of the external value." + ) + + scale = condition_scale if condition_scale is not None else self.condition_scale + + primary_hidden_states = primary_hidden_states + conditioned_features * scale + + return primary_hidden_states + + def forward( + self, + layer_idx: int, + visual_hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + *, + x_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + y_freqs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + a2v_condition_scale: Optional[float] = None, + v2a_condition_scale: Optional[float] = None, + condition_scale: Optional[float] = None, + video_grid_size: Optional[Tuple[int, int, int]] = None, + use_gradient_checkpointing: Optional[bool] = False, + use_gradient_checkpointing_offload: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply bidirectional conditional control to both visual/audio towers. + + Args: + layer_idx: current layer index + visual_hidden_states: visual DiT hidden states + audio_hidden_states: audio DiT hidden states + x_freqs / y_freqs: cross-modal RoPE (cos, sin) pairs. + If provided, x_freqs is assumed to correspond to the primary tower and y_freqs + to the conditioning tower. + a2v_condition_scale: audio->visual conditioning strength (overrides global condition_scale) + v2a_condition_scale: visual->audio conditioning strength (overrides global condition_scale) + condition_scale: fallback conditioning strength when per-direction scale is None + video_grid_size: (F, H, W), used on the audio side when pooled_adaln is enabled + + Returns: + (visual_hidden_states, audio_hidden_states), both conditioned in their respective directions. + """ + + visual_conditioned = self.apply_conditional_control( + layer_idx=layer_idx, + direction="a2v", + primary_hidden_states=visual_hidden_states, + condition_hidden_states=audio_hidden_states, + x_freqs=x_freqs, + y_freqs=y_freqs, + condition_scale=a2v_condition_scale if a2v_condition_scale is not None else condition_scale, + video_grid_size=video_grid_size, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + audio_conditioned = self.apply_conditional_control( + layer_idx=layer_idx, + direction="v2a", + primary_hidden_states=audio_hidden_states, + condition_hidden_states=visual_hidden_states, + x_freqs=y_freqs, + y_freqs=x_freqs, + condition_scale=v2a_condition_scale if v2a_condition_scale is not None else condition_scale, + video_grid_size=video_grid_size, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + + return visual_conditioned, audio_conditioned diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index d957717c8..25ddb9276 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -99,18 +99,30 @@ def rope_apply(x, freqs, num_heads): return x_out.to(x.dtype) +def set_to_torch_norm(models): + for model in models: + for module in model.modules(): + if isinstance(module, RMSNorm): + module.use_torch_norm = True + + class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) + self.use_torch_norm = False + self.normalized_shape = (dim,) def norm(self, x): return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) def forward(self, x): dtype = x.dtype - return self.norm(x.float()).to(dtype) * self.weight + if self.use_torch_norm: + return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) + else: + return self.norm(x.float()).to(dtype) * self.weight class AttentionModule(nn.Module): diff --git a/diffsynth/pipelines/mova_audio_video.py b/diffsynth/pipelines/mova_audio_video.py new file mode 100644 index 000000000..2b6c7d8fa --- /dev/null +++ b/diffsynth/pipelines/mova_audio_video.py @@ -0,0 +1,455 @@ +import sys +import torch, types +from PIL import Image +from typing import Optional, Union +from einops import rearrange +import numpy as np +from PIL import Image +from tqdm import tqdm +from typing import Optional + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit + +from ..models.wan_video_dit import WanModel, sinusoidal_embedding_1d, set_to_torch_norm +from ..models.wan_video_text_encoder import WanTextEncoder, HuggingfaceTokenizer +from ..models.wan_video_vae import WanVideoVAE +from ..models.mova_audio_dit import MovaAudioDit +from ..models.mova_audio_vae import DacVAE +from ..models.mova_dual_tower_bridge import DualTowerConditionalBridge + + +class MovaAudioVideoPipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 + ) + self.scheduler = FlowMatchScheduler("Wan") + self.tokenizer: HuggingfaceTokenizer = None + self.text_encoder: WanTextEncoder = None + self.video_dit: WanModel = None # high noise model + self.video_dit2: WanModel = None # low noise model + self.audio_dit: MovaAudioDit = None + self.dual_tower_bridge: DualTowerConditionalBridge = None + self.video_vae: WanVideoVAE = None + self.audio_vae: DacVAE = None + + self.in_iteration_models = ("video_dit", "audio_dit", "dual_tower_bridge") + self.in_iteration_models_2 = ("video_dit2", "audio_dit", "dual_tower_bridge") + + self.units = [ + MovaAudioVideoUnit_ShapeChecker(), + MovaAudioVideoUnit_NoiseInitializer(), + MovaAudioVideoUnit_InputVideoEmbedder(), + MovaAudioVideoUnit_InputAudioEmbedder(), + MovaAudioVideoUnit_PromptEmbedder(), + MovaAudioVideoUnit_ImageEmbedderVAE(), + MovaAudioVideoUnit_UnifiedSequenceParallel(), + ] + self.model_fn = model_fn_mova_audio_video + + def enable_usp(self): + from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward + for block in self.video_dit.blocks + self.audio_dit.blocks + self.video_dit2.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.sp_size = get_sequence_parallel_world_size() + self.use_unified_sequence_parallel = True + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"), + use_usp: bool = False, + vram_limit: float = None, + ): + if use_usp: + from ..utils.xfuser import initialize_usp + initialize_usp(device) + import torch.distributed as dist + from ..core.device.npu_compatible_device import get_device_name + if dist.is_available() and dist.is_initialized(): + device = get_device_name() + # Initialize pipeline + pipe = MovaAudioVideoPipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder") + pipe.video_dit, pipe.video_dit2 = model_pool.fetch_model("wan_video_dit", index=2) + pipe.audio_dit = model_pool.fetch_model("mova_audio_dit") + pipe.dual_tower_bridge = model_pool.fetch_model("mova_dual_tower_bridge") + pipe.video_vae = model_pool.fetch_model("wan_video_vae") + pipe.audio_vae = model_pool.fetch_model("mova_audio_vae") + set_to_torch_norm([pipe.video_dit, pipe.video_dit2, pipe.audio_dit, pipe.dual_tower_bridge]) + + # Size division factor + if pipe.video_vae is not None: + pipe.height_division_factor = pipe.video_vae.upsampling_factor * 2 + pipe.width_division_factor = pipe.video_vae.upsampling_factor * 2 + + # Initialize tokenizer and processor + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = HuggingfaceTokenizer(name=tokenizer_config.path, seq_len=512, clean='whitespace') + + # Unified Sequence Parallel + if use_usp: pipe.enable_usp() + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: Optional[str] = "", + # Image-to-video + input_image: Optional[Image.Image] = None, + # First-last-frame-to-video + end_image: Optional[Image.Image] = None, + # Video-to-video + denoising_strength: Optional[float] = 1.0, + # Randomness + seed: Optional[int] = None, + rand_device: Optional[str] = "cpu", + # Shape + height: Optional[int] = 352, + width: Optional[int] = 640, + num_frames: Optional[int] = 81, + frame_rate: Optional[int] = 24, + # Classifier-free guidance + cfg_scale: Optional[float] = 5.0, + # Boundary + switch_DiT_boundary: Optional[float] = 0.9, + # Scheduler + num_inference_steps: Optional[int] = 50, + sigma_shift: Optional[float] = 5.0, + # VAE tiling + tiled: Optional[bool] = True, + tile_size: Optional[tuple[int, int]] = (30, 52), + tile_stride: Optional[tuple[int, int]] = (15, 26), + # progress_bar + progress_bar_cmd=tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + # Inputs + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "input_image": input_image, + "end_image": end_image, + "denoising_strength": denoising_strength, + "seed": seed, "rand_device": rand_device, + "height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate, + "cfg_scale": cfg_scale, + "sigma_shift": sigma_shift, + "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + # Switch DiT if necessary + if timestep.item() < switch_DiT_boundary * 1000 and self.video_dit2 is not None and not models["video_dit"] is self.video_dit2: + self.load_models_to_device(self.in_iteration_models_2) + models["video_dit"] = self.video_dit2 + # Timestep + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + # Scheduler + inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared) + inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared) + + # Decode + self.load_models_to_device(['video_vae']) + video = self.video_vae.decode(inputs_shared["video_latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + video = self.vae_output_to_video(video) + self.load_models_to_device(["audio_vae"]) + audio = self.audio_vae.decode(inputs_shared["audio_latents"]).to(dtype=torch.float32, device='cpu').squeeze() + self.load_models_to_device([]) + return video, audio + + +class MovaAudioVideoUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames"), + output_params=("height", "width", "num_frames"), + ) + + def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames): + height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + return {"height": height, "width": width, "num_frames": num_frames} + + +class MovaAudioVideoUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate"), + output_params=("video_noise", "audio_noise") + ) + + def process(self, pipe: MovaAudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate): + length = (num_frames - 1) // 4 + 1 + video_shape = (1, pipe.video_vae.model.z_dim, length, height // pipe.video_vae.upsampling_factor, width // pipe.video_vae.upsampling_factor) + video_noise = pipe.generate_noise(video_shape, seed=seed, rand_device=rand_device) + + audio_num_samples = (int(pipe.audio_vae.sample_rate * num_frames / frame_rate) - 1) // int(pipe.audio_vae.hop_length) + 1 + audio_shape = (1, pipe.audio_vae.latent_dim, audio_num_samples) + audio_noise = pipe.generate_noise(audio_shape, seed=seed, rand_device=rand_device) + return {"video_noise": video_noise, "audio_noise": audio_noise} + + +class MovaAudioVideoUnit_InputVideoEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "video_noise", "tiled", "tile_size", "tile_stride"), + output_params=("latents", "input_latents"), + onload_model_names=("video_vae",) + ) + + def process(self, pipe: MovaAudioVideoPipeline, input_video, video_noise, tiled, tile_size, tile_stride): + if input_video is None: + return {"video_latents": video_noise} + # TODO: check for train + pipe.load_models_to_device(self.onload_model_names) + input_video = pipe.preprocess_video(input_video) + input_latents = pipe.video_vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + if pipe.scheduler.training: + return {"latents": video_noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, video_noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents} + + +class MovaAudioVideoUnit_InputAudioEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_audio", "audio_noise"), + output_params=("audio_latents", "audio_input_latents"), + onload_model_names=("audio_vae_encoder",) + ) + + def process(self, pipe: MovaAudioVideoPipeline, input_audio, audio_noise): + if input_audio is None: + return {"audio_latents": audio_noise} + else: + # TODO: support audio training + if pipe.scheduler.training: + return {"audio_latents": audio_noise, "audio_input_latents": audio_noise} + else: + raise NotImplementedError("Audio-to-video not supported.") + + +class MovaAudioVideoUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("context",), + onload_model_names=("text_encoder",) + ) + + def encode_prompt(self, pipe: MovaAudioVideoPipeline, prompt): + ids, mask = pipe.tokenizer( + prompt, + padding="max_length", + max_length=512, + truncation=True, + add_special_tokens=True, + return_mask=True, + return_tensors="pt", + ) + ids = ids.to(pipe.device) + mask = mask.to(pipe.device) + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_emb = pipe.text_encoder(ids, mask) + for i, v in enumerate(seq_lens): + prompt_emb[:, v:] = 0 + return prompt_emb + + def process(self, pipe: MovaAudioVideoPipeline, prompt) -> dict: + pipe.load_models_to_device(self.onload_model_names) + prompt_emb = self.encode_prompt(pipe, prompt) + return {"context": prompt_emb} + + +class MovaAudioVideoUnit_ImageEmbedderVAE(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("y",), + onload_model_names=("video_vae",) + ) + + def process(self, pipe: MovaAudioVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.video_dit.require_vae_embedding: + return {} + pipe.load_models_to_device(self.onload_model_names) + + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + y = pipe.video_vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"y": y} + +class MovaAudioVideoUnit_UnifiedSequenceParallel(PipelineUnit): + def __init__(self): + super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",)) + + def process(self, pipe: MovaAudioVideoPipeline): + if hasattr(pipe, "use_unified_sequence_parallel"): + if pipe.use_unified_sequence_parallel: + return {"use_unified_sequence_parallel": True} + return {} + +def model_fn_mova_audio_video( + video_dit: WanModel, + audio_dit: MovaAudioDit, + dual_tower_bridge: DualTowerConditionalBridge, + video_latents: torch.Tensor = None, + audio_latents: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, + y: Optional[torch.Tensor] = None, + frame_rate: Optional[int] = 24, + use_unified_sequence_parallel: bool = False, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + **kwargs, +): + video_x, audio_x = video_latents, audio_latents + # First-Last Frame + if y is not None: + video_x = torch.cat([video_x, y], dim=1) + + # Timestep + video_t = video_dit.time_embedding(sinusoidal_embedding_1d(video_dit.freq_dim, timestep)) + video_t_mod = video_dit.time_projection(video_t).unflatten(1, (6, video_dit.dim)) + audio_t = audio_dit.time_embedding(sinusoidal_embedding_1d(audio_dit.freq_dim, timestep)) + audio_t_mod = audio_dit.time_projection(audio_t).unflatten(1, (6, audio_dit.dim)) + + # Context + video_context = video_dit.text_embedding(context) + audio_context = audio_dit.text_embedding(context) + + # Patchify + video_x = video_dit.patch_embedding(video_x) + f_v, h, w = video_x.shape[2:] + video_x = rearrange(video_x, 'b c f h w -> b (f h w) c').contiguous() + seq_len_video = video_x.shape[1] + + audio_x = audio_dit.patch_embedding(audio_x) + f_a = audio_x.shape[2] + audio_x = rearrange(audio_x, 'b c f -> b f c').contiguous() + seq_len_audio = audio_x.shape[1] + + # Freqs + video_freqs = torch.cat([ + video_dit.freqs[0][:f_v].view(f_v, 1, 1, -1).expand(f_v, h, w, -1), + video_dit.freqs[1][:h].view(1, h, 1, -1).expand(f_v, h, w, -1), + video_dit.freqs[2][:w].view(1, 1, w, -1).expand(f_v, h, w, -1) + ], dim=-1).reshape(f_v * h * w, 1, -1).to(video_x.device) + audio_freqs = torch.cat([ + audio_dit.freqs[0][:f_a].view(f_a, -1).expand(f_a, -1), + audio_dit.freqs[1][:f_a].view(f_a, -1).expand(f_a, -1), + audio_dit.freqs[2][:f_a].view(f_a, -1).expand(f_a, -1), + ], dim=-1).reshape(f_a, 1, -1).to(audio_x.device) + + video_rope, audio_rope = dual_tower_bridge.build_aligned_freqs( + video_fps=frame_rate, + grid_size=(f_v, h, w), + audio_steps=audio_x.shape[1], + device=video_x.device, + dtype=video_x.dtype, + ) + # usp func + if use_unified_sequence_parallel: + from ..utils.xfuser import get_current_chunk, gather_all_chunks + else: + get_current_chunk = lambda x, dim=1: x + gather_all_chunks = lambda x, seq_len, dim=1: x + # Forward blocks + for block_id in range(len(audio_dit.blocks)): + if dual_tower_bridge.should_interact(block_id, "a2v"): + video_x, audio_x = dual_tower_bridge( + block_id, + video_x, + audio_x, + x_freqs=video_rope, + y_freqs=audio_rope, + condition_scale=1.0, + video_grid_size=(f_v, h, w), + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + video_x = get_current_chunk(video_x, dim=1) + video_x = gradient_checkpoint_forward( + video_dit.blocks[block_id], + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + video_x, video_context, video_t_mod, video_freqs + ) + video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1) + audio_x = get_current_chunk(audio_x, dim=1) + audio_x = gradient_checkpoint_forward( + audio_dit.blocks[block_id], + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + audio_x, audio_context, audio_t_mod, audio_freqs + ) + audio_x = gather_all_chunks(audio_x, seq_len=seq_len_audio, dim=1) + + video_x = get_current_chunk(video_x, dim=1) + for block_id in range(len(audio_dit.blocks), len(video_dit.blocks)): + video_x = gradient_checkpoint_forward( + video_dit.blocks[block_id], + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + video_x, video_context, video_t_mod, video_freqs + ) + video_x = gather_all_chunks(video_x, seq_len=seq_len_video, dim=1) + + # Head + video_x = video_dit.head(video_x, video_t) + video_x = video_dit.unpatchify(video_x, (f_v, h, w)) + + audio_x = audio_dit.head(audio_x, audio_t) + audio_x = audio_dit.unpatchify(audio_x, (f_a,)) + return video_x, audio_x diff --git a/diffsynth/utils/xfuser/__init__.py b/diffsynth/utils/xfuser/__init__.py index 13dd178e2..6b4271d30 100644 --- a/diffsynth/utils/xfuser/__init__.py +++ b/diffsynth/utils/xfuser/__init__.py @@ -1 +1 @@ -from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp +from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp, get_current_chunk, gather_all_chunks diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index 228e7b877..0f5d10524 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -6,6 +6,7 @@ get_sequence_parallel_world_size, get_sp_group) from xfuser.core.long_ctx_attention import xFuserLongContextAttention +import torch.distributed as dist from ... import IS_NPU_AVAILABLE from ...core.device import parse_nccl_backend, parse_device_type @@ -143,4 +144,31 @@ def usp_attn_forward(self, x, freqs): del q, k, v getattr(torch, parse_device_type(x.device)).empty_cache() - return self.o(x) \ No newline at end of file + return self.o(x) + + +def get_current_chunk(x, dim=1): + chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=dim) + ndims = len(chunks[0].shape) + pad_list = [0] * (2 * ndims) + pad_end_index = 2 * (ndims - 1 - dim) + 1 + max_size = chunks[0].size(dim) + chunks = [ + torch.nn.functional.pad( + chunk, + tuple(pad_list[:pad_end_index] + [max_size - chunk.size(dim)] + pad_list[pad_end_index+1:]), + value=0 + ) + for chunk in chunks + ] + x = chunks[get_sequence_parallel_rank()] + return x + + +def gather_all_chunks(x, seq_len=None, dim=1): + x = get_sp_group().all_gather(x, dim=dim) + if seq_len is not None: + slices = [slice(None)] * x.ndim + slices[dim] = slice(0, seq_len) + x = x[tuple(slices)] + return x diff --git a/examples/mova/acceleration/unified_sequence_parallel.py b/examples/mova/acceleration/unified_sequence_parallel.py new file mode 100644 index 000000000..ff5db8ea2 --- /dev/null +++ b/examples/mova/acceleration/unified_sequence_parallel.py @@ -0,0 +1,55 @@ +import torch +from PIL import Image +from diffsynth.utils.data.media_io_mova import save_video_with_audio +from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig +import torch.distributed as dist + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = MovaAudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + use_usp=True, + model_configs=[ + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"), +) +negative_prompt = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指" +) + +prompt = "Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other." +height, width, num_frames = 352, 640, 121 +frame_rate=24 +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((width, height)).convert("RGB") +# Image-to-video +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + input_image=input_image, + num_inference_steps=50, + seed=0, + tiled=True, + frame_rate=frame_rate, +) +if dist.get_rank() == 0: + save_video_with_audio(video, audio, "MOVA-360p-cat.mp4", fps=24, sample_rate=pipe.audio_vae.sample_rate) diff --git a/examples/mova/model_inference/MOVA-360p-TI2AV.py b/examples/mova/model_inference/MOVA-360p-TI2AV.py new file mode 100644 index 000000000..2ad77cd02 --- /dev/null +++ b/examples/mova/model_inference/MOVA-360p-TI2AV.py @@ -0,0 +1,52 @@ +import torch +from PIL import Image +from diffsynth.pipelines.mova_audio_video import ModelConfig, MovaAudioVideoPipeline +from diffsynth.utils.data.media_io_mova import save_video_with_audio + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = MovaAudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"), +) +negative_prompt = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指" +) + +prompt = "Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other." +height, width, num_frames = 352, 640, 121 +frame_rate = 24 +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((width, height)).convert("RGB") +# Image-to-video +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + input_image=input_image, + num_inference_steps=50, + seed=0, + tiled=True, + frame_rate=frame_rate, +) +save_video_with_audio(video, audio, "MOVA-360p-cat_singlegpu_49.mp4", fps=24, sample_rate=pipe.audio_vae.sample_rate) diff --git a/examples/mova/model_inference/MOVA-720p-TI2AV.py b/examples/mova/model_inference/MOVA-720p-TI2AV.py new file mode 100644 index 000000000..6294d6a6f --- /dev/null +++ b/examples/mova/model_inference/MOVA-720p-TI2AV.py @@ -0,0 +1,53 @@ +import torch +from PIL import Image +from diffsynth.utils.data.media_io_mova import save_video_with_audio +from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig + + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = MovaAudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"), +) + +negative_prompt = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指" +) +prompt = "Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other." +height, width, num_frames = 720, 1280, 121 +frame_rate=24 +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((width, height)).convert("RGB") +# Image-to-video +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + input_image=input_image, + num_inference_steps=50, + seed=0, + tiled=True, + frame_rate=frame_rate, +) +save_video_with_audio(video, audio, "MOVA-720p-cat.mp4", fps=24, sample_rate=pipe.audio_vae.sample_rate) From 6911e35b018b2d4c033e212697c5117d950e310a Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Mon, 9 Mar 2026 12:58:45 +0800 Subject: [PATCH 2/6] mova media_io --- diffsynth/utils/data/media_io_mova.py | 104 ++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 diffsynth/utils/data/media_io_mova.py diff --git a/diffsynth/utils/data/media_io_mova.py b/diffsynth/utils/data/media_io_mova.py new file mode 100644 index 000000000..4d3c4772e --- /dev/null +++ b/diffsynth/utils/data/media_io_mova.py @@ -0,0 +1,104 @@ +import os +import shutil +import subprocess +import tempfile +import wave + +import imageio +import numpy as np +import torch +from tqdm import tqdm +from ..data import merge_video_audio +try: + import imageio_ffmpeg as _imageio_ffmpeg +except ImportError: # pragma: no cover + _imageio_ffmpeg = None + + +def _write_wav_wave(audio, wav_path, sample_rate=44100): + """ + Write int16 PCM WAV using standard library wave. + - audio: torch.Tensor or np.ndarray, shape [samples] / [channels, samples] + - If float, assumed range is approximately [-1, 1], will be converted to int16 PCM + """ + if isinstance(audio, torch.Tensor): + a = audio.detach().cpu().numpy() + else: + a = np.asarray(audio) + + if a.ndim == 1: + a = a[None, :] + if a.ndim != 2: + raise ValueError(f"audio shape needs to be [S] / [C,S], current shape is {a.shape}") + + channels, samples = int(a.shape[0]), int(a.shape[1]) + if channels > 2: + a = a[:2, :] + channels = 2 + + if np.issubdtype(a.dtype, np.floating): + a = np.clip(a, -1.0, 1.0) + a = (a * 32767.0).astype(np.int16) + elif a.dtype != np.int16: + a = np.clip(a, -32768, 32767).astype(np.int16) + + if channels == 1: + interleaved = a.reshape(-1) + else: + interleaved = a.T.reshape(-1) + + with wave.open(wav_path, "wb") as wf: + wf.setnchannels(channels) + wf.setsampwidth(2) # int16 + wf.setframerate(int(sample_rate)) + wf.writeframes(interleaved.tobytes(order="C")) + + +def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None): + writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params) + for frame in tqdm(frames, desc="Saving video"): + frame = np.array(frame) + writer.append_data(frame) + writer.close() + + +# Copied from https://github.com/sgl-project/sglang/blob/7106f6c8e1509cd57abeafd5d50cb1beaffbc63c/python/sglang/multimodal_gen/runtime/entrypoints/utils.py#L96 +def _resolve_ffmpeg_exe() -> str: + ffmpeg_exe = "ffmpeg" + ffmpeg_on_path = shutil.which("ffmpeg") + if ffmpeg_on_path: + ffmpeg_exe = ffmpeg_on_path + try: + if _imageio_ffmpeg is not None: + ffmpeg_exe = _imageio_ffmpeg.get_ffmpeg_exe() + except Exception: + pass + + ffmpeg_ok = False + if ffmpeg_exe: + if os.path.isabs(ffmpeg_exe): + ffmpeg_ok = os.path.exists(ffmpeg_exe) + else: + ffmpeg_ok = shutil.which(ffmpeg_exe) is not None + if not ffmpeg_ok: + raise RuntimeError("ffmpeg not found") + return ffmpeg_exe + +def save_video_with_audio(frames, audio_data, save_path, fps=24, sample_rate=44100, quality=9, ffmpeg_path=None): + """ + Save video with audio. + - frames: List[PIL.Image | np.ndarray] + - audio: torch.Tensor or np.ndarray, shape [channels, samples] or [samples] + - save_path: Output mp4 path + - fps: Video frame rate + - sample_rate: Audio sample rate (default 44100) + Depend on ffmpeg executable program for audio/video reuse. + """ + if ffmpeg_path is None: + ffmpeg_path = _resolve_ffmpeg_exe() + + with tempfile.TemporaryDirectory(prefix='save_vwa_') as tmp_dir: + tmp_audio = os.path.join(tmp_dir, 'audio.wav') + save_video(frames, save_path, fps=fps, quality=quality) + _write_wav_wave(audio_data, tmp_audio, sample_rate=sample_rate) + merge_video_audio(save_path, tmp_audio) From 4a9391df96c432f421cd030a74d8cb5ce5bba490 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Thu, 12 Mar 2026 16:23:17 +0800 Subject: [PATCH 3/6] add unified audio_video api & fix bug of mono audio input for ltx --- diffsynth/pipelines/ltx2_audio_video.py | 3 + diffsynth/utils/data/audio.py | 108 ++++++++++++ diffsynth/utils/data/audio_video.py | 134 +++++++++++++++ diffsynth/utils/data/media_io_ltx2.py | 161 +----------------- diffsynth/utils/data/media_io_mova.py | 104 ----------- .../model_inference/LTX-2.3-A2V-TwoStage.py | 5 +- .../LTX-2.3-T2AV-TwoStage-Retake.py | 5 +- .../LTX-2.3-A2V-TwoStage.py | 5 +- .../LTX-2.3-T2AV-TwoStage-Retake.py | 5 +- .../acceleration/unified_sequence_parallel.py | 4 +- .../mova/model_inference/MOVA-360p-TI2AV.py | 4 +- .../mova/model_inference/MOVA-720p-TI2AV.py | 4 +- pyproject.toml | 5 +- 13 files changed, 268 insertions(+), 279 deletions(-) create mode 100644 diffsynth/utils/data/audio.py create mode 100644 diffsynth/utils/data/audio_video.py delete mode 100644 diffsynth/utils/data/media_io_mova.py diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index a06213a6e..5f78c2917 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -22,6 +22,7 @@ from ..models.ltx2_upsampler import LTX2LatentUpsampler from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS from ..utils.data.media_io_ltx2 import ltx2_preprocess +from ..utils.data.audio import convert_to_stereo class LTX2AudioVideoPipeline(BasePipeline): @@ -389,6 +390,7 @@ def process(self, pipe: LTX2AudioVideoPipeline, input_audio, audio_noise): return {"audio_latents": audio_noise} else: input_audio, sample_rate = input_audio + input_audio = convert_to_stereo(input_audio) pipe.load_models_to_device(self.onload_model_names) input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype) audio_input_latents = pipe.audio_vae_encoder(input_audio) @@ -441,6 +443,7 @@ def process(self, pipe: LTX2AudioVideoPipeline, retake_audio, seed, rand_device, return {} else: input_audio, sample_rate = retake_audio + input_audio = convert_to_stereo(input_audio) pipe.load_models_to_device(self.onload_model_names) input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype, device=pipe.device) input_latents_audio = pipe.audio_vae_encoder(input_audio) diff --git a/diffsynth/utils/data/audio.py b/diffsynth/utils/data/audio.py new file mode 100644 index 000000000..fe482dbdc --- /dev/null +++ b/diffsynth/utils/data/audio.py @@ -0,0 +1,108 @@ +import torch +import torchaudio +from torchcodec.decoders import AudioDecoder +from torchcodec.encoders import AudioEncoder + + +def convert_to_mono(audio_tensor: torch.Tensor) -> torch.Tensor: + """ + Convert audio to mono by averaging channels. + Supports [C, T] or [B, C, T]. Output shape: [1, T] or [B, 1, T]. + """ + return audio_tensor.mean(dim=-2, keepdim=True) + + +def convert_to_stereo(audio_tensor: torch.Tensor) -> torch.Tensor: + """ + Convert audio to stereo. + Supports [C, T] or [B, C, T]. Duplicate mono, keep stereo. + """ + if audio_tensor.size(-2) == 1: + return audio_tensor.repeat(1, 2, 1) if audio_tensor.dim() == 3 else audio_tensor.repeat(2, 1) + return audio_tensor + + +def resample_waveform(waveform: torch.Tensor, source_rate: int, target_rate: int) -> torch.Tensor: + """Resample waveform to target sample rate if needed.""" + if source_rate == target_rate: + return waveform + resampled = torchaudio.functional.resample(waveform, source_rate, target_rate) + return resampled.to(dtype=waveform.dtype) + + +def read_audio_with_torchcodec( + path: str, + start_time: float = 0, + duration: float | None = None, +) -> tuple[torch.Tensor, int]: + """ + Read audio from file natively using torchcodec, with optional start time and duration. + + Args: + path (str): The file path to the audio file. + start_time (float, optional): The start time in seconds to read from. Defaults to 0. + duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None. + + Returns: + tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate. + The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames. + """ + decoder = AudioDecoder(path) + stop_seconds = None if duration is None else start_time + duration + waveform = decoder.get_samples_played_in_range(start_seconds=start_time, stop_seconds=stop_seconds).data + return waveform, decoder.metadata.sample_rate + + +def read_audio( + path: str, + start_time: float = 0, + duration: float | None = None, + resample: bool = False, + resample_rate: int = 48000, + backend: str = "torchcodec", +) -> tuple[torch.Tensor, int]: + """ + Read audio from file, with optional start time, duration, and resampling. + + Args: + path (str): The file path to the audio file. + start_time (float, optional): The start time in seconds to read from. Defaults to 0. + duration (float | None, optional): The duration in seconds to read. If None, reads until the end. Defaults to None. + resample (bool, optional): Whether to resample the audio to a different sample rate. Defaults to False. + resample_rate (int, optional): The target sample rate for resampling if resample is True. Defaults to 48000. + backend (str, optional): The audio backend to use for reading. Defaults to "torchcodec". + + Returns: + tuple[torch.Tensor, int]: A tuple containing the audio tensor and the sample rate. + The audio tensor shape is [C, T] where C is the number of channels and T is the number of audio frames. + """ + if backend == "torchcodec": + waveform, sample_rate = read_audio_with_torchcodec(path, start_time, duration) + else: + raise ValueError(f"Unsupported audio backend: {backend}") + + if resample: + waveform = resample_waveform(waveform, sample_rate, resample_rate) + sample_rate = resample_rate + + return waveform, sample_rate + + +def save_audio(waveform: torch.Tensor, sample_rate: int, save_path: str, backend: str = "torchcodec"): + """ + Save audio tensor to file. + + Args: + waveform (torch.Tensor): The audio tensor to save. Shape can be [C, T] or [B, C, T]. + sample_rate (int): The sample rate of the audio. + save_path (str): The file path to save the audio to. + backend (str, optional): The audio backend to use for saving. Defaults to "torchcodec". + """ + if waveform.dim() == 3: + waveform = waveform[0] + + if backend == "torchcodec": + encoder = AudioEncoder(waveform, sample_rate=sample_rate) + encoder.to_file(dest=save_path) + else: + raise ValueError(f"Unsupported audio backend: {backend}") diff --git a/diffsynth/utils/data/audio_video.py b/diffsynth/utils/data/audio_video.py new file mode 100644 index 000000000..015434510 --- /dev/null +++ b/diffsynth/utils/data/audio_video.py @@ -0,0 +1,134 @@ +import av +from fractions import Fraction +import torch +from PIL import Image +from tqdm import tqdm +from .audio import convert_to_stereo + + +def _resample_audio( + container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame +) -> None: + cc = audio_stream.codec_context + + # Use the encoder's format/layout/rate as the *target* + target_format = cc.format or "fltp" # AAC → usually fltp + target_layout = cc.layout or "stereo" + target_rate = cc.sample_rate or frame_in.sample_rate + + audio_resampler = av.audio.resampler.AudioResampler( + format=target_format, + layout=target_layout, + rate=target_rate, + ) + + audio_next_pts = 0 + for rframe in audio_resampler.resample(frame_in): + if rframe.pts is None: + rframe.pts = audio_next_pts + audio_next_pts += rframe.samples + rframe.sample_rate = frame_in.sample_rate + container.mux(audio_stream.encode(rframe)) + + # flush audio encoder + for packet in audio_stream.encode(): + container.mux(packet) + + +def _write_audio( + container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int +) -> None: + if samples.ndim == 1: + samples = samples[:, None] + samples = convert_to_stereo(samples) + assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2" + samples = samples.T + # Convert to int16 packed for ingestion; resampler converts to encoder fmt. + if samples.dtype != torch.int16: + samples = torch.clip(samples, -1.0, 1.0) + samples = (samples * 32767.0).to(torch.int16) + + frame_in = av.AudioFrame.from_ndarray( + samples.contiguous().reshape(1, -1).cpu().numpy(), + format="s16", + layout="stereo", + ) + frame_in.sample_rate = audio_sample_rate + + _resample_audio(container, audio_stream, frame_in) + + +def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream: + """ + Prepare the audio stream for writing. + """ + audio_stream = container.add_stream("aac") + supported_sample_rates = audio_stream.codec_context.codec.audio_rates + if supported_sample_rates: + best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate)) + if best_rate != audio_sample_rate: + print(f"Using closest supported audio sample rate: {best_rate}") + else: + best_rate = audio_sample_rate + audio_stream.codec_context.sample_rate = best_rate + audio_stream.codec_context.layout = "stereo" + audio_stream.codec_context.time_base = Fraction(1, best_rate) + return audio_stream + + +def write_video_audio( + video: list[Image.Image], + audio: torch.Tensor | None, + output_path: str, + fps: int = 24, + audio_sample_rate: int | None = None, +) -> None: + """ + Writes a sequence of images and an audio tensor to a video file. + + This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream + and multiplex a PyTorch tensor as the audio stream into the output container. + + Args: + video (list[Image.Image]): A list of PIL Image objects representing the video frames. + The length of this list determines the total duration of the video based on the FPS. + audio (torch.Tensor | None): The audio data as a PyTorch tensor. + The shape is typically (channels, samples). If no audio is required, pass None. + channels can be 1 or 2. 1 for mono, 2 for stereo. + output_path (str): The file path (including extension) where the output video will be saved. + fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24. + audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio. + If the audio tensor is provided and this is None, the function attempts to infer the rate + based on the audio tensor's length and the video duration. + Raises: + ValueError: If an audio tensor is provided but the sample rate cannot be determined. + """ + duration = len(video) / fps + if audio_sample_rate is None: + audio_sample_rate = int(audio.shape[-1] / duration) + + width, height = video[0].size + container = av.open(output_path, mode="w") + stream = container.add_stream("libx264", rate=int(fps)) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + + if audio is not None: + if audio_sample_rate is None: + raise ValueError("audio_sample_rate is required when audio is provided") + audio_stream = _prepare_audio_stream(container, audio_sample_rate) + + for frame in tqdm(video, total=len(video)): + frame = av.VideoFrame.from_image(frame) + for packet in stream.encode(frame): + container.mux(packet) + + # Flush encoder + for packet in stream.encode(): + container.mux(packet) + + if audio is not None: + _write_audio(container, audio_stream, audio, audio_sample_rate) + + container.close() diff --git a/diffsynth/utils/data/media_io_ltx2.py b/diffsynth/utils/data/media_io_ltx2.py index c31b0507d..425278651 100644 --- a/diffsynth/utils/data/media_io_ltx2.py +++ b/diffsynth/utils/data/media_io_ltx2.py @@ -1,166 +1,7 @@ -from fractions import Fraction -import torch -import torchaudio import av -from tqdm import tqdm -from PIL import Image import numpy as np from io import BytesIO - - -def _resample_audio( - container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame -) -> None: - cc = audio_stream.codec_context - - # Use the encoder's format/layout/rate as the *target* - target_format = cc.format or "fltp" # AAC → usually fltp - target_layout = cc.layout or "stereo" - target_rate = cc.sample_rate or frame_in.sample_rate - - audio_resampler = av.audio.resampler.AudioResampler( - format=target_format, - layout=target_layout, - rate=target_rate, - ) - - audio_next_pts = 0 - for rframe in audio_resampler.resample(frame_in): - if rframe.pts is None: - rframe.pts = audio_next_pts - audio_next_pts += rframe.samples - rframe.sample_rate = frame_in.sample_rate - container.mux(audio_stream.encode(rframe)) - - # flush audio encoder - for packet in audio_stream.encode(): - container.mux(packet) - - -def _write_audio( - container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int -) -> None: - if samples.ndim == 1: - samples = samples[:, None] - if samples.shape[0] == 1: - samples = samples.repeat(2, 1) - assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2" - samples = samples.T - # Convert to int16 packed for ingestion; resampler converts to encoder fmt. - if samples.dtype != torch.int16: - samples = torch.clip(samples, -1.0, 1.0) - samples = (samples * 32767.0).to(torch.int16) - - frame_in = av.AudioFrame.from_ndarray( - samples.contiguous().reshape(1, -1).cpu().numpy(), - format="s16", - layout="stereo", - ) - frame_in.sample_rate = audio_sample_rate - - _resample_audio(container, audio_stream, frame_in) - - -def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream: - """ - Prepare the audio stream for writing. - """ - audio_stream = container.add_stream("aac") - supported_sample_rates = audio_stream.codec_context.codec.audio_rates - if supported_sample_rates: - best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate)) - if best_rate != audio_sample_rate: - print(f"Using closest supported audio sample rate: {best_rate}") - else: - best_rate = audio_sample_rate - audio_stream.codec_context.sample_rate = best_rate - audio_stream.codec_context.layout = "stereo" - audio_stream.codec_context.time_base = Fraction(1, best_rate) - return audio_stream - - -def write_video_audio_ltx2( - video: list[Image.Image], - audio: torch.Tensor | None, - output_path: str, - fps: int = 24, - audio_sample_rate: int | None = None, -) -> None: - """ - Writes a sequence of images and an audio tensor to a video file. - - This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream - and multiplex a PyTorch tensor as the audio stream into the output container. - - Args: - video (list[Image.Image]): A list of PIL Image objects representing the video frames. - The length of this list determines the total duration of the video based on the FPS. - audio (torch.Tensor | None): The audio data as a PyTorch tensor. - The shape is typically (channels, samples). If no audio is required, pass None. - channels can be 1 or 2. 1 for mono, 2 for stereo. - output_path (str): The file path (including extension) where the output video will be saved. - fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24. - audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio. - If the audio tensor is provided and this is None, the function attempts to infer the rate - based on the audio tensor's length and the video duration. - Raises: - ValueError: If an audio tensor is provided but the sample rate cannot be determined. - """ - duration = len(video) / fps - if audio_sample_rate is None: - audio_sample_rate = int(audio.shape[-1] / duration) - - width, height = video[0].size - container = av.open(output_path, mode="w") - stream = container.add_stream("libx264", rate=int(fps)) - stream.width = width - stream.height = height - stream.pix_fmt = "yuv420p" - - if audio is not None: - if audio_sample_rate is None: - raise ValueError("audio_sample_rate is required when audio is provided") - audio_stream = _prepare_audio_stream(container, audio_sample_rate) - - for frame in tqdm(video, total=len(video)): - frame = av.VideoFrame.from_image(frame) - for packet in stream.encode(frame): - container.mux(packet) - - # Flush encoder - for packet in stream.encode(): - container.mux(packet) - - if audio is not None: - _write_audio(container, audio_stream, audio, audio_sample_rate) - - container.close() - - -def resample_waveform(waveform: torch.Tensor, source_rate: int, target_rate: int) -> torch.Tensor: - """Resample waveform to target sample rate if needed.""" - if source_rate == target_rate: - return waveform - resampled = torchaudio.functional.resample(waveform, source_rate, target_rate) - return resampled.to(dtype=waveform.dtype) - - -def read_audio_with_torchaudio( - path: str, - start_time: float = 0, - duration: float | None = None, - resample: bool = False, - resample_rate: int = 48000, -) -> tuple[torch.Tensor, int]: - waveform, sample_rate = torchaudio.load(path, channels_first=True) - if resample: - waveform = resample_waveform(waveform, sample_rate, resample_rate) - sample_rate = resample_rate - start_frame = int(start_time * sample_rate) - if start_frame > waveform.shape[-1]: - raise ValueError(f"start_time of {start_time} exceeds max duration of {waveform.shape[-1] / sample_rate:.2f}") - end_frame = None if duration is None else int(duration * sample_rate + start_frame) - return waveform[..., start_frame:end_frame], sample_rate +from .audio_video import write_video_audio as write_video_audio_ltx2 def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None: diff --git a/diffsynth/utils/data/media_io_mova.py b/diffsynth/utils/data/media_io_mova.py deleted file mode 100644 index 4d3c4772e..000000000 --- a/diffsynth/utils/data/media_io_mova.py +++ /dev/null @@ -1,104 +0,0 @@ -import os -import shutil -import subprocess -import tempfile -import wave - -import imageio -import numpy as np -import torch -from tqdm import tqdm -from ..data import merge_video_audio -try: - import imageio_ffmpeg as _imageio_ffmpeg -except ImportError: # pragma: no cover - _imageio_ffmpeg = None - - -def _write_wav_wave(audio, wav_path, sample_rate=44100): - """ - Write int16 PCM WAV using standard library wave. - - audio: torch.Tensor or np.ndarray, shape [samples] / [channels, samples] - - If float, assumed range is approximately [-1, 1], will be converted to int16 PCM - """ - if isinstance(audio, torch.Tensor): - a = audio.detach().cpu().numpy() - else: - a = np.asarray(audio) - - if a.ndim == 1: - a = a[None, :] - if a.ndim != 2: - raise ValueError(f"audio shape needs to be [S] / [C,S], current shape is {a.shape}") - - channels, samples = int(a.shape[0]), int(a.shape[1]) - if channels > 2: - a = a[:2, :] - channels = 2 - - if np.issubdtype(a.dtype, np.floating): - a = np.clip(a, -1.0, 1.0) - a = (a * 32767.0).astype(np.int16) - elif a.dtype != np.int16: - a = np.clip(a, -32768, 32767).astype(np.int16) - - if channels == 1: - interleaved = a.reshape(-1) - else: - interleaved = a.T.reshape(-1) - - with wave.open(wav_path, "wb") as wf: - wf.setnchannels(channels) - wf.setsampwidth(2) # int16 - wf.setframerate(int(sample_rate)) - wf.writeframes(interleaved.tobytes(order="C")) - - -def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None): - writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params) - for frame in tqdm(frames, desc="Saving video"): - frame = np.array(frame) - writer.append_data(frame) - writer.close() - - -# Copied from https://github.com/sgl-project/sglang/blob/7106f6c8e1509cd57abeafd5d50cb1beaffbc63c/python/sglang/multimodal_gen/runtime/entrypoints/utils.py#L96 -def _resolve_ffmpeg_exe() -> str: - ffmpeg_exe = "ffmpeg" - ffmpeg_on_path = shutil.which("ffmpeg") - if ffmpeg_on_path: - ffmpeg_exe = ffmpeg_on_path - try: - if _imageio_ffmpeg is not None: - ffmpeg_exe = _imageio_ffmpeg.get_ffmpeg_exe() - except Exception: - pass - - ffmpeg_ok = False - if ffmpeg_exe: - if os.path.isabs(ffmpeg_exe): - ffmpeg_ok = os.path.exists(ffmpeg_exe) - else: - ffmpeg_ok = shutil.which(ffmpeg_exe) is not None - if not ffmpeg_ok: - raise RuntimeError("ffmpeg not found") - return ffmpeg_exe - -def save_video_with_audio(frames, audio_data, save_path, fps=24, sample_rate=44100, quality=9, ffmpeg_path=None): - """ - Save video with audio. - - frames: List[PIL.Image | np.ndarray] - - audio: torch.Tensor or np.ndarray, shape [channels, samples] or [samples] - - save_path: Output mp4 path - - fps: Video frame rate - - sample_rate: Audio sample rate (default 44100) - Depend on ffmpeg executable program for audio/video reuse. - """ - if ffmpeg_path is None: - ffmpeg_path = _resolve_ffmpeg_exe() - - with tempfile.TemporaryDirectory(prefix='save_vwa_') as tmp_dir: - tmp_audio = os.path.join(tmp_dir, 'audio.wav') - save_video(frames, save_path, fps=fps, quality=quality) - _write_wav_wave(audio_data, tmp_audio, sample_rate=sample_rate) - merge_video_audio(save_path, tmp_audio) diff --git a/examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py b/examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py index 62b88782a..ad0b91834 100644 --- a/examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py +++ b/examples/ltx2/model_inference/LTX-2.3-A2V-TwoStage.py @@ -1,6 +1,7 @@ import torch from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig -from diffsynth.utils.data.media_io_ltx2 import read_audio_with_torchaudio, write_video_audio_ltx2 +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from diffsynth.utils.data.audio import read_audio from modelscope import dataset_snapshot_download vram_config = { @@ -42,7 +43,7 @@ ) height, width, num_frames, frame_rate = 512 * 2, 768 * 2, 121, 24 duration = num_frames / frame_rate -audio, audio_sample_rate = read_audio_with_torchaudio("data/example_video_dataset/ltx2/sing.MP3", start_time=1, duration=duration) +audio, audio_sample_rate = read_audio("data/example_video_dataset/ltx2/sing.MP3", start_time=1, duration=duration) video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, diff --git a/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage-Retake.py b/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage-Retake.py index e717fe9b1..d241f69da 100644 --- a/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage-Retake.py +++ b/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage-Retake.py @@ -1,6 +1,7 @@ import torch from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig -from diffsynth.utils.data.media_io_ltx2 import read_audio_with_torchaudio, write_video_audio_ltx2 +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from diffsynth.utils.data.audio import read_audio from modelscope import dataset_snapshot_download from diffsynth.utils.data import VideoData @@ -47,7 +48,7 @@ video = VideoData(path, height=height, width=width).raw_data()[:num_frames] assert len(video) == num_frames, f"Input video has {len(video)} frames, but expected {num_frames} frames based on the specified num_frames argument." duration = num_frames / frame_rate -audio, audio_sample_rate = read_audio_with_torchaudio(path) +audio, audio_sample_rate = read_audio(path) # Regenerate the video within time regions. You can specify different time regions for video frames and audio retake. # retake regions are in seconds, and the example below retakes video frames in the time regions of [1s, 2s] and [3s, 4s], and retakes audio in the time regions of [0s, 1s] and [4s, 5s]. diff --git a/examples/ltx2/model_inference_low_vram/LTX-2.3-A2V-TwoStage.py b/examples/ltx2/model_inference_low_vram/LTX-2.3-A2V-TwoStage.py index 267a30597..022160bb9 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2.3-A2V-TwoStage.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2.3-A2V-TwoStage.py @@ -1,6 +1,7 @@ import torch from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig -from diffsynth.utils.data.media_io_ltx2 import read_audio_with_torchaudio, write_video_audio_ltx2 +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from diffsynth.utils.data.audio import read_audio from modelscope import dataset_snapshot_download vram_config = { @@ -43,7 +44,7 @@ ) height, width, num_frames, frame_rate = 512 * 2, 768 * 2, 121, 24 duration = num_frames / frame_rate -audio, audio_sample_rate = read_audio_with_torchaudio("data/example_video_dataset/ltx2/sing.MP3", start_time=1, duration=duration) +audio, audio_sample_rate = read_audio("data/example_video_dataset/ltx2/sing.MP3", start_time=1, duration=duration) video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, diff --git a/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage-Retake.py b/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage-Retake.py index 7eb1e7eb8..65a6ebfe5 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage-Retake.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage-Retake.py @@ -1,6 +1,7 @@ import torch from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig -from diffsynth.utils.data.media_io_ltx2 import read_audio_with_torchaudio, write_video_audio_ltx2 +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from diffsynth.utils.data.audio import read_audio from modelscope import dataset_snapshot_download from diffsynth.utils.data import VideoData @@ -48,7 +49,7 @@ video = VideoData(path, height=height, width=width).raw_data()[:num_frames] assert len(video) == num_frames, f"Input video has {len(video)} frames, but expected {num_frames} frames based on the specified num_frames argument." duration = num_frames / frame_rate -audio, audio_sample_rate = read_audio_with_torchaudio(path) +audio, audio_sample_rate = read_audio(path) # Regenerate the video within time regions. You can specify different time regions for video frames and audio retake. # retake regions are in seconds, and the example below retakes video frames in the time regions of [1s, 2s] and [3s, 4s], and retakes audio in the time regions of [0s, 1s] and [4s, 5s]. diff --git a/examples/mova/acceleration/unified_sequence_parallel.py b/examples/mova/acceleration/unified_sequence_parallel.py index ff5db8ea2..d90036845 100644 --- a/examples/mova/acceleration/unified_sequence_parallel.py +++ b/examples/mova/acceleration/unified_sequence_parallel.py @@ -1,6 +1,6 @@ import torch from PIL import Image -from diffsynth.utils.data.media_io_mova import save_video_with_audio +from diffsynth.utils.data.audio_video import write_video_audio from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig import torch.distributed as dist @@ -52,4 +52,4 @@ frame_rate=frame_rate, ) if dist.get_rank() == 0: - save_video_with_audio(video, audio, "MOVA-360p-cat.mp4", fps=24, sample_rate=pipe.audio_vae.sample_rate) + write_video_audio(video, audio, "MOVA-360p-cat.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate) diff --git a/examples/mova/model_inference/MOVA-360p-TI2AV.py b/examples/mova/model_inference/MOVA-360p-TI2AV.py index 2ad77cd02..28e03f8cc 100644 --- a/examples/mova/model_inference/MOVA-360p-TI2AV.py +++ b/examples/mova/model_inference/MOVA-360p-TI2AV.py @@ -1,7 +1,7 @@ import torch from PIL import Image from diffsynth.pipelines.mova_audio_video import ModelConfig, MovaAudioVideoPipeline -from diffsynth.utils.data.media_io_mova import save_video_with_audio +from diffsynth.utils.data.audio_video import write_video_audio vram_config = { "offload_dtype": torch.bfloat16, @@ -49,4 +49,4 @@ tiled=True, frame_rate=frame_rate, ) -save_video_with_audio(video, audio, "MOVA-360p-cat_singlegpu_49.mp4", fps=24, sample_rate=pipe.audio_vae.sample_rate) +write_video_audio(video, audio, "MOVA-360p-cat.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate) diff --git a/examples/mova/model_inference/MOVA-720p-TI2AV.py b/examples/mova/model_inference/MOVA-720p-TI2AV.py index 6294d6a6f..22b4757d7 100644 --- a/examples/mova/model_inference/MOVA-720p-TI2AV.py +++ b/examples/mova/model_inference/MOVA-720p-TI2AV.py @@ -1,6 +1,6 @@ import torch from PIL import Image -from diffsynth.utils.data.media_io_mova import save_video_with_audio +from diffsynth.utils.data.audio_video import write_video_audio from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig @@ -50,4 +50,4 @@ tiled=True, frame_rate=frame_rate, ) -save_video_with_audio(video, audio, "MOVA-720p-cat.mp4", fps=24, sample_rate=pipe.audio_vae.sample_rate) +write_video_audio(video, audio, "MOVA-720p-cat.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate) diff --git a/pyproject.toml b/pyproject.toml index 5ddc5606d..05269c725 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,6 @@ requires-python = ">=3.10.1" dependencies = [ "torch>=2.0.0", "torchvision", - "torchaudio", "transformers", "imageio", "imageio[ffmpeg]", @@ -48,6 +47,10 @@ npu = [ "torch-npu==2.7.1", "torchvision==0.22.1+cpu" ] +audio = [ + "torchaudio", + "torchcodec" +] [tool.setuptools] include-package-data = true From b3ecbb62018f08d20cdbf919bcbd1a95fbcdd530 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Thu, 12 Mar 2026 17:50:57 +0800 Subject: [PATCH 4/6] support mova train --- diffsynth/diffusion/base_pipeline.py | 7 - diffsynth/pipelines/mova_audio_video.py | 50 ++--- .../{MOVA-360p-TI2AV.py => MOVA-360p-I2AV.py} | 0 .../{MOVA-720p-TI2AV.py => MOVA-720p-I2AV.py} | 0 .../model_training/full/MOVA-360P-I2AV.sh | 39 ++++ .../model_training/full/MOVA-720P-I2AV.sh | 39 ++++ .../model_training/lora/MOVA-360P-I2AV.sh | 43 ++++ .../model_training/lora/MOVA-720P-I2AV.sh | 43 ++++ examples/mova/model_training/train.py | 193 ++++++++++++++++++ .../validate_full/MOVA-360p-I2AV.py | 53 +++++ .../validate_full/MOVA-720p-I2AV.py | 54 +++++ .../validate_lora/MOVA-360p-I2AV.py | 54 +++++ .../validate_lora/MOVA-720p-I2AV.py | 54 +++++ 13 files changed, 599 insertions(+), 30 deletions(-) rename examples/mova/model_inference/{MOVA-360p-TI2AV.py => MOVA-360p-I2AV.py} (100%) rename examples/mova/model_inference/{MOVA-720p-TI2AV.py => MOVA-720p-I2AV.py} (100%) create mode 100644 examples/mova/model_training/full/MOVA-360P-I2AV.sh create mode 100644 examples/mova/model_training/full/MOVA-720P-I2AV.sh create mode 100644 examples/mova/model_training/lora/MOVA-360P-I2AV.sh create mode 100644 examples/mova/model_training/lora/MOVA-720P-I2AV.sh create mode 100644 examples/mova/model_training/train.py create mode 100644 examples/mova/model_training/validate_full/MOVA-360p-I2AV.py create mode 100644 examples/mova/model_training/validate_full/MOVA-720p-I2AV.py create mode 100644 examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py create mode 100644 examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index 8e3649a70..face31911 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -152,13 +152,6 @@ def output_audio_format_check(self, audio_output): # remove batch dim if audio_output.ndim == 3: audio_output = audio_output.squeeze(0) - # Transform to stereo - if audio_output.shape[0] == 1: - audio_output = audio_output.repeat(2, 1) - elif audio_output.shape[0] == 2: - pass - else: - raise ValueError("The output audio should be [C, T] or [1, C, T] or [2, C, T].") return audio_output.float() def load_models_to_device(self, model_names): diff --git a/diffsynth/pipelines/mova_audio_video.py b/diffsynth/pipelines/mova_audio_video.py index 2b6c7d8fa..89933b4eb 100644 --- a/diffsynth/pipelines/mova_audio_video.py +++ b/diffsynth/pipelines/mova_audio_video.py @@ -19,6 +19,7 @@ from ..models.mova_audio_dit import MovaAudioDit from ..models.mova_audio_vae import DacVAE from ..models.mova_dual_tower_bridge import DualTowerConditionalBridge +from ..utils.data.audio import convert_to_mono, resample_waveform class MovaAudioVideoPipeline(BasePipeline): @@ -81,12 +82,16 @@ def from_pretrained( # Fetch models pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder") - pipe.video_dit, pipe.video_dit2 = model_pool.fetch_model("wan_video_dit", index=2) + dit = model_pool.fetch_model("wan_video_dit", index=2) + if isinstance(dit, list): + pipe.video_dit, pipe.video_dit2 = dit + else: + pipe.video_dit = dit pipe.audio_dit = model_pool.fetch_model("mova_audio_dit") pipe.dual_tower_bridge = model_pool.fetch_model("mova_dual_tower_bridge") pipe.video_vae = model_pool.fetch_model("wan_video_vae") pipe.audio_vae = model_pool.fetch_model("mova_audio_vae") - set_to_torch_norm([pipe.video_dit, pipe.video_dit2, pipe.audio_dit, pipe.dual_tower_bridge]) + set_to_torch_norm([pipe.video_dit, pipe.audio_dit, pipe.dual_tower_bridge] + ([pipe.video_dit2] if pipe.video_dit2 is not None else [])) # Size division factor if pipe.video_vae is not None: @@ -185,7 +190,8 @@ def __call__( video = self.video_vae.decode(inputs_shared["video_latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) video = self.vae_output_to_video(video) self.load_models_to_device(["audio_vae"]) - audio = self.audio_vae.decode(inputs_shared["audio_latents"]).to(dtype=torch.float32, device='cpu').squeeze() + audio = self.audio_vae.decode(inputs_shared["audio_latents"]) + audio = self.output_audio_format_check(audio) self.load_models_to_device([]) return video, audio @@ -229,17 +235,13 @@ def __init__(self): ) def process(self, pipe: MovaAudioVideoPipeline, input_video, video_noise, tiled, tile_size, tile_stride): - if input_video is None: + if input_video is None or not pipe.scheduler.training: return {"video_latents": video_noise} - # TODO: check for train - pipe.load_models_to_device(self.onload_model_names) - input_video = pipe.preprocess_video(input_video) - input_latents = pipe.video_vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) - if pipe.scheduler.training: - return {"latents": video_noise, "input_latents": input_latents} else: - latents = pipe.scheduler.add_noise(input_latents, video_noise, timestep=pipe.scheduler.timesteps[0]) - return {"latents": latents} + pipe.load_models_to_device(self.onload_model_names) + input_video = pipe.preprocess_video(input_video) + input_latents = pipe.video_vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"input_latents": input_latents} class MovaAudioVideoUnit_InputAudioEmbedder(PipelineUnit): @@ -247,18 +249,19 @@ def __init__(self): super().__init__( input_params=("input_audio", "audio_noise"), output_params=("audio_latents", "audio_input_latents"), - onload_model_names=("audio_vae_encoder",) + onload_model_names=("audio_vae",) ) def process(self, pipe: MovaAudioVideoPipeline, input_audio, audio_noise): - if input_audio is None: + if input_audio is None or not pipe.scheduler.training: return {"audio_latents": audio_noise} else: - # TODO: support audio training - if pipe.scheduler.training: - return {"audio_latents": audio_noise, "audio_input_latents": audio_noise} - else: - raise NotImplementedError("Audio-to-video not supported.") + input_audio, sample_rate = input_audio + input_audio = convert_to_mono(input_audio) + input_audio = resample_waveform(input_audio, sample_rate, pipe.audio_vae.sample_rate) + input_audio = pipe.audio_vae.preprocess(input_audio.unsqueeze(0), pipe.audio_vae.sample_rate) + z, _, _, _, _ = pipe.audio_vae.encode(input_audio) + return {"audio_input_latents": z.mode()} class MovaAudioVideoUnit_PromptEmbedder(PipelineUnit): @@ -329,15 +332,16 @@ def process(self, pipe: MovaAudioVideoPipeline, input_image, end_image, num_fram y = y.to(dtype=pipe.torch_dtype, device=pipe.device) return {"y": y} + class MovaAudioVideoUnit_UnifiedSequenceParallel(PipelineUnit): def __init__(self): super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",)) def process(self, pipe: MovaAudioVideoPipeline): - if hasattr(pipe, "use_unified_sequence_parallel"): - if pipe.use_unified_sequence_parallel: - return {"use_unified_sequence_parallel": True} - return {} + if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel: + return {"use_unified_sequence_parallel": True} + return {"use_unified_sequence_parallel": False} + def model_fn_mova_audio_video( video_dit: WanModel, diff --git a/examples/mova/model_inference/MOVA-360p-TI2AV.py b/examples/mova/model_inference/MOVA-360p-I2AV.py similarity index 100% rename from examples/mova/model_inference/MOVA-360p-TI2AV.py rename to examples/mova/model_inference/MOVA-360p-I2AV.py diff --git a/examples/mova/model_inference/MOVA-720p-TI2AV.py b/examples/mova/model_inference/MOVA-720p-I2AV.py similarity index 100% rename from examples/mova/model_inference/MOVA-720p-TI2AV.py rename to examples/mova/model_inference/MOVA-720p-I2AV.py diff --git a/examples/mova/model_training/full/MOVA-360P-I2AV.sh b/examples/mova/model_training/full/MOVA-360P-I2AV.sh new file mode 100644 index 000000000..fa7c18c5a --- /dev/null +++ b/examples/mova/model_training/full/MOVA-360P-I2AV.sh @@ -0,0 +1,39 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \ + --dataset_base_path data/example_video_dataset/ltx2 \ + --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio,input_image" \ + --height 352 \ + --width 640 \ + --num_frames 121 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "openmoss/MOVA-360p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.video_dit." \ + --output_path "./models/train/MOVA-360p-I2AV_high_noise_full" \ + --trainable_models "dit" \ + --max_timestep_boundary 0.358 \ + --min_timestep_boundary 0 \ + --use_gradient_checkpointing +# boundary corresponds to timesteps [900, 1000] + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \ + --dataset_base_path data/example_video_dataset/ltx2 \ + --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio,input_image" \ + --height 352 \ + --width 640 \ + --num_frames 121 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "openmoss/MOVA-360p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.video_dit." \ + --output_path "./models/train/MOVA-360p-I2AV_low_noise_full" \ + --trainable_models "dit" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.358 \ + --use_gradient_checkpointing +# boundary corresponds to timesteps [0, 900) \ No newline at end of file diff --git a/examples/mova/model_training/full/MOVA-720P-I2AV.sh b/examples/mova/model_training/full/MOVA-720P-I2AV.sh new file mode 100644 index 000000000..955efb1c2 --- /dev/null +++ b/examples/mova/model_training/full/MOVA-720P-I2AV.sh @@ -0,0 +1,39 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \ + --dataset_base_path data/example_video_dataset/ltx2 \ + --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio,input_image" \ + --height 720 \ + --width 1280 \ + --num_frames 121 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "openmoss/MOVA-720p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.video_dit." \ + --output_path "./models/train/MOVA-720p-I2AV_high_noise_full" \ + --trainable_models "dit" \ + --max_timestep_boundary 0.358 \ + --min_timestep_boundary 0 \ + --use_gradient_checkpointing +# boundary corresponds to timesteps [900, 1000] + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \ + --dataset_base_path data/example_video_dataset/ltx2 \ + --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio,input_image" \ + --height 720 \ + --width 1280 \ + --num_frames 121 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "openmoss/MOVA-720p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.video_dit." \ + --output_path "./models/train/MOVA-720p-I2AV_low_noise_full" \ + --trainable_models "dit" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.358 \ + --use_gradient_checkpointing +# boundary corresponds to timesteps [0, 900) \ No newline at end of file diff --git a/examples/mova/model_training/lora/MOVA-360P-I2AV.sh b/examples/mova/model_training/lora/MOVA-360P-I2AV.sh new file mode 100644 index 000000000..0485968d7 --- /dev/null +++ b/examples/mova/model_training/lora/MOVA-360P-I2AV.sh @@ -0,0 +1,43 @@ +accelerate launch examples/mova/model_training/train.py \ + --dataset_base_path data/example_video_dataset/ltx2 \ + --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio,input_image" \ + --height 352 \ + --width 640 \ + --num_frames 121 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "openmoss/MOVA-360p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.video_dit." \ + --output_path "./models/train/MOVA-360p-I2AV_high_noise_lora" \ + --lora_base_model "video_dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --max_timestep_boundary 0.358 \ + --min_timestep_boundary 0 \ + --use_gradient_checkpointing +# boundary corresponds to timesteps [900, 1000] + +# accelerate launch examples/mova/model_training/train.py \ +# --dataset_base_path data/example_video_dataset/ltx2 \ +# --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ +# --data_file_keys "video,input_audio" \ +# --extra_inputs "input_audio,input_image" \ +# --height 352 \ +# --width 640 \ +# --num_frames 121 \ +# --dataset_repeat 100 \ +# --model_id_with_origin_paths "openmoss/MOVA-360p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \ +# --learning_rate 1e-4 \ +# --num_epochs 5 \ +# --remove_prefix_in_ckpt "pipe.video_dit." \ +# --output_path "./models/train/MOVA-360p-I2AV_low_noise_lora" \ +# --lora_base_model "video_dit" \ +# --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ +# --lora_rank 32 \ +# --max_timestep_boundary 1 \ +# --min_timestep_boundary 0.358 \ +# --use_gradient_checkpointing +# boundary corresponds to timesteps [0, 900) \ No newline at end of file diff --git a/examples/mova/model_training/lora/MOVA-720P-I2AV.sh b/examples/mova/model_training/lora/MOVA-720P-I2AV.sh new file mode 100644 index 000000000..ae3dae141 --- /dev/null +++ b/examples/mova/model_training/lora/MOVA-720P-I2AV.sh @@ -0,0 +1,43 @@ +accelerate launch examples/mova/model_training/train.py \ + --dataset_base_path data/example_video_dataset/ltx2 \ + --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio,input_image" \ + --height 720 \ + --width 1280 \ + --num_frames 121 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "openmoss/MOVA-720p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.video_dit." \ + --output_path "./models/train/MOVA-720p-I2AV_high_noise_lora" \ + --lora_base_model "video_dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --max_timestep_boundary 0.358 \ + --min_timestep_boundary 0 \ + --use_gradient_checkpointing +# boundary corresponds to timesteps [900, 1000] + +accelerate launch examples/mova/model_training/train.py \ + --dataset_base_path data/example_video_dataset/ltx2 \ + --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio,input_image" \ + --height 720 \ + --width 1280 \ + --num_frames 121 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "openmoss/MOVA-720p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.video_dit." \ + --output_path "./models/train/MOVA-720p-I2AV_low_noise_lora" \ + --lora_base_model "video_dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.358 \ + --use_gradient_checkpointing +# boundary corresponds to timesteps [0, 900) \ No newline at end of file diff --git a/examples/mova/model_training/train.py b/examples/mova/model_training/train.py new file mode 100644 index 000000000..24f08b1c9 --- /dev/null +++ b/examples/mova/model_training/train.py @@ -0,0 +1,193 @@ +import torch, os, argparse, accelerate, warnings +from diffsynth.core import UnifiedDataset +from diffsynth.core.data.operators import LoadAudioWithTorchaudio, ToAbsolutePath, RouteByType, SequencialProcess +from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig +from diffsynth.diffusion import * +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +class MOVATrainingModule(DiffusionTrainingModule): + def __init__( + self, + model_paths=None, model_id_with_origin_paths=None, + tokenizer_path=None, + trainable_models=None, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, + preset_lora_path=None, preset_lora_model=None, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + extra_inputs=None, + fp8_models=None, + offload_models=None, + device="cpu", + task="sft", + max_timestep_boundary=1.0, + min_timestep_boundary=0.0, + ): + super().__init__() + # Warning + if not use_gradient_checkpointing: + warnings.warn("Gradient checkpointing is detected as disabled. To prevent out-of-memory errors, the training framework will forcibly enable gradient checkpointing.") + use_gradient_checkpointing = True + + # Load models + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) + tokenizer_config = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized") if tokenizer_path is None else ModelConfig(tokenizer_path) + self.pipe = MovaAudioVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) + self.pipe = self.split_pipeline_units( + task, self.pipe, trainable_models, lora_base_model, + remove_unnecessary_params=True, + force_remove_params_shared=("audio_latents", "video_latents"), + force_remove_params_nega=("audio_context", "video_context") + ) + # Training mode + self.switch_pipe_to_training_mode( + self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint, + preset_lora_path, preset_lora_model, + task=task, + ) + + # Store other configs + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + self.fp8_models = fp8_models + self.task = task + self.task_to_loss = { + "sft:data_process": lambda pipe, *args: args, + "sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTAudioVideoLoss(pipe, **inputs_shared, **inputs_posi), + "sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTAudioVideoLoss(pipe, **inputs_shared, **inputs_posi), + } + self.max_timestep_boundary = max_timestep_boundary + self.min_timestep_boundary = min_timestep_boundary + + def parse_extra_inputs(self, data, extra_inputs, inputs_shared): + for extra_input in extra_inputs: + if extra_input == "input_image": + inputs_shared["input_image"] = data["video"][0] + else: + inputs_shared[extra_input] = data[extra_input] + return inputs_shared + + def get_pipeline_inputs(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {} + inputs_shared = { + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + "input_video": data["video"], + "height": data["video"][0].size[1], + "width": data["video"][0].size[0], + "num_frames": len(data["video"]), + "frame_rate": data.get("frame_rate", 24), + # Please do not modify the following parameters + # unless you clearly know what this will cause. + "cfg_scale": 1, + "tiled": False, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": self.use_gradient_checkpointing, + "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + "max_timestep_boundary": self.max_timestep_boundary, + "min_timestep_boundary": self.min_timestep_boundary, + } + inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) + return inputs_shared, inputs_posi, inputs_nega + + def forward(self, data, inputs=None): + if inputs is None: inputs = self.get_pipeline_inputs(data) + inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) + for unit in self.pipe.units: + inputs = self.pipe.unit_runner(unit, self.pipe, *inputs) + loss = self.task_to_loss[self.task](self.pipe, *inputs) + return loss + + +def ltx2_parser(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser = add_general_config(parser) + parser = add_video_size_config(parser) + parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") + parser.add_argument("--frame_rate", type=float, default=24, help="Frame rate of the training videos. Mova is trained with a frame rate of 24, so it's recommended to use the same frame rate.") + parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") + parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") + parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.") + return parser + + +if __name__ == "__main__": + parser = ltx2_parser() + args = parser.parse_args() + accelerator = accelerate.Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)], + ) + model = MOVATrainingModule( + model_paths=args.model_paths, + model_id_with_origin_paths=args.model_id_with_origin_paths, + tokenizer_path=args.tokenizer_path, + trainable_models=args.trainable_models, + lora_base_model=args.lora_base_model, + lora_target_modules=args.lora_target_modules, + lora_rank=args.lora_rank, + lora_checkpoint=args.lora_checkpoint, + preset_lora_path=args.preset_lora_path, + preset_lora_model=args.preset_lora_model, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + extra_inputs=args.extra_inputs, + fp8_models=args.fp8_models, + offload_models=args.offload_models, + task=args.task, + device="cpu" if args.initialize_model_on_cpu else accelerator.device, + max_timestep_boundary=args.max_timestep_boundary, + min_timestep_boundary=args.min_timestep_boundary, + ) + video_processor = UnifiedDataset.default_video_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=model.pipe.height_division_factor, + width_division_factor=model.pipe.width_division_factor, + num_frames=args.num_frames, + time_division_factor=model.pipe.time_division_factor, + time_division_remainder=model.pipe.time_division_remainder, + frame_rate=args.frame_rate, + fix_frame_rate=True, + ) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=video_processor, + special_operator_map={ + "input_audio": + ToAbsolutePath(args.dataset_base_path) >> LoadAudioWithTorchaudio( + num_frames=args.num_frames, + time_division_factor=model.pipe.time_division_factor, + time_division_remainder=model.pipe.time_division_remainder, + frame_rate=args.frame_rate, + ), + "in_context_videos": + RouteByType(operator_map=[ + (str, video_processor), + (list, SequencialProcess(video_processor)), + ]), + }, + ) + + model_logger = ModelLogger( + args.output_path, + remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, + ) + launcher_map = { + "sft:data_process": launch_data_process_task, + "direct_distill:data_process": launch_data_process_task, + "sft": launch_training_task, + "sft:train": launch_training_task, + "direct_distill": launch_training_task, + "direct_distill:train": launch_training_task, + } + launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) diff --git a/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py b/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py new file mode 100644 index 000000000..606880a44 --- /dev/null +++ b/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py @@ -0,0 +1,53 @@ +import torch +from PIL import Image +from diffsynth.pipelines.mova_audio_video import ModelConfig, MovaAudioVideoPipeline +from diffsynth.utils.data.audio_video import write_video_audio +from diffsynth.utils.data import VideoData + + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = MovaAudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(path="./models/train/MOVA-360p-I2AV_high_noise_full/epoch-4.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"), +) +negative_prompt = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指" +) +prompt = "A beautiful sunset over the ocean." +height, width, num_frames = 352, 640, 121 +frame_rate = 24 +input_image = VideoData("data/example_video_dataset/ltx2/video.mp4", height=height, width=width)[0] +# Image-to-video +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + input_image=input_image, + num_inference_steps=50, + seed=0, + tiled=True, + frame_rate=frame_rate, +) +write_video_audio(video, audio, "MOVA-360p.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate) diff --git a/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py b/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py new file mode 100644 index 000000000..8c0ef824d --- /dev/null +++ b/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py @@ -0,0 +1,54 @@ +import torch +from PIL import Image +from diffsynth.utils.data.audio_video import write_video_audio +from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig +from diffsynth.utils.data import VideoData + + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = MovaAudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(path="./models/train/MOVA-720p-I2AV_high_noise_full/epoch-4.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.video_dit, "models/train/MOVA-720p-I2AV_high_noise_lora/epoch-4.safetensors") +negative_prompt = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指" +) +prompt = "A beautiful sunset over the ocean." +height, width, num_frames = 720, 1280, 121 +frame_rate = 24 +input_image = VideoData("data/example_video_dataset/ltx2/video.mp4", height=height, width=width)[0] +# Image-to-video +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + input_image=input_image, + num_inference_steps=50, + seed=0, + tiled=True, + frame_rate=frame_rate, +) +write_video_audio(video, audio, "MOVA-720p.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate) diff --git a/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py b/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py new file mode 100644 index 000000000..00f437690 --- /dev/null +++ b/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py @@ -0,0 +1,54 @@ +import torch +from PIL import Image +from diffsynth.pipelines.mova_audio_video import ModelConfig, MovaAudioVideoPipeline +from diffsynth.utils.data.audio_video import write_video_audio +from diffsynth.utils.data import VideoData + + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = MovaAudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.video_dit, "models/train/MOVA-360p-I2AV_high_noise_lora/epoch-4.safetensors") +negative_prompt = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指" +) +prompt = "A beautiful sunset over the ocean." +height, width, num_frames = 352, 640, 121 +frame_rate = 24 +input_image = VideoData("data/example_video_dataset/ltx2/video.mp4", height=height, width=width)[0] +# Image-to-video +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + input_image=input_image, + num_inference_steps=50, + seed=0, + tiled=True, + frame_rate=frame_rate, +) +write_video_audio(video, audio, "MOVA-360p.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate) diff --git a/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py b/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py new file mode 100644 index 000000000..282a8b090 --- /dev/null +++ b/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py @@ -0,0 +1,54 @@ +import torch +from PIL import Image +from diffsynth.utils.data.audio_video import write_video_audio +from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig +from diffsynth.utils.data import VideoData + + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = MovaAudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.video_dit, "models/train/MOVA-720p-I2AV_high_noise_lora/epoch-4.safetensors") +negative_prompt = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指" +) +prompt = "A beautiful sunset over the ocean." +height, width, num_frames = 720, 1280, 121 +frame_rate = 24 +input_image = VideoData("data/example_video_dataset/ltx2/video.mp4", height=height, width=width)[0] +# Image-to-video +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + input_image=input_image, + num_inference_steps=50, + seed=0, + tiled=True, + frame_rate=frame_rate, +) +write_video_audio(video, audio, "MOVA-720p.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate) From d94be3fe01e97550554d4c8de637582acac7f6d7 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Thu, 12 Mar 2026 18:17:58 +0800 Subject: [PATCH 5/6] mova docs --- README.md | 5 +- README_zh.md | 4 ++ docs/en/Model_Details/Wan.md | 2 + docs/zh/Model_Details/Wan.md | 2 + .../mova/model_inference/MOVA-720p-I2AV.py | 3 +- .../MOVA-360p-I2AV.py | 53 +++++++++++++++++++ .../MOVA-720p-I2AV.py | 53 +++++++++++++++++++ 7 files changed, 119 insertions(+), 3 deletions(-) create mode 100644 examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py create mode 100644 examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py diff --git a/README.md b/README.md index 66dcefab0..b9f8ab02e 100644 --- a/README.md +++ b/README.md @@ -32,8 +32,9 @@ We believe that a well-developed open-source code framework can lower the thresh > DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update. > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. +- **January 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available. -- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/zh/Model_Details/LTX-2.md) and [code](/examples/ltx2/). +- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/). - **March 3, 2026**: We released the [DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2) model, which is an updated version of Qwen-Image-Layered-Control. In addition to the originally supported text-guided functionality, it adds brush-controlled layer separation capabilities. @@ -867,6 +868,8 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/) |[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)| |[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)| |[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)| +| [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) | +| [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) | diff --git a/README_zh.md b/README_zh.md index c7feb85ff..76b29d873 100644 --- a/README_zh.md +++ b/README_zh.md @@ -32,8 +32,10 @@ DiffSynth 目前包括两个开源项目: > DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。 > 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。 +- **2026年1月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。 - **2026年3月12日** 我们新增了 [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) 音视频生成模型的支持,模型支持的功能包括文生音视频、图生音视频、IC-LoRA控制、音频生视频、音视频局部Inpainting,框架支持完整的推理和训练功能。详细信息请参考 [文档](/docs/zh/Model_Details/LTX-2.md) 和 [示例代码](/examples/ltx2/)。 + - **2026年3月3日** 我们发布了 [DiffSynth-Studio/Qwen-Image-Layered-Control-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control-V2) 模型,这是 Qwen-Image-Layered-Control 的更新版本。除了原本就支持的文本引导功能,新增了画笔控制的图层拆分能力。 - **2026年3月2日** 新增对[Anima](https://modelscope.cn/models/circlestone-labs/Anima)的支持,详见[文档](docs/zh/Model_Details/Anima.md)。这是一个有趣的动漫风格图像生成模型,我们期待其后续的模型更新。 @@ -866,6 +868,8 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/) |[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)| |[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)| |[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)| +| [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) | +| [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) | diff --git a/docs/en/Model_Details/Wan.md b/docs/en/Model_Details/Wan.md index 20e12822b..805a06944 100644 --- a/docs/en/Model_Details/Wan.md +++ b/docs/en/Model_Details/Wan.md @@ -137,6 +137,8 @@ graph LR; | [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) | | [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) | | [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) | +| [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) | +| [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) | * FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/) * Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/) diff --git a/docs/zh/Model_Details/Wan.md b/docs/zh/Model_Details/Wan.md index 0144bd212..611a38cb2 100644 --- a/docs/zh/Model_Details/Wan.md +++ b/docs/zh/Model_Details/Wan.md @@ -138,6 +138,8 @@ graph LR; |[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)| |[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)| |[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)| +| [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) | +| [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) | * FP8 精度训练:[doc](../Training/FP8_Precision.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/) * 两阶段拆分训练:[doc](../Training/Split_Training.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/) diff --git a/examples/mova/model_inference/MOVA-720p-I2AV.py b/examples/mova/model_inference/MOVA-720p-I2AV.py index 22b4757d7..b82c4fc3c 100644 --- a/examples/mova/model_inference/MOVA-720p-I2AV.py +++ b/examples/mova/model_inference/MOVA-720p-I2AV.py @@ -3,7 +3,6 @@ from diffsynth.utils.data.audio_video import write_video_audio from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig - vram_config = { "offload_dtype": torch.bfloat16, "offload_device": "cpu", @@ -35,7 +34,7 @@ ) prompt = "Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other." height, width, num_frames = 720, 1280, 121 -frame_rate=24 +frame_rate = 24 input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((width, height)).convert("RGB") # Image-to-video video, audio = pipe( diff --git a/examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py b/examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py new file mode 100644 index 000000000..badd94978 --- /dev/null +++ b/examples/mova/model_inference_low_vram/MOVA-360p-I2AV.py @@ -0,0 +1,53 @@ +import torch +from PIL import Image +from diffsynth.pipelines.mova_audio_video import ModelConfig, MovaAudioVideoPipeline +from diffsynth.utils.data.audio_video import write_video_audio + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = MovaAudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-360p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) +negative_prompt = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指" +) + +prompt = "Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other." +height, width, num_frames = 352, 640, 121 +frame_rate = 24 +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((width, height)).convert("RGB") +# Image-to-video +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + input_image=input_image, + num_inference_steps=50, + seed=0, + tiled=True, + frame_rate=frame_rate, +) +write_video_audio(video, audio, "MOVA-360p-cat.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate) diff --git a/examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py b/examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py new file mode 100644 index 000000000..3d6888349 --- /dev/null +++ b/examples/mova/model_inference_low_vram/MOVA-720p-I2AV.py @@ -0,0 +1,53 @@ +import torch +from PIL import Image +from diffsynth.utils.data.audio_video import write_video_audio +from diffsynth.pipelines.mova_audio_video import MovaAudioVideoPipeline, ModelConfig + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = MovaAudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit/diffusion_pytorch_model-*.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="video_dit_2/diffusion_pytorch_model-*.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_dit/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="dual_tower_bridge/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="audio_vae/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="Wan2.1_VAE.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/Wan-Series-Converted-Safetensors", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="openmoss/MOVA-720p", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) + +negative_prompt = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指" +) +prompt = "Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other." +height, width, num_frames = 720, 1280, 121 +frame_rate = 24 +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((width, height)).convert("RGB") +# Image-to-video +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + input_image=input_image, + num_inference_steps=50, + seed=0, + tiled=True, + frame_rate=frame_rate, +) +write_video_audio(video, audio, "MOVA-720p-cat.mp4", fps=24, audio_sample_rate=pipe.audio_vae.sample_rate) From 24fa2b2b9f7cae8c85523736b022393463329a0a Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Thu, 12 Mar 2026 18:47:27 +0800 Subject: [PATCH 6/6] fix bug --- diffsynth/configs/vram_management_module_maps.py | 4 ++-- diffsynth/pipelines/mova_audio_video.py | 3 ++- diffsynth/utils/data/audio_video.py | 2 +- diffsynth/utils/xfuser/xdit_context_parallel.py | 1 - 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py index 0142958ad..de276891f 100644 --- a/diffsynth/configs/vram_management_module_maps.py +++ b/diffsynth/configs/vram_management_module_maps.py @@ -262,8 +262,8 @@ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", }, - "diffsynth.models.mova_audo_vae.DacVAE": { - "diffsynth.models.mova_audo_vae.Snake1d": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.mova_audio_vae.DacVAE": { + "diffsynth.models.mova_audio_vae.Snake1d": "diffsynth.core.vram.layers.AutoWrappedModule", "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", "torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule", }, diff --git a/diffsynth/pipelines/mova_audio_video.py b/diffsynth/pipelines/mova_audio_video.py index 89933b4eb..b74a648e3 100644 --- a/diffsynth/pipelines/mova_audio_video.py +++ b/diffsynth/pipelines/mova_audio_video.py @@ -230,7 +230,7 @@ class MovaAudioVideoUnit_InputVideoEmbedder(PipelineUnit): def __init__(self): super().__init__( input_params=("input_video", "video_noise", "tiled", "tile_size", "tile_stride"), - output_params=("latents", "input_latents"), + output_params=("video_latents", "input_latents"), onload_model_names=("video_vae",) ) @@ -256,6 +256,7 @@ def process(self, pipe: MovaAudioVideoPipeline, input_audio, audio_noise): if input_audio is None or not pipe.scheduler.training: return {"audio_latents": audio_noise} else: + pipe.load_models_to_device(self.onload_model_names) input_audio, sample_rate = input_audio input_audio = convert_to_mono(input_audio) input_audio = resample_waveform(input_audio, sample_rate, pipe.audio_vae.sample_rate) diff --git a/diffsynth/utils/data/audio_video.py b/diffsynth/utils/data/audio_video.py index 015434510..6914b2d12 100644 --- a/diffsynth/utils/data/audio_video.py +++ b/diffsynth/utils/data/audio_video.py @@ -39,7 +39,7 @@ def _write_audio( container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int ) -> None: if samples.ndim == 1: - samples = samples[:, None] + samples = samples.unsqueeze(0) samples = convert_to_stereo(samples) assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2" samples = samples.T diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index 0f5d10524..b9cf13f8c 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -6,7 +6,6 @@ get_sequence_parallel_world_size, get_sp_group) from xfuser.core.long_ctx_attention import xFuserLongContextAttention -import torch.distributed as dist from ... import IS_NPU_AVAILABLE from ...core.device import parse_nccl_backend, parse_device_type