Skip to content

[feat] Add activation offloading to training pipelines#1106

Open
ketakitank wants to merge 7 commits intohao-ai-lab:mainfrom
ketakitank:ketaki/activation-offloading
Open

[feat] Add activation offloading to training pipelines#1106
ketakitank wants to merge 7 commits intohao-ai-lab:mainfrom
ketakitank:ketaki/activation-offloading

Conversation

@ketakitank
Copy link
Contributor

This PR implements asynchronous activation offloading to reduce VRAM consumption during training. By leveraging PyTorch’s saved_tensors_hooks, intermediate activations are moved to CPU memory during the forward pass and fetched back just-in-time for the backward pass

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @ketakitank, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new feature that allows for asynchronous activation offloading during model training. By moving intermediate activations from GPU to CPU memory during the forward pass and fetching them back just-in-time for the backward pass, this change significantly reduces VRAM consumption, enabling the training of larger models or larger batch sizes on memory-constrained GPUs.

Highlights

  • Activation Offloading Argument: A new boolean argument, activation_offloading, was introduced to TrainingArgs and the command-line interface, allowing users to enable or disable this feature.
  • Memory Management Hooks: New pack_hook and unpack_hook functions were implemented to facilitate the asynchronous movement of activations between GPU and CPU memory during the forward and backward passes, respectively.
  • Dynamic Offloading Integration: The set_trainable method in ComposedPipelineBase was modified to conditionally apply the activation offloading mechanism to identified model blocks (e.g., layers, blocks, transformer_blocks) when the activation_offloading flag is enabled.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • fastvideo/fastvideo_args.py
    • Added activation_offloading boolean field to TrainingArgs.
    • Added --activation-offloading CLI argument with StoreBoolean action.
  • fastvideo/pipelines/composed_pipeline_base.py
    • Imported the functools module.
    • Defined pack_hook to move tensors from GPU to CPU.
    • Defined unpack_hook to move tensors from CPU back to GPU.
    • Defined offloaded_forward as a wrapper using torch.autograd.graph.saved_tensors_hooks.
    • Modified set_trainable to conditionally apply offloaded_forward to layers, blocks, or transformer_blocks within modules if activation_offloading is enabled.
Activity
  • No specific activity (comments, reviews, progress) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@ketakitank ketakitank marked this pull request as ready for review February 16, 2026 01:37
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces activation offloading to reduce VRAM consumption during training, which is a valuable feature. The implementation correctly uses PyTorch's saved_tensors_hooks. However, I've identified a critical bug in the monkey-patching logic that will cause infinite recursion during the forward pass. I have provided suggestions to resolve this issue. Additionally, I've included a recommendation to improve the robustness of the CPU/GPU memory transfers in multi-GPU environments.

Comment on lines +48 to +53
def offloaded_forward(module, *args, **kwargs):
"""
A wrapper for a module's forward pass that enables activation offloading.
"""
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
return module(*args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation of offloaded_forward is designed to take a module instance and call it. When used to monkey-patch a module's forward method, this will lead to infinite recursion because module(*args, **kwargs) eventually calls module.forward(). To fix this, the wrapper should accept the original forward method to be called, rather than the module instance itself. This is part of a two-part fix for a critical recursion bug.

Suggested change
def offloaded_forward(module, *args, **kwargs):
"""
A wrapper for a module's forward pass that enables activation offloading.
"""
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
return module(*args, **kwargs)
def offloaded_forward(original_forward, *args, **kwargs):
"""
A wrapper for a module's forward pass that enables activation offloading.
"""
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
return original_forward(*args, **kwargs)

Comment on lines +138 to +141
for i, layer in enumerate(blocks):
# Apply the monkeypatch
layer.forward = functools.partial(
offloaded_forward, layer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This monkey-patching logic will cause infinite recursion. It should be updated to pass the original forward method to the offloaded_forward wrapper, not the layer instance. This change, combined with the suggested change to offloaded_forward, will resolve the recursion. Additionally, the loop variable i from enumerate is unused and can be removed for cleaner code.

Suggested change
for i, layer in enumerate(blocks):
# Apply the monkeypatch
layer.forward = functools.partial(
offloaded_forward, layer)
for layer in blocks:
# Apply the monkeypatch
layer.forward = functools.partial(
offloaded_forward, layer.forward)

Comment on lines +34 to +45
def pack_hook(tensor: torch.Tensor):
"""
Moves activations from GPU to CPU memory during the forward pass.
"""
return tensor.to("cpu", non_blocking=True)


def unpack_hook(packed_tensor):
"""
Fetches activations back to the GPU from CPU memory during the backward pass.
"""
return packed_tensor.to("cuda", non_blocking=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current unpack_hook moves tensors to the current CUDA device via to("cuda"). This may not be robust in complex multi-GPU scenarios (e.g., model parallelism) where the current device might not be the tensor's original device. A safer approach is to store the tensor's original device in pack_hook and use it during unpack_hook. This also provides an opportunity to add the missing type hint for the unpack_hook's argument.

Suggested change
def pack_hook(tensor: torch.Tensor):
"""
Moves activations from GPU to CPU memory during the forward pass.
"""
return tensor.to("cpu", non_blocking=True)
def unpack_hook(packed_tensor):
"""
Fetches activations back to the GPU from CPU memory during the backward pass.
"""
return packed_tensor.to("cuda", non_blocking=True)
def pack_hook(tensor: torch.Tensor):
"""
Moves activations from GPU to CPU memory during the forward pass.
"""
return (tensor.to("cpu", non_blocking=True), tensor.device)
def unpack_hook(packed: tuple[torch.Tensor, torch.device]):
"""
Fetches activations back to the GPU from CPU memory during the backward pass.
"""
tensor, device = packed
return tensor.to(device, non_blocking=True)

@SolitaryThinker
Copy link
Collaborator

thanks! could you add some profiling for total memory usage?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants