Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions diffsynth/pipelines/wan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):


def enable_usp(self):
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward, usp_vace_forward

for block in self.dit.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
Expand All @@ -92,6 +92,14 @@ def enable_usp(self):
for block in self.dit2.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
if self.vace is not None:
for block in self.vace.vace_blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.vace.forward = types.MethodType(usp_vace_forward, self.vace)
if self.vace2 is not None:
for block in self.vace2.vace_blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.vace2.forward = types.MethodType(usp_vace_forward, self.vace2)
Comment on lines +95 to +102
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's duplicated logic for patching self.vace and self.vace2. This can be refactored to improve maintainability and reduce redundancy. You can use a loop to apply the same patching logic to both models.

Suggested change
if self.vace is not None:
for block in self.vace.vace_blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.vace.forward = types.MethodType(usp_vace_forward, self.vace)
if self.vace2 is not None:
for block in self.vace2.vace_blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.vace2.forward = types.MethodType(usp_vace_forward, self.vace2)
for vace_model in [self.vace, self.vace2]:
if vace_model is not None:
for block in vace_model.vace_blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
vace_model.forward = types.MethodType(usp_vace_forward, vace_model)

self.sp_size = get_sequence_parallel_world_size()
self.use_unified_sequence_parallel = True

Expand Down Expand Up @@ -1303,21 +1311,21 @@ def model_fn_wan_video(
tea_cache_update = tea_cache.check(dit, x, t_mod)
else:
tea_cache_update = False

if vace_context is not None:
vace_hints = vace(
x, vace_context, context, t_mod, freqs,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
)


# blocks
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
x = chunks[get_sequence_parallel_rank()]

if vace_context is not None:
vace_hints = vace(
x, vace_context, context, t_mod, freqs,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
)
if tea_cache_update:
x = tea_cache.update(x)
else:
Expand Down Expand Up @@ -1356,9 +1364,6 @@ def custom_forward(*inputs):
# VACE
if vace_context is not None and block_id in vace.vace_layers_mapping:
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
x = x + current_vace_hint * vace_scale

# Animate
Expand Down
3 changes: 2 additions & 1 deletion diffsynth/utils/xfuser/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, usp_vace_forward, get_sequence_parallel_world_size, initialize_usp

33 changes: 33 additions & 0 deletions diffsynth/utils/xfuser/xdit_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,39 @@ def usp_dit_forward(self,
return x


def usp_vace_forward(
self, x, vace_context, context, t_mod, freqs,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
):
# Compute full sequence length from the sharded x
full_seq_len = x.shape[1] * get_sequence_parallel_world_size()

# Embed vace_context via patch embedding
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
torch.cat([u, u.new_zeros(1, full_seq_len - u.size(1), u.size(2))],
dim=1) for u in c
])

# Chunk VACE context along sequence dim BEFORE processing through blocks
c = torch.chunk(c, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]

# Process through vace_blocks (self_attn already monkey-patched to usp_attn_forward)
for block in self.vace_blocks:
c = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
c, x, context, t_mod, freqs
)

# Hints are already sharded per-rank
hints = torch.unbind(c)[:-1]
return hints

Comment on lines +120 to +151
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The implementation of usp_vace_forward can be made more efficient and robust. The current use of list comprehensions to process vace_context is inefficient, especially if batching is used in the future. It's better to perform batched tensor operations directly. Additionally, the padding logic can be made more robust to handle cases where the sequence length of the context might not match full_seq_len exactly (e.g., if it's longer).

def usp_vace_forward(
    self, x, vace_context, context, t_mod, freqs,
    use_gradient_checkpointing: bool = False,
    use_gradient_checkpointing_offload: bool = False,
):
    # Compute full sequence length from the sharded x
    full_seq_len = x.shape[1] * get_sequence_parallel_world_size()

    # Embed vace_context via patch embedding and flatten
    c = self.vace_patch_embedding(vace_context)
    c = c.flatten(2).transpose(1, 2)

    # Pad or truncate to full sequence length
    seq_len_diff = full_seq_len - c.size(1)
    if seq_len_diff > 0:
        padding = c.new_zeros(c.size(0), seq_len_diff, c.size(2))
        c = torch.cat([c, padding], dim=1)
    elif seq_len_diff < 0:
        c = c[:, :full_seq_len]

    # Chunk VACE context along sequence dim BEFORE processing through blocks
    c = torch.chunk(c, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]

    # Process through vace_blocks (self_attn already monkey-patched to usp_attn_forward)
    for block in self.vace_blocks:
        c = gradient_checkpoint_forward(
            block,
            use_gradient_checkpointing,
            use_gradient_checkpointing_offload,
            c, x, context, t_mod, freqs
        )

    # Hints are already sharded per-rank
    hints = torch.unbind(c)[:-1]
    return hints


def usp_attn_forward(self, x, freqs):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
Expand Down