From 90979be0547c5973ab5834c63b661ec4db14564e Mon Sep 17 00:00:00 2001 From: yuan Date: Wed, 11 Mar 2026 13:47:11 +0800 Subject: [PATCH] add VACE sequence parallel --- diffsynth/pipelines/wan_video.py | 29 +++++++++------- diffsynth/utils/xfuser/__init__.py | 3 +- .../utils/xfuser/xdit_context_parallel.py | 33 +++++++++++++++++++ 3 files changed, 52 insertions(+), 13 deletions(-) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index bbc479e29..fe884fe1c 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -83,7 +83,7 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): def enable_usp(self): - from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward + from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward, usp_vace_forward for block in self.dit.blocks: block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) @@ -92,6 +92,14 @@ def enable_usp(self): for block in self.dit2.blocks: block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2) + if self.vace is not None: + for block in self.vace.vace_blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.vace.forward = types.MethodType(usp_vace_forward, self.vace) + if self.vace2 is not None: + for block in self.vace2.vace_blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + self.vace2.forward = types.MethodType(usp_vace_forward, self.vace2) self.sp_size = get_sequence_parallel_world_size() self.use_unified_sequence_parallel = True @@ -1303,14 +1311,7 @@ def model_fn_wan_video( tea_cache_update = tea_cache.check(dit, x, t_mod) else: tea_cache_update = False - - if vace_context is not None: - vace_hints = vace( - x, vace_context, context, t_mod, freqs, - use_gradient_checkpointing=use_gradient_checkpointing, - use_gradient_checkpointing_offload=use_gradient_checkpointing_offload - ) - + # blocks if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: @@ -1318,6 +1319,13 @@ def model_fn_wan_video( pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] x = chunks[get_sequence_parallel_rank()] + + if vace_context is not None: + vace_hints = vace( + x, vace_context, context, t_mod, freqs, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload + ) if tea_cache_update: x = tea_cache.update(x) else: @@ -1356,9 +1364,6 @@ def custom_forward(*inputs): # VACE if vace_context is not None and block_id in vace.vace_layers_mapping: current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] - if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: - current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] - current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0) x = x + current_vace_hint * vace_scale # Animate diff --git a/diffsynth/utils/xfuser/__init__.py b/diffsynth/utils/xfuser/__init__.py index 13dd178e2..e1eb8c074 100644 --- a/diffsynth/utils/xfuser/__init__.py +++ b/diffsynth/utils/xfuser/__init__.py @@ -1 +1,2 @@ -from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp +from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, usp_vace_forward, get_sequence_parallel_world_size, initialize_usp + diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index 228e7b877..c38a24650 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -117,6 +117,39 @@ def usp_dit_forward(self, return x +def usp_vace_forward( + self, x, vace_context, context, t_mod, freqs, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, +): + # Compute full sequence length from the sharded x + full_seq_len = x.shape[1] * get_sequence_parallel_world_size() + + # Embed vace_context via patch embedding + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] + c = torch.cat([ + torch.cat([u, u.new_zeros(1, full_seq_len - u.size(1), u.size(2))], + dim=1) for u in c + ]) + + # Chunk VACE context along sequence dim BEFORE processing through blocks + c = torch.chunk(c, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + + # Process through vace_blocks (self_attn already monkey-patched to usp_attn_forward) + for block in self.vace_blocks: + c = gradient_checkpoint_forward( + block, + use_gradient_checkpointing, + use_gradient_checkpointing_offload, + c, x, context, t_mod, freqs + ) + + # Hints are already sharded per-rank + hints = torch.unbind(c)[:-1] + return hints + + def usp_attn_forward(self, x, freqs): q = self.norm_q(self.q(x)) k = self.norm_k(self.k(x))