[feat] Add activation offloading to training pipelines#1106
[feat] Add activation offloading to training pipelines#1106ketakitank wants to merge 7 commits intohao-ai-lab:mainfrom
Conversation
Summary of ChangesHello @ketakitank, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new feature that allows for asynchronous activation offloading during model training. By moving intermediate activations from GPU to CPU memory during the forward pass and fetching them back just-in-time for the backward pass, this change significantly reduces VRAM consumption, enabling the training of larger models or larger batch sizes on memory-constrained GPUs. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces activation offloading to reduce VRAM consumption during training, which is a valuable feature. The implementation correctly uses PyTorch's saved_tensors_hooks. However, I've identified a critical bug in the monkey-patching logic that will cause infinite recursion during the forward pass. I have provided suggestions to resolve this issue. Additionally, I've included a recommendation to improve the robustness of the CPU/GPU memory transfers in multi-GPU environments.
| def offloaded_forward(module, *args, **kwargs): | ||
| """ | ||
| A wrapper for a module's forward pass that enables activation offloading. | ||
| """ | ||
| with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): | ||
| return module(*args, **kwargs) |
There was a problem hiding this comment.
The current implementation of offloaded_forward is designed to take a module instance and call it. When used to monkey-patch a module's forward method, this will lead to infinite recursion because module(*args, **kwargs) eventually calls module.forward(). To fix this, the wrapper should accept the original forward method to be called, rather than the module instance itself. This is part of a two-part fix for a critical recursion bug.
| def offloaded_forward(module, *args, **kwargs): | |
| """ | |
| A wrapper for a module's forward pass that enables activation offloading. | |
| """ | |
| with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): | |
| return module(*args, **kwargs) | |
| def offloaded_forward(original_forward, *args, **kwargs): | |
| """ | |
| A wrapper for a module's forward pass that enables activation offloading. | |
| """ | |
| with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): | |
| return original_forward(*args, **kwargs) |
| for i, layer in enumerate(blocks): | ||
| # Apply the monkeypatch | ||
| layer.forward = functools.partial( | ||
| offloaded_forward, layer) |
There was a problem hiding this comment.
This monkey-patching logic will cause infinite recursion. It should be updated to pass the original forward method to the offloaded_forward wrapper, not the layer instance. This change, combined with the suggested change to offloaded_forward, will resolve the recursion. Additionally, the loop variable i from enumerate is unused and can be removed for cleaner code.
| for i, layer in enumerate(blocks): | |
| # Apply the monkeypatch | |
| layer.forward = functools.partial( | |
| offloaded_forward, layer) | |
| for layer in blocks: | |
| # Apply the monkeypatch | |
| layer.forward = functools.partial( | |
| offloaded_forward, layer.forward) |
| def pack_hook(tensor: torch.Tensor): | ||
| """ | ||
| Moves activations from GPU to CPU memory during the forward pass. | ||
| """ | ||
| return tensor.to("cpu", non_blocking=True) | ||
|
|
||
|
|
||
| def unpack_hook(packed_tensor): | ||
| """ | ||
| Fetches activations back to the GPU from CPU memory during the backward pass. | ||
| """ | ||
| return packed_tensor.to("cuda", non_blocking=True) |
There was a problem hiding this comment.
The current unpack_hook moves tensors to the current CUDA device via to("cuda"). This may not be robust in complex multi-GPU scenarios (e.g., model parallelism) where the current device might not be the tensor's original device. A safer approach is to store the tensor's original device in pack_hook and use it during unpack_hook. This also provides an opportunity to add the missing type hint for the unpack_hook's argument.
| def pack_hook(tensor: torch.Tensor): | |
| """ | |
| Moves activations from GPU to CPU memory during the forward pass. | |
| """ | |
| return tensor.to("cpu", non_blocking=True) | |
| def unpack_hook(packed_tensor): | |
| """ | |
| Fetches activations back to the GPU from CPU memory during the backward pass. | |
| """ | |
| return packed_tensor.to("cuda", non_blocking=True) | |
| def pack_hook(tensor: torch.Tensor): | |
| """ | |
| Moves activations from GPU to CPU memory during the forward pass. | |
| """ | |
| return (tensor.to("cpu", non_blocking=True), tensor.device) | |
| def unpack_hook(packed: tuple[torch.Tensor, torch.device]): | |
| """ | |
| Fetches activations back to the GPU from CPU memory during the backward pass. | |
| """ | |
| tensor, device = packed | |
| return tensor.to(device, non_blocking=True) |
|
thanks! could you add some profiling for total memory usage? |
This PR implements asynchronous activation offloading to reduce VRAM consumption during training. By leveraging PyTorch’s saved_tensors_hooks, intermediate activations are moved to CPU memory during the forward pass and fetched back just-in-time for the backward pass