[feat] add VACE sequence parallel#1345
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the multi-GPU inference capabilities for the VACE model by introducing sequence parallelism. The core change involves a new forward function for VACE blocks that efficiently distributes computation across multiple GPUs, leading to a substantial speedup in inference. The modifications ensure that the VACE model can leverage distributed environments more effectively without impacting existing functionalities. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces sequence parallelism for the VACE block, which is a great feature for improving inference speed on multiple GPUs. The changes look mostly correct and follow the existing pattern for sequence parallelism in the repository.
I've left a couple of suggestions for improvement:
- Refactoring some duplicated code in
wan_video.pyto improve maintainability. - Improving the implementation of the new
usp_vace_forwardfunction inxdit_context_parallel.pyto be more efficient and robust by using batched tensor operations and handling sequence length mismatches.
After addressing these points, the PR should be in good shape.
| if self.vace is not None: | ||
| for block in self.vace.vace_blocks: | ||
| block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) | ||
| self.vace.forward = types.MethodType(usp_vace_forward, self.vace) | ||
| if self.vace2 is not None: | ||
| for block in self.vace2.vace_blocks: | ||
| block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) | ||
| self.vace2.forward = types.MethodType(usp_vace_forward, self.vace2) |
There was a problem hiding this comment.
There's duplicated logic for patching self.vace and self.vace2. This can be refactored to improve maintainability and reduce redundancy. You can use a loop to apply the same patching logic to both models.
| if self.vace is not None: | |
| for block in self.vace.vace_blocks: | |
| block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) | |
| self.vace.forward = types.MethodType(usp_vace_forward, self.vace) | |
| if self.vace2 is not None: | |
| for block in self.vace2.vace_blocks: | |
| block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) | |
| self.vace2.forward = types.MethodType(usp_vace_forward, self.vace2) | |
| for vace_model in [self.vace, self.vace2]: | |
| if vace_model is not None: | |
| for block in vace_model.vace_blocks: | |
| block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) | |
| vace_model.forward = types.MethodType(usp_vace_forward, vace_model) |
| def usp_vace_forward( | ||
| self, x, vace_context, context, t_mod, freqs, | ||
| use_gradient_checkpointing: bool = False, | ||
| use_gradient_checkpointing_offload: bool = False, | ||
| ): | ||
| # Compute full sequence length from the sharded x | ||
| full_seq_len = x.shape[1] * get_sequence_parallel_world_size() | ||
|
|
||
| # Embed vace_context via patch embedding | ||
| c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] | ||
| c = [u.flatten(2).transpose(1, 2) for u in c] | ||
| c = torch.cat([ | ||
| torch.cat([u, u.new_zeros(1, full_seq_len - u.size(1), u.size(2))], | ||
| dim=1) for u in c | ||
| ]) | ||
|
|
||
| # Chunk VACE context along sequence dim BEFORE processing through blocks | ||
| c = torch.chunk(c, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] | ||
|
|
||
| # Process through vace_blocks (self_attn already monkey-patched to usp_attn_forward) | ||
| for block in self.vace_blocks: | ||
| c = gradient_checkpoint_forward( | ||
| block, | ||
| use_gradient_checkpointing, | ||
| use_gradient_checkpointing_offload, | ||
| c, x, context, t_mod, freqs | ||
| ) | ||
|
|
||
| # Hints are already sharded per-rank | ||
| hints = torch.unbind(c)[:-1] | ||
| return hints | ||
|
|
There was a problem hiding this comment.
The implementation of usp_vace_forward can be made more efficient and robust. The current use of list comprehensions to process vace_context is inefficient, especially if batching is used in the future. It's better to perform batched tensor operations directly. Additionally, the padding logic can be made more robust to handle cases where the sequence length of the context might not match full_seq_len exactly (e.g., if it's longer).
def usp_vace_forward(
self, x, vace_context, context, t_mod, freqs,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
):
# Compute full sequence length from the sharded x
full_seq_len = x.shape[1] * get_sequence_parallel_world_size()
# Embed vace_context via patch embedding and flatten
c = self.vace_patch_embedding(vace_context)
c = c.flatten(2).transpose(1, 2)
# Pad or truncate to full sequence length
seq_len_diff = full_seq_len - c.size(1)
if seq_len_diff > 0:
padding = c.new_zeros(c.size(0), seq_len_diff, c.size(2))
c = torch.cat([c, padding], dim=1)
elif seq_len_diff < 0:
c = c[:, :full_seq_len]
# Chunk VACE context along sequence dim BEFORE processing through blocks
c = torch.chunk(c, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
# Process through vace_blocks (self_attn already monkey-patched to usp_attn_forward)
for block in self.vace_blocks:
c = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
c, x, context, t_mod, freqs
)
# Hints are already sharded per-rank
hints = torch.unbind(c)[:-1]
return hints
This repository currently does not optimize multi-GPU parallel inference for the VACE block of the VACE model. I referred to the multi-GPU parallel inference optimization in the original VACE repository and added the
usp_vace_forwardfunction to replace the forward function of the VACE block.With this PR, the inference speed of the VACE model has been accelerated by approximately 50%.
I have tested the Wan model and the VACE model, and the modified pipeline works fine for both models.