From 5f05817a0c62b4f246863c612536d2217abc2500 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 8 Apr 2026 10:00:14 +0800 Subject: [PATCH 01/13] docs: add multi-lora fsdp2 design --- .../2026-04-08-multi-lora-fsdp2-design.md | 132 ++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md diff --git a/docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md b/docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md new file mode 100644 index 00000000..717dc87d --- /dev/null +++ b/docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md @@ -0,0 +1,132 @@ +# Multi-LoRA Transformers FSDP2 Support Design + +## Summary + +Enable `MultiLoraTransformersModel` to run with FSDP2 model sharding while keeping `AccelerateStrategy` as the default strategy. The initial scope is intentionally narrow: support the training critical path and LoRA weight persistence for the transformers multi-LoRA backend, without expanding into sampler sync or broader checkpoint compatibility work. + +## Goal + +Make the following flow work when `device_mesh.fsdp_world_size > 1`: + +- construct `MultiLoraTransformersModel` +- add a LoRA adapter +- run `forward`, `calculate_loss`, `backward`, `step`, and `zero_grad` +- save and load LoRA adapter weights +- remove an adapter and restore the slot to its initial state + +## Non-Goals + +- No new sampler or checkpoint-engine synchronization behavior +- No dedicated `native_fsdp` path for multi-LoRA in this iteration +- No broader refactor to merge `MultiLoraTransformersModel` into the single-adapter transformers stack +- No new guarantees around optimizer state migration across sharding layouts +- No changes to megatron multi-LoRA behavior + +## Current Problem + +`MultiLoraTransformersModel` currently blocks FSDP usage and bypasses the normal transformers wrapping lifecycle: + +- it asserts that FSDP is unsupported during construction +- it always uses `AccelerateStrategy(device_mesh=None)` +- it eagerly wraps the model in `__init__` +- multi-LoRA save/load helpers assume local tensors and perform direct `parameter.data.copy_` writes + +That combination prevents Accelerate FSDP2 from sharding the model and makes adapter state handling unsafe once LoRA parameters become sharded tensors. + +## Proposed Approach + +Keep the existing class and multi-slot adapter design, but make it FSDP2-compatible under the default Accelerate path. + +### 1. Strategy and wrapping lifecycle + +Update `MultiLoraTransformersModel` so it no longer hard-disables FSDP. + +- remove the constructor assert that rejects FSDP +- instantiate `AccelerateStrategy` with the real `device_mesh` +- keep `multi_adapter.patch(self.model)` before any wrapping so the sharded model includes all LoRA slots +- stop eager wrapping in `__init__` +- implement `_lazy_wrap_model()` by reusing the parent lifecycle so wrapping happens after optimizers are created + +This preserves the current default strategy choice while allowing Accelerate's FSDP2 plugin to own sharding. + +### 2. DTensor-safe multi-LoRA weight access + +Adjust `MultiLora` helper methods so they work when LoRA parameters are represented as sharded tensors. + +Methods that need FSDP2-aware handling: + +- `save_initial_weights` +- `_load_initial_weights` +- `set_state_dict` +- `get_state_dict` +- `save_lora_converter` + +Design rules: + +- reading weights should operate on a local or reconstructed tensor view rather than assuming a plain parameter tensor +- writing weights should detect DTensor-like parameters and transform incoming checkpoint tensors to the target layout before copy +- LoRA rank slicing rules remain unchanged for `lora_A`, `lora_B`, `lora_embedding_A`, and `lora_embedding_B` + +The intent is not to create a fully generic distributed checkpoint layer, only to make the current LoRA slot persistence logic safe for FSDP2. + +### 3. Multi-LoRA load path + +Extend `MultiLoraTransformersModel.load()` to mirror the existing single-adapter transformers FSDP2 behavior. + +Current single-adapter transformers code already converts CPU adapter weights into the destination distributed layout before applying them. The multi-LoRA path should reuse the same idea, but route tensors into the tenant-owned slot inside `MultiLora` instead of using `set_peft_model_state_dict` directly. + +Expected behavior: + +- checkpoint weights load on CPU first +- keys are mapped into the real internal adapter slot +- tensors are distributed as needed to match the wrapped parameter layout +- values are copied into the correct LoRA slot for the tenant adapter + +### 4. Test coverage + +Add the smallest set of tests that proves the supported scope. + +Required checks: + +- constructing `MultiLoraTransformersModel` with an FSDP-enabled mesh no longer fails +- adding an adapter and running one training step succeeds under Accelerate FSDP2 +- `get_state_dict` followed by `load` round-trips LoRA weights correctly +- removing an adapter restores the slot to its initial values + +Test strategy: + +- prefer a focused regression test near the transformers model tests +- keep the model and world size small +- assert behavior on LoRA slot tensors, not just that no exception was raised + +## Implementation Outline + +1. Update `MultiLoraTransformersModel.__init__` to keep the default Accelerate path but pass through `device_mesh` +2. Move model wrapping out of construction and back into `_lazy_wrap_model` +3. Add FSDP2-safe tensor conversion helpers in `MultiLora` +4. Update multi-LoRA load and slot-reset code to use those helpers +5. Add regression tests for FSDP2 construction, train step, save/load, and remove + +## Risks + +- LoRA parameter layouts under Accelerate FSDP2 may differ from the assumptions used by direct slicing in patched LoRA forward code +- some helper methods may reconstruct full tensors when only local shards are needed, which can increase test-time memory usage +- eager assumptions about adapter activation order may be exposed once wrapping becomes lazy again + +## Deferred Work + +These should stay out of this change unless the minimal support cannot be made correct without them: + +- sampler LoRA sync compatibility +- optimizer state load/save guarantees under FSDP2 +- memory-efficient initialization specific to multi-LoRA +- a deeper unification with `TransformersModel` adapter management + +## Acceptance Criteria + +This design is complete when: + +- `MultiLoraTransformersModel` can run the narrow training flow under Accelerate FSDP2 +- multi-LoRA adapter save/load works for the supported LoRA tensor types +- regression tests cover the new supported behavior +- unsupported areas remain explicitly unchanged From b19373f724f33b0d27c619fb3f733693d8d5518a Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 8 Apr 2026 10:16:46 +0800 Subject: [PATCH 02/13] docs: expand multi-lora fsdp2 staged design --- .../2026-04-08-multi-lora-fsdp2-design.md | 351 ++++++++++++++---- 1 file changed, 271 insertions(+), 80 deletions(-) diff --git a/docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md b/docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md index 717dc87d..e3df61bd 100644 --- a/docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md +++ b/docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md @@ -1,132 +1,323 @@ -# Multi-LoRA Transformers FSDP2 Support Design +# Multi-LoRA Transformers FSDP2 Staged Support Design ## Summary -Enable `MultiLoraTransformersModel` to run with FSDP2 model sharding while keeping `AccelerateStrategy` as the default strategy. The initial scope is intentionally narrow: support the training critical path and LoRA weight persistence for the transformers multi-LoRA backend, without expanding into sampler sync or broader checkpoint compatibility work. +Enable `MultiLoraTransformersModel` to run under FSDP2 for both `AccelerateStrategy` and `NativeFSDPStrategy`, with staged delivery. The end goal is to run both SFT and GRPO training under FSDP2, but implementation will proceed in this order: -## Goal +1. `AccelerateStrategy + SFT` +2. `AccelerateStrategy + GRPO` +3. `native_fsdp + SFT` +4. `native_fsdp + GRPO` -Make the following flow work when `device_mesh.fsdp_world_size > 1`: +The design centers on a shared FSDP2 compatibility layer for transformers multi-LoRA, so each later stage builds on the same model lifecycle, adapter-slot semantics, and distributed weight handling. -- construct `MultiLoraTransformersModel` -- add a LoRA adapter -- run `forward`, `calculate_loss`, `backward`, `step`, and `zero_grad` -- save and load LoRA adapter weights -- remove an adapter and restore the slot to its initial state +## Goals + +### Final Goal + +Make `MultiLoraTransformersModel` work under FSDP2 for both SFT and GRPO, across both supported transformers strategies: + +- `AccelerateStrategy` +- `NativeFSDPStrategy` + +### Delivery Goal + +Ship the capability in four stages: + +1. `AccelerateStrategy + SFT` +2. `AccelerateStrategy + GRPO` +3. `native_fsdp + SFT` +4. `native_fsdp + GRPO` + +Each stage must leave the shared foundations in a state that later stages can reuse without strategy-specific rewrites. ## Non-Goals -- No new sampler or checkpoint-engine synchronization behavior -- No dedicated `native_fsdp` path for multi-LoRA in this iteration -- No broader refactor to merge `MultiLoraTransformersModel` into the single-adapter transformers stack -- No new guarantees around optimizer state migration across sharding layouts -- No changes to megatron multi-LoRA behavior +- No megatron multi-LoRA changes in this workstream +- No sampler or checkpoint-engine LoRA sync expansion in this workstream +- No attempt to solve all distributed checkpoint migration cases across arbitrary sharding layouts +- No requirement to support multi-adapter concurrent training in the initial stages +- No large-scale performance tuning or RL throughput optimization as part of correctness work +- No deep rewrite that merges `MultiLoraTransformersModel` into the single-adapter transformers implementation ## Current Problem -`MultiLoraTransformersModel` currently blocks FSDP usage and bypasses the normal transformers wrapping lifecycle: +`MultiLoraTransformersModel` currently does not participate correctly in the transformers FSDP2 lifecycle. + +Current blockers: + +- construction hard-rejects FSDP via an assert +- the class always uses `AccelerateStrategy(device_mesh=None)` +- the model is eagerly wrapped in `__init__` instead of participating in lazy wrap +- multi-LoRA weight helpers assume local tensors and directly mutate `parameter.data` +- save, load, and slot-reset behavior is not strategy-aware + +Because of that: + +- `AccelerateStrategy` cannot shard the model for multi-LoRA training +- `NativeFSDPStrategy` is not wired into multi-LoRA at all +- adapter slot persistence is unsafe once LoRA parameters become DTensors or other sharded parameter forms +- GRPO-specific paths such as `forward_only`, `disable_lora`, and server-side forward-backward entrypoints cannot be trusted under FSDP2 + +## Design Principles + +- Keep the existing multi-slot adapter model: tenants bind to preallocated internal LoRA slots +- Build one shared FSDP2 compatibility layer instead of separate accelerate-only and native-only implementations +- Let strategy differences stay in strategy code paths, not in duplicated multi-LoRA business logic +- Stage rollout by training scenario, but avoid temporary patches that block later phases +- Prefer the smallest regression tests that prove correctness at each phase + +## Proposed Architecture + +The work is split into two layers: + +1. A shared FSDP2 compatibility foundation for transformers multi-LoRA +2. A staged rollout over SFT and GRPO for accelerate and native FSDP2 + +### Shared Foundation + +The shared foundation must be correct before later phases can be added safely. + +#### 1. Unified strategy selection and lazy wrap lifecycle + +`MultiLoraTransformersModel` should stop managing wrapping as a special case. + +Required changes: + +- remove the constructor assert that blocks FSDP +- stop forcing `AccelerateStrategy(device_mesh=None)` +- let the class honor the requested strategy: + - default remains `AccelerateStrategy` + - `strategy='native_fsdp'` must instantiate `NativeFSDPStrategy` +- keep `multi_adapter.patch(self.model)` before wrapping so LoRA slots exist in the wrapped model graph +- move wrapping back into `_lazy_wrap_model()` so optimizer creation and strategy wrapping follow the same lifecycle as transformers models + +This is the main prerequisite for supporting both accelerate and native FSDP2 without forking the class. + +#### 2. Strategy-aware multi-LoRA parameter access + +`MultiLora` needs a small internal abstraction for reading and writing LoRA slot tensors under both unsharded and sharded parameter representations. + +This abstraction should support: + +- reading a saveable tensor view from LoRA slot parameters +- writing checkpoint tensors into LoRA slot parameters with the correct target layout +- restoring initial slot values when an adapter is removed +- preserving existing rank slicing rules for: + - `lora_A` + - `lora_B` + - `lora_embedding_A` + - `lora_embedding_B` + +The goal is not a generic distributed checkpoint layer. The goal is to make the current multi-LoRA slot logic safe under FSDP2. + +#### 3. Stable adapter-slot state machine after wrapping + +The tenant-to-slot model is part of the multi-LoRA contract and should remain unchanged. + +The following behaviors must stay valid after wrapping: -- it asserts that FSDP is unsupported during construction -- it always uses `AccelerateStrategy(device_mesh=None)` -- it eagerly wraps the model in `__init__` -- multi-LoRA save/load helpers assume local tensors and perform direct `parameter.data.copy_` writes +- `activate_adapter` +- `deactivate_adapter` +- `save_context` +- `remove_adapter` +- `disable_lora` inference paths -That combination prevents Accelerate FSDP2 from sharding the model and makes adapter state handling unsafe once LoRA parameters become sharded tensors. +This is especially important for GRPO, where policy-style and LoRA-disabled paths need to coexist without corrupting adapter state. -## Proposed Approach +#### 4. Unified save/load/remove semantics -Keep the existing class and multi-slot adapter design, but make it FSDP2-compatible under the default Accelerate path. +The same slot-aware semantics should apply regardless of strategy. -### 1. Strategy and wrapping lifecycle +Required behaviors: -Update `MultiLoraTransformersModel` so it no longer hard-disables FSDP. +- `get_state_dict` returns the tenant adapter's LoRA state with correct rank slicing +- `load` maps checkpoint weights into the correct internal slot +- `remove_adapter` restores the slot to its initial weights +- save/load logic works under both wrapped and unwrapped model states -- remove the constructor assert that rejects FSDP -- instantiate `AccelerateStrategy` with the real `device_mesh` -- keep `multi_adapter.patch(self.model)` before any wrapping so the sharded model includes all LoRA slots -- stop eager wrapping in `__init__` -- implement `_lazy_wrap_model()` by reusing the parent lifecycle so wrapping happens after optimizers are created +Single-adapter transformers already has FSDP2-aware load behavior. Multi-LoRA should reuse that idea, but route tensors through tenant-owned slots instead of direct single-adapter PEFT application. -This preserves the current default strategy choice while allowing Accelerate's FSDP2 plugin to own sharding. +#### 5. Reusable test scaffolding -### 2. DTensor-safe multi-LoRA weight access +Even though rollout is staged, the test base should be reusable from the start. -Adjust `MultiLora` helper methods so they work when LoRA parameters are represented as sharded tensors. +Shared fixtures/helpers should cover: -Methods that need FSDP2-aware handling: +- a minimal FSDP2-capable device mesh +- a minimal multi-LoRA transformers model builder +- adapter-slot inspection helpers +- minimal SFT input samples +- minimal GRPO input samples -- `save_initial_weights` -- `_load_initial_weights` -- `set_state_dict` -- `get_state_dict` -- `save_lora_converter` +This avoids rebuilding test infrastructure at every phase. -Design rules: +## Staged Rollout -- reading weights should operate on a local or reconstructed tensor view rather than assuming a plain parameter tensor -- writing weights should detect DTensor-like parameters and transform incoming checkpoint tensors to the target layout before copy -- LoRA rank slicing rules remain unchanged for `lora_A`, `lora_B`, `lora_embedding_A`, and `lora_embedding_B` +### Phase 1: `AccelerateStrategy + SFT` -The intent is not to create a fully generic distributed checkpoint layer, only to make the current LoRA slot persistence logic safe for FSDP2. +This is the first delivery milestone and the narrowest supported training loop. -### 3. Multi-LoRA load path +Supported flow: -Extend `MultiLoraTransformersModel.load()` to mirror the existing single-adapter transformers FSDP2 behavior. +- construct `MultiLoraTransformersModel` with accelerate FSDP2 +- add a single LoRA adapter +- run SFT training: + - `forward` + - `calculate_loss(CrossEntropyLoss)` + - `backward` + - `clip_grad_norm` + - `step` + - `zero_grad` +- save and load LoRA adapter state +- remove the adapter and restore the slot -Current single-adapter transformers code already converts CPU adapter weights into the destination distributed layout before applying them. The multi-LoRA path should reuse the same idea, but route tensors into the tenant-owned slot inside `MultiLora` instead of using `set_peft_model_state_dict` directly. +Implementation focus: -Expected behavior: +- strategy/lazy-wrap integration +- FSDP2-safe slot state persistence +- SFT regression coverage -- checkpoint weights load on CPU first -- keys are mapped into the real internal adapter slot -- tensors are distributed as needed to match the wrapped parameter layout -- values are copied into the correct LoRA slot for the tenant adapter +Explicitly out of scope for this phase: -### 4. Test coverage +- GRPO server entrypoints +- multi-adapter concurrent training +- sampler sync -Add the smallest set of tests that proves the supported scope. +### Phase 2: `AccelerateStrategy + GRPO` -Required checks: +Build on Phase 1 and add GRPO-specific correctness. -- constructing `MultiLoraTransformersModel` with an FSDP-enabled mesh no longer fails -- adding an adapter and running one training step succeeds under Accelerate FSDP2 -- `get_state_dict` followed by `load` round-trips LoRA weights correctly -- removing an adapter restores the slot to its initial values +Supported flow: -Test strategy: +- `forward_only` under wrapped accelerate FSDP2 +- `disable_lora` behavior for reference-style inference +- `GRPOLoss` training inputs, including `old_logps` and `advantages` +- minimal GRPO forward-backward-step flow through: + - twinkle-native path + - tinker-compatible server path -- prefer a focused regression test near the transformers model tests -- keep the model and world size small -- assert behavior on LoRA slot tensors, not just that no exception was raised +Implementation focus: -## Implementation Outline +- active adapter switching during GRPO paths +- correctness of `disable_lora` under wrapped PEFT model state +- minimal regression coverage for GRPO entrypoints -1. Update `MultiLoraTransformersModel.__init__` to keep the default Accelerate path but pass through `device_mesh` -2. Move model wrapping out of construction and back into `_lazy_wrap_model` -3. Add FSDP2-safe tensor conversion helpers in `MultiLora` -4. Update multi-LoRA load and slot-reset code to use those helpers -5. Add regression tests for FSDP2 construction, train step, save/load, and remove +Explicitly out of scope for this phase: + +- online sampler orchestration +- large-scale RL performance tuning + +### Phase 3: `native_fsdp + SFT` + +Add native FSDP2 support by reusing the shared foundation. + +Supported flow: + +- construct `MultiLoraTransformersModel` with `strategy='native_fsdp'` +- add a single LoRA adapter +- run the same minimal SFT loop as Phase 1 +- save, load, and remove LoRA adapters correctly + +Implementation focus: + +- compatibility with `NativeFSDPStrategy.wrap_model` +- parameter layout handling after `fully_shard` +- optimizer rebinding and lazy-wrap correctness +- native SFT regression coverage + +Key risk areas: + +- wrapped parameter representation under native FSDP2 +- optimizer param-group rebinding after wrapping +- LoRA forward patch assumptions when parameters are sharded + +### Phase 4: `native_fsdp + GRPO` + +Complete the matrix by validating GRPO under native FSDP2. + +Supported flow: + +- `forward_only` +- `disable_lora` +- `GRPOLoss` +- `clip_grad_norm` +- `step` +- minimal GRPO server/backend path + +Implementation focus: + +- native FSDP interaction with GRPO control flow +- slot-state correctness while toggling LoRA-enabled and LoRA-disabled execution +- final GRPO regression coverage under native FSDP2 + +## Test Plan + +Tests should expand phase by phase, but stay narrow and behavior-oriented. + +### Shared test requirements + +- use the smallest model and mesh that still exercises the target strategy +- assert adapter-slot tensor behavior, not only absence of exceptions +- validate save/load round-trip and slot reset where relevant + +### Phase-specific checks + +#### Phase 1 + +- accelerate FSDP2 construction succeeds +- SFT forward-backward-step succeeds +- LoRA state round-trips through save/load +- `remove_adapter` restores initial slot values + +#### Phase 2 + +- accelerate FSDP2 GRPO forward-backward path succeeds +- `disable_lora` behavior stays correct under wrapped execution +- minimal server/backend GRPO entrypoint succeeds + +#### Phase 3 + +- native FSDP2 construction succeeds +- native SFT forward-backward-step succeeds +- native save/load/remove semantics are correct + +#### Phase 4 + +- native FSDP2 GRPO forward-backward path succeeds +- native `disable_lora` path remains correct +- minimal native GRPO server/backend entrypoint succeeds ## Risks -- LoRA parameter layouts under Accelerate FSDP2 may differ from the assumptions used by direct slicing in patched LoRA forward code -- some helper methods may reconstruct full tensors when only local shards are needed, which can increase test-time memory usage -- eager assumptions about adapter activation order may be exposed once wrapping becomes lazy again +- patched LoRA forward code may assume local tensor access patterns that do not hold after sharding +- accelerate and native FSDP2 may expose LoRA parameters through different tensor representations +- some helper logic may accidentally reconstruct full tensors when only local shards are needed +- lazy wrap may surface ordering issues around optimizer creation, adapter activation, or template hooks +- GRPO adds path complexity because it mixes training, inference-style forward passes, and LoRA-disabled execution ## Deferred Work -These should stay out of this change unless the minimal support cannot be made correct without them: +These should remain outside this design unless later stages prove them necessary for correctness: -- sampler LoRA sync compatibility -- optimizer state load/save guarantees under FSDP2 -- memory-efficient initialization specific to multi-LoRA -- a deeper unification with `TransformersModel` adapter management +- sampler or checkpoint-engine LoRA synchronization enhancements +- memory-efficient-init customization specific to transformers multi-LoRA +- megatron multi-LoRA parity work +- broader adapter lifecycle redesign beyond the current slot model +- large-scale performance benchmarking or throughput tuning ## Acceptance Criteria -This design is complete when: +This design is complete when all four stages are delivered in order and each stage has dedicated regression coverage: + +1. `AccelerateStrategy + SFT` +2. `AccelerateStrategy + GRPO` +3. `native_fsdp + SFT` +4. `native_fsdp + GRPO` + +At the end of the full rollout: -- `MultiLoraTransformersModel` can run the narrow training flow under Accelerate FSDP2 -- multi-LoRA adapter save/load works for the supported LoRA tensor types -- regression tests cover the new supported behavior +- `MultiLoraTransformersModel` supports FSDP2 under both accelerate and native strategies +- SFT and GRPO both run through the supported transformers multi-LoRA paths +- save, load, and remove adapter semantics remain correct under FSDP2 - unsupported areas remain explicitly unchanged From 6f46dc3f3924bbd59697c7858ccc70d2a51b5d4a Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 8 Apr 2026 10:18:55 +0800 Subject: [PATCH 03/13] docs: narrow multi-lora fsdp2 scope to sft --- .../2026-04-08-multi-lora-fsdp2-design.md | 102 ++++-------------- 1 file changed, 20 insertions(+), 82 deletions(-) diff --git a/docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md b/docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md index e3df61bd..674b9d51 100644 --- a/docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md +++ b/docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md @@ -1,33 +1,29 @@ -# Multi-LoRA Transformers FSDP2 Staged Support Design +# Multi-LoRA Transformers FSDP2 SFT Support Design ## Summary -Enable `MultiLoraTransformersModel` to run under FSDP2 for both `AccelerateStrategy` and `NativeFSDPStrategy`, with staged delivery. The end goal is to run both SFT and GRPO training under FSDP2, but implementation will proceed in this order: +Enable `MultiLoraTransformersModel` to run under FSDP2 for SFT across both `AccelerateStrategy` and `NativeFSDPStrategy`, with staged delivery. Implementation will proceed in this order: 1. `AccelerateStrategy + SFT` -2. `AccelerateStrategy + GRPO` -3. `native_fsdp + SFT` -4. `native_fsdp + GRPO` +2. `native_fsdp + SFT` -The design centers on a shared FSDP2 compatibility layer for transformers multi-LoRA, so each later stage builds on the same model lifecycle, adapter-slot semantics, and distributed weight handling. +The design centers on a shared FSDP2 compatibility layer for transformers multi-LoRA, so the native FSDP stage builds on the same model lifecycle, adapter-slot semantics, and distributed weight handling established by the accelerate stage. ## Goals ### Final Goal -Make `MultiLoraTransformersModel` work under FSDP2 for both SFT and GRPO, across both supported transformers strategies: +Make `MultiLoraTransformersModel` work under FSDP2 for SFT across both supported transformers strategies: - `AccelerateStrategy` - `NativeFSDPStrategy` ### Delivery Goal -Ship the capability in four stages: +Ship the capability in two stages: 1. `AccelerateStrategy + SFT` -2. `AccelerateStrategy + GRPO` -3. `native_fsdp + SFT` -4. `native_fsdp + GRPO` +2. `native_fsdp + SFT` Each stage must leave the shared foundations in a state that later stages can reuse without strategy-specific rewrites. @@ -37,7 +33,8 @@ Each stage must leave the shared foundations in a state that later stages can re - No sampler or checkpoint-engine LoRA sync expansion in this workstream - No attempt to solve all distributed checkpoint migration cases across arbitrary sharding layouts - No requirement to support multi-adapter concurrent training in the initial stages -- No large-scale performance tuning or RL throughput optimization as part of correctness work +- No GRPO support in this workstream +- No large-scale performance tuning as part of correctness work - No deep rewrite that merges `MultiLoraTransformersModel` into the single-adapter transformers implementation ## Current Problem @@ -57,14 +54,14 @@ Because of that: - `AccelerateStrategy` cannot shard the model for multi-LoRA training - `NativeFSDPStrategy` is not wired into multi-LoRA at all - adapter slot persistence is unsafe once LoRA parameters become DTensors or other sharded parameter forms -- GRPO-specific paths such as `forward_only`, `disable_lora`, and server-side forward-backward entrypoints cannot be trusted under FSDP2 +- SFT under FSDP2 is currently blocked for both strategy paths ## Design Principles - Keep the existing multi-slot adapter model: tenants bind to preallocated internal LoRA slots - Build one shared FSDP2 compatibility layer instead of separate accelerate-only and native-only implementations - Let strategy differences stay in strategy code paths, not in duplicated multi-LoRA business logic -- Stage rollout by training scenario, but avoid temporary patches that block later phases +- Stage rollout by strategy, but avoid temporary patches that block later phases - Prefer the smallest regression tests that prove correctness at each phase ## Proposed Architecture @@ -72,7 +69,7 @@ Because of that: The work is split into two layers: 1. A shared FSDP2 compatibility foundation for transformers multi-LoRA -2. A staged rollout over SFT and GRPO for accelerate and native FSDP2 +2. A staged rollout over SFT for accelerate and native FSDP2 ### Shared Foundation @@ -121,9 +118,8 @@ The following behaviors must stay valid after wrapping: - `deactivate_adapter` - `save_context` - `remove_adapter` -- `disable_lora` inference paths -This is especially important for GRPO, where policy-style and LoRA-disabled paths need to coexist without corrupting adapter state. +This is important even in SFT because adapter activation, save/load, and slot reset must behave the same before and after wrapping. #### 4. Unified save/load/remove semantics @@ -148,7 +144,6 @@ Shared fixtures/helpers should cover: - a minimal multi-LoRA transformers model builder - adapter-slot inspection helpers - minimal SFT input samples -- minimal GRPO input samples This avoids rebuilding test infrastructure at every phase. @@ -180,35 +175,11 @@ Implementation focus: Explicitly out of scope for this phase: -- GRPO server entrypoints +- GRPO - multi-adapter concurrent training - sampler sync -### Phase 2: `AccelerateStrategy + GRPO` - -Build on Phase 1 and add GRPO-specific correctness. - -Supported flow: - -- `forward_only` under wrapped accelerate FSDP2 -- `disable_lora` behavior for reference-style inference -- `GRPOLoss` training inputs, including `old_logps` and `advantages` -- minimal GRPO forward-backward-step flow through: - - twinkle-native path - - tinker-compatible server path - -Implementation focus: - -- active adapter switching during GRPO paths -- correctness of `disable_lora` under wrapped PEFT model state -- minimal regression coverage for GRPO entrypoints - -Explicitly out of scope for this phase: - -- online sampler orchestration -- large-scale RL performance tuning - -### Phase 3: `native_fsdp + SFT` +### Phase 2: `native_fsdp + SFT` Add native FSDP2 support by reusing the shared foundation. @@ -232,25 +203,6 @@ Key risk areas: - optimizer param-group rebinding after wrapping - LoRA forward patch assumptions when parameters are sharded -### Phase 4: `native_fsdp + GRPO` - -Complete the matrix by validating GRPO under native FSDP2. - -Supported flow: - -- `forward_only` -- `disable_lora` -- `GRPOLoss` -- `clip_grad_norm` -- `step` -- minimal GRPO server/backend path - -Implementation focus: - -- native FSDP interaction with GRPO control flow -- slot-state correctness while toggling LoRA-enabled and LoRA-disabled execution -- final GRPO regression coverage under native FSDP2 - ## Test Plan Tests should expand phase by phase, but stay narrow and behavior-oriented. @@ -272,34 +224,22 @@ Tests should expand phase by phase, but stay narrow and behavior-oriented. #### Phase 2 -- accelerate FSDP2 GRPO forward-backward path succeeds -- `disable_lora` behavior stays correct under wrapped execution -- minimal server/backend GRPO entrypoint succeeds - -#### Phase 3 - - native FSDP2 construction succeeds - native SFT forward-backward-step succeeds - native save/load/remove semantics are correct -#### Phase 4 - -- native FSDP2 GRPO forward-backward path succeeds -- native `disable_lora` path remains correct -- minimal native GRPO server/backend entrypoint succeeds - ## Risks - patched LoRA forward code may assume local tensor access patterns that do not hold after sharding - accelerate and native FSDP2 may expose LoRA parameters through different tensor representations - some helper logic may accidentally reconstruct full tensors when only local shards are needed - lazy wrap may surface ordering issues around optimizer creation, adapter activation, or template hooks -- GRPO adds path complexity because it mixes training, inference-style forward passes, and LoRA-disabled execution ## Deferred Work These should remain outside this design unless later stages prove them necessary for correctness: +- GRPO support for either strategy - sampler or checkpoint-engine LoRA synchronization enhancements - memory-efficient-init customization specific to transformers multi-LoRA - megatron multi-LoRA parity work @@ -308,16 +248,14 @@ These should remain outside this design unless later stages prove them necessary ## Acceptance Criteria -This design is complete when all four stages are delivered in order and each stage has dedicated regression coverage: +This design is complete when both stages are delivered in order and each stage has dedicated regression coverage: 1. `AccelerateStrategy + SFT` -2. `AccelerateStrategy + GRPO` -3. `native_fsdp + SFT` -4. `native_fsdp + GRPO` +2. `native_fsdp + SFT` -At the end of the full rollout: +At the end of the rollout: - `MultiLoraTransformersModel` supports FSDP2 under both accelerate and native strategies -- SFT and GRPO both run through the supported transformers multi-LoRA paths +- SFT runs through the supported transformers multi-LoRA paths under both strategies - save, load, and remove adapter semantics remain correct under FSDP2 - unsupported areas remain explicitly unchanged From 2694bdc0f23ddbd53050bfbe43d1132ca818fa35 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 8 Apr 2026 16:38:54 +0800 Subject: [PATCH 04/13] docs: add multi-lora fsdp2 sft plan --- .../plans/2026-04-08-multi-lora-fsdp2-sft.md | 324 ++++++++++++++++++ 1 file changed, 324 insertions(+) create mode 100644 docs/superpowers/plans/2026-04-08-multi-lora-fsdp2-sft.md diff --git a/docs/superpowers/plans/2026-04-08-multi-lora-fsdp2-sft.md b/docs/superpowers/plans/2026-04-08-multi-lora-fsdp2-sft.md new file mode 100644 index 00000000..248b7a8e --- /dev/null +++ b/docs/superpowers/plans/2026-04-08-multi-lora-fsdp2-sft.md @@ -0,0 +1,324 @@ +# Multi-LoRA FSDP2 SFT Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Make `MultiLoraTransformersModel` support SFT under FSDP2 for both `AccelerateStrategy` and `NativeFSDPStrategy`, including adapter save/load/remove semantics. + +**Architecture:** Reuse one shared FSDP2 compatibility layer instead of building separate accelerate-only and native-only implementations. Keep the existing multi-slot adapter model, but move model wrapping back into the shared transformers lifecycle and teach `MultiLora` how to read and write LoRA slot weights safely when parameters are sharded. + +**Tech Stack:** Python 3.11+, PyTorch distributed FSDP2, Accelerate, PEFT LoRA, pytest, Twinkle transformers model stack + +--- + +## References + +- Spec: `docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md` +- Use `@test-driven-development` for every behavior change. +- Use `@verification-before-completion` before claiming any phase is done. + +## File Map + +- Modify: `src/twinkle/model/transformers/multi_lora_transformers.py` + - Rejoin the shared strategy-selection and lazy-wrap lifecycle. +- Modify: `src/twinkle/model/multi_lora.py` + - Add sharding-safe helpers for LoRA slot reads, writes, save/load, and reset. +- Create: `tests/model/transformers/test_multi_lora_fsdp2_sft.py` + - End-to-end regression coverage for accelerate and native FSDP2 SFT. +- Optional Create: `tests/model/test_multi_lora_state.py` + - Fast, non-distributed coverage for slot save/load/remove semantics if the distributed tests are too expensive to debug first. + +## Assumptions + +- GPU-backed tests are acceptable for the FSDP2 integration path. +- Test model path should resolve from `TEST_MODEL_ID` with an offline-cache fallback before attempting any network access. +- GRPO is out of scope for this plan. + +### Task 1: Create the Accelerate FSDP2 SFT regression test scaffold + +**Files:** +- Create: `tests/model/transformers/test_multi_lora_fsdp2_sft.py` +- Reference: `cookbook/transformers/fsdp2.py` +- Reference: `cookbook/transformers/sp_fsdp_dense.py` + +- [ ] **Step 1: Write the failing accelerate SFT test** + +```python +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Requires 2+ GPUs') +def test_multi_lora_accelerate_fsdp2_sft_round_trip(tmp_path): + model = build_multi_lora_model(strategy='accelerate') + model.add_adapter_to_model('default', build_lora_config(), gradient_accumulation_steps=1) + model.set_loss('CrossEntropyLoss', adapter_name='default') + model.set_optimizer('AdamW', lr=1e-4, adapter_name='default') + + batch = build_sft_batch() + model.forward_backward(inputs=batch, adapter_name='default') + model.clip_grad_and_step(adapter_name='default') + + state_before = model.get_state_dict(adapter_name='default') + model.save('ckpt', output_dir=str(tmp_path), adapter_name='default') + model.remove_adapter('default') + model.add_adapter_to_model('default', build_lora_config(), gradient_accumulation_steps=1) + model.load('ckpt', output_dir=str(tmp_path), adapter_name='default') + state_after = model.get_state_dict(adapter_name='default') + + assert_same_lora_state(state_before, state_after) +``` + +- [ ] **Step 2: Run the accelerate test to verify it fails** + +Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k accelerate -v` +Expected: FAIL because `MultiLoraTransformersModel` still blocks or bypasses FSDP2 wrapping. + +- [ ] **Step 3: Add test helpers inside the same file** + +```python +def build_multi_lora_model(strategy: str): + model_path = get_model_path() + mesh = DeviceMesh.from_sizes(world_size=2, fsdp_size=2) + return MultiLoraTransformersModel( + model_id=model_path, + device_mesh=mesh, + strategy=strategy, + ) + + +def build_sft_batch(): + return [{ + 'input_ids': [1, 2, 3, 4], + 'labels': [1, 2, 3, 4], + }] +``` + +- [ ] **Step 4: Re-run the accelerate test to confirm the failure is the intended one** + +Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k accelerate -v` +Expected: FAIL in model setup or FSDP2 execution, not from missing helper functions or syntax errors. + +- [ ] **Step 5: Commit the red test scaffold** + +```bash +git add tests/model/transformers/test_multi_lora_fsdp2_sft.py +git commit -m "test: add accelerate multi-lora fsdp2 sft regression" +``` + +### Task 2: Make `MultiLoraTransformersModel` participate in the shared wrap lifecycle + +**Files:** +- Modify: `src/twinkle/model/transformers/multi_lora_transformers.py` +- Reference: `src/twinkle/model/transformers/transformers.py` +- Test: `tests/model/transformers/test_multi_lora_fsdp2_sft.py` + +- [ ] **Step 1: Write the smallest failing assertion for strategy selection and lazy wrap** + +```python +def test_multi_lora_accelerate_fsdp2_uses_device_mesh(): + model = build_multi_lora_model(strategy='accelerate') + assert model.strategy.device_mesh is not None + assert model._model_wrapped is False +``` + +- [ ] **Step 2: Run the targeted test to verify it fails** + +Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k uses_device_mesh -v` +Expected: FAIL because the class still hardcodes `AccelerateStrategy(device_mesh=None)` and eagerly wraps in `__init__`. + +- [ ] **Step 3: Implement the minimal lifecycle fix** + +```python +class MultiLoraTransformersModel(TransformersModel, PreTrainedModel): + def __init__(..., strategy: Literal['accelerate', 'native_fsdp'] = 'accelerate', fsdp_config=None, ...): + self._fsdp_config = dict(fsdp_config or {}) + self._decide_strategy(strategy) + ... + self.model = self.multi_adapter.patch(self.model) + self._model_wrapped = False + + def _lazy_wrap_model(self): + return super()._lazy_wrap_model() +``` + +- [ ] **Step 4: Run the accelerate tests again** + +Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k "uses_device_mesh or accelerate" -v` +Expected: first lifecycle assertion passes; full accelerate round-trip test still fails in slot save/load or sharded tensor handling. + +- [ ] **Step 5: Commit the lifecycle change** + +```bash +git add src/twinkle/model/transformers/multi_lora_transformers.py tests/model/transformers/test_multi_lora_fsdp2_sft.py +git commit -m "feat: reuse shared fsdp lifecycle for multi-lora transformers" +``` + +### Task 3: Add sharding-safe LoRA slot tensor helpers in `MultiLora` + +**Files:** +- Modify: `src/twinkle/model/multi_lora.py` +- Test: `tests/model/transformers/test_multi_lora_fsdp2_sft.py` +- Optional Test: `tests/model/test_multi_lora_state.py` + +- [ ] **Step 1: Write a failing state round-trip test that isolates slot semantics** + +```python +def test_multi_lora_state_dict_round_trip_preserves_rank_slices(tmp_path): + model = build_multi_lora_model(strategy='accelerate') + model.add_adapter_to_model('default', build_lora_config(r=4), gradient_accumulation_steps=1) + state = model.get_state_dict(adapter_name='default') + assert state + assert all('.default.' not in key for key in state) +``` + +- [ ] **Step 2: Run the state round-trip test to verify it fails for the expected reason** + +Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k round_trip_preserves_rank_slices -v` +Expected: FAIL because current multi-LoRA save/load helpers assume local tensors and direct `parameter.data` writes. + +- [ ] **Step 3: Add minimal helper methods for slot IO** + +```python +def _read_param_tensor(self, parameter): + return torch_util.to_local_tensor(parameter) + + +def _write_param_tensor(self, parameter, value): + if hasattr(parameter, 'device_mesh') and hasattr(parameter, 'placements'): + value = distribute_tensor(value.to(parameter.device), parameter.device_mesh, parameter.placements) + parameter.data.copy_(value) +``` + +- [ ] **Step 4: Refactor `save_initial_weights`, `_load_initial_weights`, `set_state_dict`, `get_state_dict`, and `save_lora_converter` to use the new helpers** + +Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k accelerate -v` +Expected: the accelerate SFT round-trip test passes. + +- [ ] **Step 5: Commit the shared slot-IO layer** + +```bash +git add src/twinkle/model/multi_lora.py tests/model/transformers/test_multi_lora_fsdp2_sft.py +git commit -m "feat: support sharded multi-lora slot state io" +``` + +### Task 4: Add native FSDP2 SFT regression coverage + +**Files:** +- Modify: `tests/model/transformers/test_multi_lora_fsdp2_sft.py` +- Reference: `cookbook/transformers/sp_fsdp_dense.py` + +- [ ] **Step 1: Write the failing native FSDP SFT test** + +```python +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Requires 2+ GPUs') +def test_multi_lora_native_fsdp_sft_round_trip(tmp_path): + model = build_multi_lora_model(strategy='native_fsdp') + model.add_adapter_to_model('default', build_lora_config(), gradient_accumulation_steps=1) + model.set_loss('CrossEntropyLoss', adapter_name='default') + model.set_optimizer('AdamW', lr=1e-4, adapter_name='default') + model.forward_backward(inputs=build_sft_batch(), adapter_name='default') + model.clip_grad_and_step(adapter_name='default') + model.save('native-ckpt', output_dir=str(tmp_path), adapter_name='default') +``` + +- [ ] **Step 2: Run the native test to verify it fails** + +Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k native -v` +Expected: FAIL because native FSDP has not yet been exercised through multi-LoRA wrapping, optimizer rebinding, or slot restore. + +- [ ] **Step 3: Add native-specific assertions before implementation** + +```python +assert model._model_wrapped is False +assert model.device_mesh.fsdp_world_size == 2 +``` + +- [ ] **Step 4: Re-run the native test to verify the failure still points at native FSDP support** + +Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k native -v` +Expected: FAIL in native FSDP wrapping or native slot state handling, not in test setup. + +- [ ] **Step 5: Commit the native regression scaffold** + +```bash +git add tests/model/transformers/test_multi_lora_fsdp2_sft.py +git commit -m "test: add native multi-lora fsdp2 sft regression" +``` + +### Task 5: Make the native FSDP2 SFT path pass + +**Files:** +- Modify: `src/twinkle/model/transformers/multi_lora_transformers.py` +- Modify: `src/twinkle/model/multi_lora.py` +- Test: `tests/model/transformers/test_multi_lora_fsdp2_sft.py` + +- [ ] **Step 1: Implement the smallest native FSDP2 compatibility change** + +```python +model = MultiLoraTransformersModel( + ..., + strategy='native_fsdp', +) +``` + +Required behavior: +- strategy is selected through `_decide_strategy` +- wrapping still happens only through `_lazy_wrap_model` +- optimizer binding remains valid after wrap + +- [ ] **Step 2: Run the native regression to verify the new failure, if any** + +Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k native -v` +Expected: either PASS or a narrower failure around native slot IO / sharded forward behavior. + +- [ ] **Step 3: Fix the minimal remaining native-specific issue** + +```python +if strategy == 'native_fsdp': + # Keep multi-LoRA slot tensors readable/writable after fully_shard. + ... +``` + +- [ ] **Step 4: Run the full SFT regression file** + +Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -v` +Expected: PASS for both accelerate and native SFT tests. + +- [ ] **Step 5: Commit the native support** + +```bash +git add src/twinkle/model/transformers/multi_lora_transformers.py src/twinkle/model/multi_lora.py tests/model/transformers/test_multi_lora_fsdp2_sft.py +git commit -m "feat: support native fsdp2 sft for multi-lora transformers" +``` + +### Task 6: Final verification and cleanup + +**Files:** +- Modify: `src/twinkle/model/transformers/multi_lora_transformers.py` +- Modify: `src/twinkle/model/multi_lora.py` +- Modify: `tests/model/transformers/test_multi_lora_fsdp2_sft.py` + +- [ ] **Step 1: Run the shared targeted verification suite** + +Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -v` +Expected: PASS + +- [ ] **Step 2: Run one adjacent regression if save/load behavior was touched deeply** + +Run: `python -m pytest tests/sampler/test_weight_sync.py -k lora -v` +Expected: PASS or SKIP, with no new failures caused by LoRA state handling changes. + +- [ ] **Step 3: Inspect the final diff for scope discipline** + +Run: `git diff --stat HEAD~1..HEAD` +Expected: only multi-LoRA transformers model code and the new SFT regression tests are touched. + +- [ ] **Step 4: Document verification evidence in the final handoff** + +Required notes: +- exact test commands run +- whether each command passed, failed, or skipped +- any remaining risks, especially GPU-only coverage limitations + +- [ ] **Step 5: Commit cleanup if needed** + +```bash +git add src/twinkle/model/transformers/multi_lora_transformers.py src/twinkle/model/multi_lora.py tests/model/transformers/test_multi_lora_fsdp2_sft.py +git commit -m "test: verify multi-lora fsdp2 sft coverage" +``` From ec5f6a946ece48454b892e9c9e18b05cbb88df53 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 9 Apr 2026 08:45:58 +0800 Subject: [PATCH 05/13] wip --- src/twinkle/model/multi_lora.py | 85 ++++--- .../transformers/multi_lora_transformers.py | 29 ++- .../transformers/test_multi_lora_fsdp2_sft.py | 229 ++++++++++++++++++ 3 files changed, 296 insertions(+), 47 deletions(-) create mode 100644 tests/model/transformers/test_multi_lora_fsdp2_sft.py diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index 38b253ab..f46d3b3c 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -42,6 +42,44 @@ def _get_available_lora(self) -> Optional[LoraTenant]: return _lora return None + def _read_param_tensor(self, parameter): + return torch_util.to_local_tensor(parameter) + + def _write_param_tensor(self, parameter, value): + value = value.to(dtype=parameter.dtype) + if hasattr(parameter, 'device_mesh') and hasattr(parameter, 'placements'): + from torch.distributed.tensor import distribute_tensor + value = distribute_tensor(value.to(parameter.device), parameter.device_mesh, parameter.placements) + else: + value = value.to(parameter.device) + parameter.data.copy_(value) + + @staticmethod + def _slice_rank_tensor(name: str, tensor: torch.Tensor, rank: int) -> torch.Tensor: + if 'embedding_A' in name: + return tensor[:, :rank] + if 'embedding_B' in name: + return tensor[:rank, :] + if '_A' in name: + return tensor[:rank, :] + if '_B' in name: + return tensor[:, :rank] + return tensor + + @staticmethod + def _copy_rank_tensor(name: str, target: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + if 'embedding_A' in name: + target[:, :value.shape[1]].copy_(value) + elif 'embedding_B' in name: + target[:value.shape[0], :].copy_(value) + elif '_A' in name: + target[:value.shape[0], :].copy_(value) + elif '_B' in name: + target[:, :value.shape[1]].copy_(value) + else: + target.copy_(value) + return target + def _count_available_loras(self): return len([_lora for _lora in self.loras if _lora.tenant_adapter_name is None]) @@ -435,7 +473,7 @@ def save_initial_weights(self): def _store_weights(_module): for name, parameter in _module.named_parameters(): if pattern.search(name): - lora_tenant.lora_A_weights[name] = parameter.data.clone().to('cpu') + lora_tenant.lora_A_weights[name] = self._read_param_tensor(parameter).clone().to('cpu') if isinstance(self.module, list): for _module in self.module: @@ -482,15 +520,7 @@ def save_lora_converter(self, name, parameter, adapter_name): pattern_no_adapter = re.compile(r'\.lora_\w+\.weight') if (pattern.search(name) or pattern_no_adapter.search(name)) and self.match_target_modules( name, _lora.tenant_config.target_modules): - _param = torch_util.to_local_tensor(parameter) - if 'embedding_A' in name: - _param = _param[:, :_lora.tenant_config.r] - elif 'embedding_B' in name: - _param = _param[:_lora.tenant_config.r, :] - elif '_A' in name: - _param = _param[:_lora.tenant_config.r, :] - elif '_B' in name: - _param = _param[:, :_lora.tenant_config.r] + _param = self._slice_rank_tensor(name, self._read_param_tensor(parameter), _lora.tenant_config.r) name = name.replace(f'.{_lora.adapter_name}.', '.') return name, _param else: @@ -503,20 +533,11 @@ def set_state_dict(self, tenant_adapter_name, state_dict): def _load_weights(_module): for name, parameter in _module.named_parameters(): if pattern.search(name) and self.match_target_modules(name, _lora.tenant_config.target_modules): - name = name.replace(f'.{_lora.adapter_name}.', '.') - src_tensor = state_dict[name] - if 'embedding_A' in name: - r_saved = src_tensor.shape[1] - parameter.data[:, :r_saved].copy_(src_tensor) - elif 'embedding_B' in name: - r_saved = src_tensor.shape[0] - parameter.data[:r_saved, :].copy_(src_tensor) - elif '_A' in name: - r_saved = src_tensor.shape[0] - parameter.data[:r_saved, :].copy_(src_tensor) - elif '_B' in name: - r_saved = src_tensor.shape[1] - parameter.data[:, :r_saved].copy_(src_tensor) + state_key = name.replace(f'.{_lora.adapter_name}.', '.') + target_tensor = self._read_param_tensor(parameter).clone() + src_tensor = state_dict[state_key].to(dtype=target_tensor.dtype, device=target_tensor.device) + self._copy_rank_tensor(name, target_tensor, src_tensor) + self._write_param_tensor(parameter, target_tensor) if isinstance(self.module, list): for _module in self.module: @@ -533,15 +554,7 @@ def _get_weights(_module): state_dict = {} for name, parameter in _module.named_parameters(): if pattern.search(name) and self.match_target_modules(name, _lora.tenant_config.target_modules): - _param = torch_util.to_local_tensor(parameter) - if 'embedding_A' in name: - _param = _param[:, :_lora.tenant_config.r] - elif 'embedding_B' in name: - _param = _param[:_lora.tenant_config.r, :] - elif '_A' in name: - _param = _param[:_lora.tenant_config.r, :] - elif '_B' in name: - _param = _param[:, :_lora.tenant_config.r] + _param = self._slice_rank_tensor(name, self._read_param_tensor(parameter), _lora.tenant_config.r) name = name.replace(f'.{_lora.adapter_name}.', '.') state_dict[name] = _param return state_dict @@ -561,9 +574,11 @@ def _load_initial_weights(self, origin_adapter_name): def _load_initial_weights(_module): for name, parameter in _module.named_parameters(): if pattern_A.search(name): - parameter.data.copy_(_lora.lora_A_weights[name]) + target_device = self._read_param_tensor(parameter).device + value = _lora.lora_A_weights[name].to(dtype=parameter.dtype, device=target_device) + self._write_param_tensor(parameter, value) if pattern_B.search(name): - parameter.data.copy_(torch.zeros_like(parameter.data).to(parameter.data.dtype)) + self._write_param_tensor(parameter, torch.zeros_like(self._read_param_tensor(parameter))) if isinstance(self.module, list): for _module in self.module: diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index f7573f41..1ace5854 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -15,7 +15,6 @@ from twinkle.metric import Metric from twinkle.processor import InputProcessor from ..multi_lora import MultiLora -from .strategy import AccelerateStrategy from .transformers import OptimizerGroup, TransformersModel @@ -29,36 +28,42 @@ def __init__( config: Optional[PretrainedConfig] = None, device_mesh: Optional[DeviceMesh] = None, mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16', + strategy: Literal['accelerate', 'native_fsdp'] = 'accelerate', + ddp_config: Dict[str, Any] = None, + fsdp_config: Dict[str, Any] = None, grad_scaler_config: Dict[str, Any] = None, + memory_efficient_init: bool = False, max_loras: int = 5, max_r: int = 32, max_length: int = 8192, **kwargs): - assert device_mesh.fsdp_world_size <= 0, f'MultiLora does not support FSDP, current is: {str(device_mesh)}' os.environ['TOKENIZERS_PARALLELISM'] = 'true' self._try_init_process_group() super(PreTrainedModel, self).__init__() - model_id = HubOperation.download_model(model_id) - if isinstance(model_cls, str): - model_cls = getattr(transformers, model_cls) - self.model = model_cls.from_pretrained(model_id, config=config, **kwargs) self.model_id = model_id self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id) + self._default_tokenizer = None self.device_mesh = device_mesh self.mixed_precision = mixed_precision + self._fsdp_config = dict(fsdp_config or {}) + self._ddp_config = ddp_config or {} + self._memory_efficient_init = memory_efficient_init + self._decide_strategy(strategy) self.grad_scaler_config = grad_scaler_config + if isinstance(model_cls, str): + model_cls = getattr(transformers, model_cls) + model_id = HubOperation.download_model(model_id) + with self.strategy.pretrained_load_context(): + self.model = model_cls.from_pretrained(model_id, config=config, **kwargs) + self.model_id = model_id + self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id) self._model_wrapped = False self.sp_strategy = None # Initialize expert parallel attributes (required by set_optimizer in TransformersModel) - self._expert_parallel_config = None - self._enable_expert_parallel = False - self._expert_parallel_applied = False self.optimizer_group: Dict[str, OptimizerGroup] = {} self.multi_adapter = MultiLora(max_loras=max_loras, max_r=max_r, max_length=max_length) self.model.gradient_checkpointing_enable() self.model = self.multi_adapter.patch(self.model) - self.strategy = AccelerateStrategy(mixed_precision=mixed_precision, device_mesh=None) - self.model = self.strategy.wrap_model(self.model) self.multi_adapter.save_initial_weights() # Active group for compatibility with single adapter self.active_group = None @@ -88,7 +93,7 @@ def unregister_mm_forward_hook(self, optimizer_group: OptimizerGroup): pass def _lazy_wrap_model(self): - pass + return super()._lazy_wrap_model() @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): diff --git a/tests/model/transformers/test_multi_lora_fsdp2_sft.py b/tests/model/transformers/test_multi_lora_fsdp2_sft.py new file mode 100644 index 00000000..3ad08473 --- /dev/null +++ b/tests/model/transformers/test_multi_lora_fsdp2_sft.py @@ -0,0 +1,229 @@ +import os +import shutil +import tempfile +from pathlib import Path + +import pytest +import torch +import twinkle +from peft import LoraConfig +from tokenizers import Tokenizer +from tokenizers.models import WordLevel +from tokenizers.pre_tokenizers import Whitespace +from transformers import LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerFast + +from twinkle import DeviceMesh +from twinkle.model.multi_lora import MultiLora +from twinkle.model.transformers.multi_lora_transformers import MultiLoraTransformersModel + +TEST_MODEL_ID = os.environ.get('TEST_MODEL_ID') + + +def build_lora_config(r: int = 4) -> LoraConfig: + return LoraConfig( + r=r, + lora_alpha=max(8, r * 2), + target_modules='all-linear', + init_lora_weights=False, + ) + + +def build_sft_batch(): + return [{'input_ids': [1, 3, 4, 2], 'labels': [1, 3, 4, 2]}] + + +def assert_same_lora_state(state_before, state_after): + assert state_before.keys() == state_after.keys() + for name, value in state_before.items(): + assert torch.equal(value, state_after[name]), name + + +def _write_tiny_model_dir(model_dir: Path) -> str: + model_dir.mkdir(parents=True, exist_ok=True) + + config = LlamaConfig( + vocab_size=16, + hidden_size=16, + intermediate_size=32, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + max_position_embeddings=32, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ) + LlamaForCausalLM(config).save_pretrained(model_dir) + + vocab = { + '': 0, + '': 1, + '': 2, + 'hello': 3, + 'world': 4, + 'adapter': 5, + 'state': 6, + '': 7, + } + tokenizer = Tokenizer(WordLevel(vocab=vocab, unk_token='')) + tokenizer.pre_tokenizer = Whitespace() + fast_tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + unk_token='', + pad_token='', + bos_token='', + eos_token='', + ) + fast_tokenizer.save_pretrained(model_dir) + return str(model_dir) + + +@pytest.fixture +def model_path() -> str: + if TEST_MODEL_ID: + return TEST_MODEL_ID + return _write_tiny_model_dir(make_workspace_temp_dir('tiny-llama')) + + +def make_workspace_temp_dir(prefix: str) -> Path: + base_dir = Path.cwd() / '.codex_test_tmp' + base_dir.mkdir(parents=True, exist_ok=True) + return Path(tempfile.mkdtemp(prefix=f'{prefix}-', dir=base_dir)) + + +def build_device_mesh(fsdp_size: int = 2) -> DeviceMesh: + if torch.cuda.device_count() >= fsdp_size: + return DeviceMesh.from_sizes(fsdp_size=fsdp_size, device_type='cuda') + return DeviceMesh.from_sizes(world_size=1, dp_size=1, device_type='cpu') + + +def build_multi_lora_model(model_path: str, strategy: str, fsdp_size: int = 2): + mesh = build_device_mesh(fsdp_size=fsdp_size) + twinkle.initialize(mode='local', global_device_mesh=mesh) + return MultiLoraTransformersModel( + model_id=model_path, + device_mesh=mesh, + strategy=strategy, + ) + + +def build_multi_lora_state() -> MultiLora: + config = LlamaConfig( + vocab_size=16, + hidden_size=16, + intermediate_size=32, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + max_position_embeddings=32, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ) + multi_lora = MultiLora(max_loras=2, max_r=8, max_length=32) + multi_lora.patch(LlamaForCausalLM(config)) + multi_lora.save_initial_weights() + multi_lora.acquire_lora('default', build_lora_config(r=4)) + return multi_lora + + +def test_multi_lora_accelerate_fsdp2_uses_device_mesh(model_path): + model = build_multi_lora_model(model_path, strategy='accelerate') + assert model.strategy.device_mesh is not None + assert model._model_wrapped is False + + +def test_multi_lora_native_fsdp2_uses_lazy_wrap(model_path): + model = build_multi_lora_model(model_path, strategy='native_fsdp') + assert model.strategy.device_mesh is not None + assert model._model_wrapped is False + + +def test_multi_lora_state_dict_round_trip_preserves_rank_slices(): + multi_lora = build_multi_lora_state() + calls = [] + + def fake_read(parameter): + calls.append(tuple(parameter.shape)) + return parameter.detach().clone() + + multi_lora._read_param_tensor = fake_read + state = multi_lora.get_state_dict('default') + + assert state + assert all('.default.' not in key for key in state) + assert calls + + +def test_multi_lora_set_state_dict_uses_tensor_write_helper(): + multi_lora = build_multi_lora_state() + state = multi_lora.get_state_dict('default') + calls = [] + + def fake_write(parameter, value): + calls.append((tuple(parameter.shape), tuple(value.shape))) + parameter.data.copy_(value) + + multi_lora._write_param_tensor = fake_write + multi_lora.set_state_dict('default', state) + + assert calls + + +def test_multi_lora_load_initial_weights_uses_tensor_write_helper(): + multi_lora = build_multi_lora_state() + calls = [] + + def fake_write(parameter, value): + calls.append((tuple(parameter.shape), tuple(value.shape))) + parameter.data.copy_(value) + + multi_lora._write_param_tensor = fake_write + multi_lora._load_initial_weights('lora_0') + + assert calls + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Requires 2+ GPUs') +def test_multi_lora_accelerate_fsdp2_sft_round_trip(model_path): + output_dir = make_workspace_temp_dir('accelerate-ckpt') + model = build_multi_lora_model(model_path, strategy='accelerate') + model.add_adapter_to_model('default', build_lora_config(), gradient_accumulation_steps=1) + model.set_loss('CrossEntropyLoss', adapter_name='default') + model.set_optimizer('AdamW', lr=1e-4, adapter_name='default') + + batch = build_sft_batch() + model.forward_backward(inputs=batch, adapter_name='default') + model.clip_grad_and_step(adapter_name='default') + + state_before = model.get_state_dict(adapter_name='default') + model.save('ckpt', output_dir=str(output_dir), adapter_name='default') + model.remove_adapter('default') + model.add_adapter_to_model('default', build_lora_config(), gradient_accumulation_steps=1) + model.load('ckpt', output_dir=str(output_dir), adapter_name='default') + state_after = model.get_state_dict(adapter_name='default') + + assert_same_lora_state(state_before, state_after) + shutil.rmtree(output_dir, ignore_errors=True) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Requires 2+ GPUs') +def test_multi_lora_native_fsdp_sft_round_trip(model_path): + output_dir = make_workspace_temp_dir('native-ckpt') + model = build_multi_lora_model(model_path, strategy='native_fsdp') + model.add_adapter_to_model('default', build_lora_config(), gradient_accumulation_steps=1) + model.set_loss('CrossEntropyLoss', adapter_name='default') + model.set_optimizer('AdamW', lr=1e-4, adapter_name='default') + + model.forward_backward(inputs=build_sft_batch(), adapter_name='default') + model.clip_grad_and_step(adapter_name='default') + + state_before = model.get_state_dict(adapter_name='default') + model.save('native-ckpt', output_dir=str(output_dir), adapter_name='default') + model.remove_adapter('default') + model.add_adapter_to_model('default', build_lora_config(), gradient_accumulation_steps=1) + model.load('native-ckpt', output_dir=str(output_dir), adapter_name='default') + state_after = model.get_state_dict(adapter_name='default') + + assert_same_lora_state(state_before, state_after) + shutil.rmtree(output_dir, ignore_errors=True) From ee46e602bbaa7597851e4f97321abb00d62f4d36 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 9 Apr 2026 09:25:29 +0800 Subject: [PATCH 06/13] wip --- .../transformers/test_multi_lora_fsdp2_sft.py | 67 ++++++++++++++++++- 1 file changed, 65 insertions(+), 2 deletions(-) diff --git a/tests/model/transformers/test_multi_lora_fsdp2_sft.py b/tests/model/transformers/test_multi_lora_fsdp2_sft.py index 3ad08473..fc9fcbfe 100644 --- a/tests/model/transformers/test_multi_lora_fsdp2_sft.py +++ b/tests/model/transformers/test_multi_lora_fsdp2_sft.py @@ -91,9 +91,35 @@ def make_workspace_temp_dir(prefix: str) -> Path: return Path(tempfile.mkdtemp(prefix=f'{prefix}-', dir=base_dir)) +def _npu_device_count() -> int: + npu = getattr(torch, 'npu', None) + if npu is None: + try: + import torch_npu # noqa: F401 + except ImportError: + return 0 + npu = getattr(torch, 'npu', None) + + if npu is None: + return 0 + + try: + if npu.is_available(): + return npu.device_count() + except Exception: + return 0 + return 0 + + +def accelerator_device_count() -> int: + return max(torch.cuda.device_count(), _npu_device_count()) + + def build_device_mesh(fsdp_size: int = 2) -> DeviceMesh: if torch.cuda.device_count() >= fsdp_size: return DeviceMesh.from_sizes(fsdp_size=fsdp_size, device_type='cuda') + if _npu_device_count() >= fsdp_size: + return DeviceMesh.from_sizes(fsdp_size=fsdp_size, device_type='npu') return DeviceMesh.from_sizes(world_size=1, dp_size=1, device_type='cpu') @@ -139,6 +165,43 @@ def test_multi_lora_native_fsdp2_uses_lazy_wrap(model_path): assert model._model_wrapped is False +def test_build_device_mesh_prefers_npu_when_available(monkeypatch): + class FakeNPU: + + @staticmethod + def is_available(): + return True + + @staticmethod + def device_count(): + return 2 + + monkeypatch.setattr(torch.cuda, 'device_count', lambda: 0) + monkeypatch.setattr(torch, 'npu', FakeNPU(), raising=False) + + mesh = build_device_mesh(fsdp_size=2) + + assert mesh.device_type == 'npu' + assert mesh.fsdp_world_size == 2 + + +def test_accelerator_device_count_uses_npu_when_cuda_absent(monkeypatch): + class FakeNPU: + + @staticmethod + def is_available(): + return True + + @staticmethod + def device_count(): + return 2 + + monkeypatch.setattr(torch.cuda, 'device_count', lambda: 0) + monkeypatch.setattr(torch, 'npu', FakeNPU(), raising=False) + + assert accelerator_device_count() == 2 + + def test_multi_lora_state_dict_round_trip_preserves_rank_slices(): multi_lora = build_multi_lora_state() calls = [] @@ -184,7 +247,7 @@ def fake_write(parameter, value): assert calls -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Requires 2+ GPUs') +@pytest.mark.skipif(accelerator_device_count() < 2, reason='Requires 2+ CUDA GPUs or NPUs') def test_multi_lora_accelerate_fsdp2_sft_round_trip(model_path): output_dir = make_workspace_temp_dir('accelerate-ckpt') model = build_multi_lora_model(model_path, strategy='accelerate') @@ -207,7 +270,7 @@ def test_multi_lora_accelerate_fsdp2_sft_round_trip(model_path): shutil.rmtree(output_dir, ignore_errors=True) -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Requires 2+ GPUs') +@pytest.mark.skipif(accelerator_device_count() < 2, reason='Requires 2+ CUDA GPUs or NPUs') def test_multi_lora_native_fsdp_sft_round_trip(model_path): output_dir = make_workspace_temp_dir('native-ckpt') model = build_multi_lora_model(model_path, strategy='native_fsdp') From bd02a96aa654a2e8b96c9b078547b79183b63d28 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 9 Apr 2026 10:56:06 +0800 Subject: [PATCH 07/13] wip --- .../transformers/test_multi_lora_fsdp2_sft.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/model/transformers/test_multi_lora_fsdp2_sft.py b/tests/model/transformers/test_multi_lora_fsdp2_sft.py index fc9fcbfe..4e3a1b81 100644 --- a/tests/model/transformers/test_multi_lora_fsdp2_sft.py +++ b/tests/model/transformers/test_multi_lora_fsdp2_sft.py @@ -29,7 +29,12 @@ def build_lora_config(r: int = 4) -> LoraConfig: def build_sft_batch(): - return [{'input_ids': [1, 3, 4, 2], 'labels': [1, 3, 4, 2]}] + return [{ + 'input_ids': [1, 3, 4, 2], + 'attention_mask': [1, 1, 1, 1], + 'position_ids': [0, 1, 2, 3], + 'labels': [1, 3, 4, 2], + }] def assert_same_lora_state(state_before, state_after): @@ -202,6 +207,17 @@ def device_count(): assert accelerator_device_count() == 2 +def test_build_sft_batch_includes_processor_fields(): + batch = build_sft_batch() + + assert batch == [{ + 'input_ids': [1, 3, 4, 2], + 'attention_mask': [1, 1, 1, 1], + 'position_ids': [0, 1, 2, 3], + 'labels': [1, 3, 4, 2], + }] + + def test_multi_lora_state_dict_round_trip_preserves_rank_slices(): multi_lora = build_multi_lora_state() calls = [] From 6c359e0bdbb2e6ba65f3c23d16839ba903bea8d8 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 9 Apr 2026 11:03:40 +0800 Subject: [PATCH 08/13] wip --- src/twinkle/model/multi_lora.py | 2 +- .../transformers/test_multi_lora_fsdp2_sft.py | 28 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index f46d3b3c..2f6a18e1 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -46,7 +46,7 @@ def _read_param_tensor(self, parameter): return torch_util.to_local_tensor(parameter) def _write_param_tensor(self, parameter, value): - value = value.to(dtype=parameter.dtype) + value = value.detach().to(dtype=parameter.dtype) if hasattr(parameter, 'device_mesh') and hasattr(parameter, 'placements'): from torch.distributed.tensor import distribute_tensor value = distribute_tensor(value.to(parameter.device), parameter.device_mesh, parameter.placements) diff --git a/tests/model/transformers/test_multi_lora_fsdp2_sft.py b/tests/model/transformers/test_multi_lora_fsdp2_sft.py index 4e3a1b81..5d8f109a 100644 --- a/tests/model/transformers/test_multi_lora_fsdp2_sft.py +++ b/tests/model/transformers/test_multi_lora_fsdp2_sft.py @@ -263,6 +263,34 @@ def fake_write(parameter, value): assert calls +def test_multi_lora_write_param_tensor_distributes_leaf_tensor(monkeypatch): + multi_lora = MultiLora(max_loras=1, max_r=4, max_length=8) + parameter = torch.nn.Parameter(torch.zeros(2, 2)) + parameter.device_mesh = object() + parameter.placements = ('shard0', ) + + recorded = {} + + def fake_distribute_tensor(value, device_mesh, placements): + if not value.is_leaf: + raise RuntimeError('`distribute_tensor` should be used to distribute leaf tensors!') + recorded['is_leaf'] = value.is_leaf + recorded['device_mesh'] = device_mesh + recorded['placements'] = placements + return value + + monkeypatch.setattr('torch.distributed.tensor.distribute_tensor', fake_distribute_tensor) + + source = torch.ones(2, 2, requires_grad=True) * 3 + + multi_lora._write_param_tensor(parameter, source) + + assert recorded['is_leaf'] is True + assert recorded['device_mesh'] is parameter.device_mesh + assert recorded['placements'] == parameter.placements + assert torch.equal(parameter.data, torch.full((2, 2), 3.0)) + + @pytest.mark.skipif(accelerator_device_count() < 2, reason='Requires 2+ CUDA GPUs or NPUs') def test_multi_lora_accelerate_fsdp2_sft_round_trip(model_path): output_dir = make_workspace_temp_dir('accelerate-ckpt') From 26cada3fc5f47b8e12598b3db3a370a33c22b95b Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Mon, 13 Apr 2026 09:42:54 +0800 Subject: [PATCH 09/13] wip --- .../transformers/multi_lora_transformers.py | 8 ++- .../transformers/test_multi_lora_fsdp2_sft.py | 68 +++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index 1ace5854..1a3a13ea 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os +import torch.distributed as dist import transformers from peft import LoraConfig, PeftConfig, PeftModel, load_peft_weights from torch.optim import Optimizer @@ -225,7 +226,10 @@ def get_state_dict(self, **kwargs): def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) with self.multi_adapter.save_context(kwargs.get('adapter_name')): - return super().save(name, output_dir, interval, **kwargs) + checkpoint_dir = super().save(name, output_dir, interval, **kwargs) + if dist.is_initialized(): + dist.barrier() + return checkpoint_dir @remote_function() def load(self, name: Optional[str] = None, output_dir: Optional[str] = None, **kwargs): @@ -247,6 +251,8 @@ def load(self, name: Optional[str] = None, output_dir: Optional[str] = None, **k if load_optimizer: self._load_optimizer(checkpoint_dir, adapter_name=adapter_name) + if dist.is_initialized(): + dist.barrier() @remote_function() def set_grad_scaler(self, **kwargs): diff --git a/tests/model/transformers/test_multi_lora_fsdp2_sft.py b/tests/model/transformers/test_multi_lora_fsdp2_sft.py index 5d8f109a..68836c2c 100644 --- a/tests/model/transformers/test_multi_lora_fsdp2_sft.py +++ b/tests/model/transformers/test_multi_lora_fsdp2_sft.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext import os import shutil import tempfile @@ -15,6 +16,7 @@ from twinkle import DeviceMesh from twinkle.model.multi_lora import MultiLora from twinkle.model.transformers.multi_lora_transformers import MultiLoraTransformersModel +from twinkle.model.transformers.transformers import TransformersModel TEST_MODEL_ID = os.environ.get('TEST_MODEL_ID') @@ -291,6 +293,72 @@ def fake_distribute_tensor(value, device_mesh, placements): assert torch.equal(parameter.data, torch.full((2, 2), 3.0)) +def _build_stub_multi_lora_model(): + model = object.__new__(MultiLoraTransformersModel) + model._check_adapter_valid = lambda adapter_name: None + model.multi_adapter = type( + 'DummyMultiAdapter', + (), + { + 'save_context': staticmethod(lambda adapter_name: nullcontext()), + 'set_state_dict': lambda self, adapter_name, state_dict: None, + }, + )() + model.strategy = type('DummyStrategy', (), {'unwrap_model': lambda self, wrapped: wrapped})() + model.model = object() + model._load_optimizer = lambda checkpoint_dir, adapter_name=None: None + return model + + +def test_multi_lora_save_barriers_after_checkpoint_write(monkeypatch): + model = _build_stub_multi_lora_model() + events = [] + + def fake_save(self, name, output_dir=None, interval=1, **kwargs): + events.append('save') + return 'ckpt' + + monkeypatch.setattr(TransformersModel, 'save', fake_save) + monkeypatch.setattr('torch.distributed.is_initialized', lambda: True) + monkeypatch.setattr('torch.distributed.barrier', lambda: events.append('barrier')) + + checkpoint_dir = model.save('ckpt', output_dir='output', adapter_name='default') + + assert checkpoint_dir == 'ckpt' + assert events == ['save', 'barrier'] + + +def test_multi_lora_load_barriers_after_adapter_restore(monkeypatch): + model = _build_stub_multi_lora_model() + events = [] + + class FakePeftModel: + pass + + fake_peft_model = FakePeftModel() + model.model = fake_peft_model + + def fake_set_state_dict(adapter_name, state_dict): + events.append(('set_state_dict', adapter_name, state_dict)) + + model.multi_adapter.set_state_dict = fake_set_state_dict + + monkeypatch.setattr('twinkle.model.transformers.multi_lora_transformers.PeftModel', FakePeftModel) + monkeypatch.setattr('twinkle.model.transformers.multi_lora_transformers.load_peft_weights', + lambda checkpoint_dir, device='cpu': events.append(('load_peft_weights', checkpoint_dir, + device)) or {'layer.weight': torch.ones(1)}) + monkeypatch.setattr('torch.distributed.is_initialized', lambda: True) + monkeypatch.setattr('torch.distributed.barrier', lambda: events.append('barrier')) + + model.load('ckpt', output_dir='output', adapter_name='default') + + assert events == [ + ('load_peft_weights', os.path.join('output', 'ckpt'), 'cpu'), + ('set_state_dict', 'default', {'layer.weight': torch.ones(1)}), + 'barrier', + ] + + @pytest.mark.skipif(accelerator_device_count() < 2, reason='Requires 2+ CUDA GPUs or NPUs') def test_multi_lora_accelerate_fsdp2_sft_round_trip(model_path): output_dir = make_workspace_temp_dir('accelerate-ckpt') From f5127c71a7d47aa2f15f68e125fd73acf4fcd3fe Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 14 Apr 2026 11:06:48 +0800 Subject: [PATCH 10/13] wip --- .../plans/2026-04-08-multi-lora-fsdp2-sft.md | 324 -------------- .../2026-04-08-multi-lora-fsdp2-design.md | 261 ----------- .../transformers/test_multi_lora_fsdp2_sft.py | 404 ------------------ 3 files changed, 989 deletions(-) delete mode 100644 docs/superpowers/plans/2026-04-08-multi-lora-fsdp2-sft.md delete mode 100644 docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md delete mode 100644 tests/model/transformers/test_multi_lora_fsdp2_sft.py diff --git a/docs/superpowers/plans/2026-04-08-multi-lora-fsdp2-sft.md b/docs/superpowers/plans/2026-04-08-multi-lora-fsdp2-sft.md deleted file mode 100644 index 248b7a8e..00000000 --- a/docs/superpowers/plans/2026-04-08-multi-lora-fsdp2-sft.md +++ /dev/null @@ -1,324 +0,0 @@ -# Multi-LoRA FSDP2 SFT Implementation Plan - -> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. - -**Goal:** Make `MultiLoraTransformersModel` support SFT under FSDP2 for both `AccelerateStrategy` and `NativeFSDPStrategy`, including adapter save/load/remove semantics. - -**Architecture:** Reuse one shared FSDP2 compatibility layer instead of building separate accelerate-only and native-only implementations. Keep the existing multi-slot adapter model, but move model wrapping back into the shared transformers lifecycle and teach `MultiLora` how to read and write LoRA slot weights safely when parameters are sharded. - -**Tech Stack:** Python 3.11+, PyTorch distributed FSDP2, Accelerate, PEFT LoRA, pytest, Twinkle transformers model stack - ---- - -## References - -- Spec: `docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md` -- Use `@test-driven-development` for every behavior change. -- Use `@verification-before-completion` before claiming any phase is done. - -## File Map - -- Modify: `src/twinkle/model/transformers/multi_lora_transformers.py` - - Rejoin the shared strategy-selection and lazy-wrap lifecycle. -- Modify: `src/twinkle/model/multi_lora.py` - - Add sharding-safe helpers for LoRA slot reads, writes, save/load, and reset. -- Create: `tests/model/transformers/test_multi_lora_fsdp2_sft.py` - - End-to-end regression coverage for accelerate and native FSDP2 SFT. -- Optional Create: `tests/model/test_multi_lora_state.py` - - Fast, non-distributed coverage for slot save/load/remove semantics if the distributed tests are too expensive to debug first. - -## Assumptions - -- GPU-backed tests are acceptable for the FSDP2 integration path. -- Test model path should resolve from `TEST_MODEL_ID` with an offline-cache fallback before attempting any network access. -- GRPO is out of scope for this plan. - -### Task 1: Create the Accelerate FSDP2 SFT regression test scaffold - -**Files:** -- Create: `tests/model/transformers/test_multi_lora_fsdp2_sft.py` -- Reference: `cookbook/transformers/fsdp2.py` -- Reference: `cookbook/transformers/sp_fsdp_dense.py` - -- [ ] **Step 1: Write the failing accelerate SFT test** - -```python -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Requires 2+ GPUs') -def test_multi_lora_accelerate_fsdp2_sft_round_trip(tmp_path): - model = build_multi_lora_model(strategy='accelerate') - model.add_adapter_to_model('default', build_lora_config(), gradient_accumulation_steps=1) - model.set_loss('CrossEntropyLoss', adapter_name='default') - model.set_optimizer('AdamW', lr=1e-4, adapter_name='default') - - batch = build_sft_batch() - model.forward_backward(inputs=batch, adapter_name='default') - model.clip_grad_and_step(adapter_name='default') - - state_before = model.get_state_dict(adapter_name='default') - model.save('ckpt', output_dir=str(tmp_path), adapter_name='default') - model.remove_adapter('default') - model.add_adapter_to_model('default', build_lora_config(), gradient_accumulation_steps=1) - model.load('ckpt', output_dir=str(tmp_path), adapter_name='default') - state_after = model.get_state_dict(adapter_name='default') - - assert_same_lora_state(state_before, state_after) -``` - -- [ ] **Step 2: Run the accelerate test to verify it fails** - -Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k accelerate -v` -Expected: FAIL because `MultiLoraTransformersModel` still blocks or bypasses FSDP2 wrapping. - -- [ ] **Step 3: Add test helpers inside the same file** - -```python -def build_multi_lora_model(strategy: str): - model_path = get_model_path() - mesh = DeviceMesh.from_sizes(world_size=2, fsdp_size=2) - return MultiLoraTransformersModel( - model_id=model_path, - device_mesh=mesh, - strategy=strategy, - ) - - -def build_sft_batch(): - return [{ - 'input_ids': [1, 2, 3, 4], - 'labels': [1, 2, 3, 4], - }] -``` - -- [ ] **Step 4: Re-run the accelerate test to confirm the failure is the intended one** - -Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k accelerate -v` -Expected: FAIL in model setup or FSDP2 execution, not from missing helper functions or syntax errors. - -- [ ] **Step 5: Commit the red test scaffold** - -```bash -git add tests/model/transformers/test_multi_lora_fsdp2_sft.py -git commit -m "test: add accelerate multi-lora fsdp2 sft regression" -``` - -### Task 2: Make `MultiLoraTransformersModel` participate in the shared wrap lifecycle - -**Files:** -- Modify: `src/twinkle/model/transformers/multi_lora_transformers.py` -- Reference: `src/twinkle/model/transformers/transformers.py` -- Test: `tests/model/transformers/test_multi_lora_fsdp2_sft.py` - -- [ ] **Step 1: Write the smallest failing assertion for strategy selection and lazy wrap** - -```python -def test_multi_lora_accelerate_fsdp2_uses_device_mesh(): - model = build_multi_lora_model(strategy='accelerate') - assert model.strategy.device_mesh is not None - assert model._model_wrapped is False -``` - -- [ ] **Step 2: Run the targeted test to verify it fails** - -Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k uses_device_mesh -v` -Expected: FAIL because the class still hardcodes `AccelerateStrategy(device_mesh=None)` and eagerly wraps in `__init__`. - -- [ ] **Step 3: Implement the minimal lifecycle fix** - -```python -class MultiLoraTransformersModel(TransformersModel, PreTrainedModel): - def __init__(..., strategy: Literal['accelerate', 'native_fsdp'] = 'accelerate', fsdp_config=None, ...): - self._fsdp_config = dict(fsdp_config or {}) - self._decide_strategy(strategy) - ... - self.model = self.multi_adapter.patch(self.model) - self._model_wrapped = False - - def _lazy_wrap_model(self): - return super()._lazy_wrap_model() -``` - -- [ ] **Step 4: Run the accelerate tests again** - -Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k "uses_device_mesh or accelerate" -v` -Expected: first lifecycle assertion passes; full accelerate round-trip test still fails in slot save/load or sharded tensor handling. - -- [ ] **Step 5: Commit the lifecycle change** - -```bash -git add src/twinkle/model/transformers/multi_lora_transformers.py tests/model/transformers/test_multi_lora_fsdp2_sft.py -git commit -m "feat: reuse shared fsdp lifecycle for multi-lora transformers" -``` - -### Task 3: Add sharding-safe LoRA slot tensor helpers in `MultiLora` - -**Files:** -- Modify: `src/twinkle/model/multi_lora.py` -- Test: `tests/model/transformers/test_multi_lora_fsdp2_sft.py` -- Optional Test: `tests/model/test_multi_lora_state.py` - -- [ ] **Step 1: Write a failing state round-trip test that isolates slot semantics** - -```python -def test_multi_lora_state_dict_round_trip_preserves_rank_slices(tmp_path): - model = build_multi_lora_model(strategy='accelerate') - model.add_adapter_to_model('default', build_lora_config(r=4), gradient_accumulation_steps=1) - state = model.get_state_dict(adapter_name='default') - assert state - assert all('.default.' not in key for key in state) -``` - -- [ ] **Step 2: Run the state round-trip test to verify it fails for the expected reason** - -Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k round_trip_preserves_rank_slices -v` -Expected: FAIL because current multi-LoRA save/load helpers assume local tensors and direct `parameter.data` writes. - -- [ ] **Step 3: Add minimal helper methods for slot IO** - -```python -def _read_param_tensor(self, parameter): - return torch_util.to_local_tensor(parameter) - - -def _write_param_tensor(self, parameter, value): - if hasattr(parameter, 'device_mesh') and hasattr(parameter, 'placements'): - value = distribute_tensor(value.to(parameter.device), parameter.device_mesh, parameter.placements) - parameter.data.copy_(value) -``` - -- [ ] **Step 4: Refactor `save_initial_weights`, `_load_initial_weights`, `set_state_dict`, `get_state_dict`, and `save_lora_converter` to use the new helpers** - -Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k accelerate -v` -Expected: the accelerate SFT round-trip test passes. - -- [ ] **Step 5: Commit the shared slot-IO layer** - -```bash -git add src/twinkle/model/multi_lora.py tests/model/transformers/test_multi_lora_fsdp2_sft.py -git commit -m "feat: support sharded multi-lora slot state io" -``` - -### Task 4: Add native FSDP2 SFT regression coverage - -**Files:** -- Modify: `tests/model/transformers/test_multi_lora_fsdp2_sft.py` -- Reference: `cookbook/transformers/sp_fsdp_dense.py` - -- [ ] **Step 1: Write the failing native FSDP SFT test** - -```python -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Requires 2+ GPUs') -def test_multi_lora_native_fsdp_sft_round_trip(tmp_path): - model = build_multi_lora_model(strategy='native_fsdp') - model.add_adapter_to_model('default', build_lora_config(), gradient_accumulation_steps=1) - model.set_loss('CrossEntropyLoss', adapter_name='default') - model.set_optimizer('AdamW', lr=1e-4, adapter_name='default') - model.forward_backward(inputs=build_sft_batch(), adapter_name='default') - model.clip_grad_and_step(adapter_name='default') - model.save('native-ckpt', output_dir=str(tmp_path), adapter_name='default') -``` - -- [ ] **Step 2: Run the native test to verify it fails** - -Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k native -v` -Expected: FAIL because native FSDP has not yet been exercised through multi-LoRA wrapping, optimizer rebinding, or slot restore. - -- [ ] **Step 3: Add native-specific assertions before implementation** - -```python -assert model._model_wrapped is False -assert model.device_mesh.fsdp_world_size == 2 -``` - -- [ ] **Step 4: Re-run the native test to verify the failure still points at native FSDP support** - -Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k native -v` -Expected: FAIL in native FSDP wrapping or native slot state handling, not in test setup. - -- [ ] **Step 5: Commit the native regression scaffold** - -```bash -git add tests/model/transformers/test_multi_lora_fsdp2_sft.py -git commit -m "test: add native multi-lora fsdp2 sft regression" -``` - -### Task 5: Make the native FSDP2 SFT path pass - -**Files:** -- Modify: `src/twinkle/model/transformers/multi_lora_transformers.py` -- Modify: `src/twinkle/model/multi_lora.py` -- Test: `tests/model/transformers/test_multi_lora_fsdp2_sft.py` - -- [ ] **Step 1: Implement the smallest native FSDP2 compatibility change** - -```python -model = MultiLoraTransformersModel( - ..., - strategy='native_fsdp', -) -``` - -Required behavior: -- strategy is selected through `_decide_strategy` -- wrapping still happens only through `_lazy_wrap_model` -- optimizer binding remains valid after wrap - -- [ ] **Step 2: Run the native regression to verify the new failure, if any** - -Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -k native -v` -Expected: either PASS or a narrower failure around native slot IO / sharded forward behavior. - -- [ ] **Step 3: Fix the minimal remaining native-specific issue** - -```python -if strategy == 'native_fsdp': - # Keep multi-LoRA slot tensors readable/writable after fully_shard. - ... -``` - -- [ ] **Step 4: Run the full SFT regression file** - -Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -v` -Expected: PASS for both accelerate and native SFT tests. - -- [ ] **Step 5: Commit the native support** - -```bash -git add src/twinkle/model/transformers/multi_lora_transformers.py src/twinkle/model/multi_lora.py tests/model/transformers/test_multi_lora_fsdp2_sft.py -git commit -m "feat: support native fsdp2 sft for multi-lora transformers" -``` - -### Task 6: Final verification and cleanup - -**Files:** -- Modify: `src/twinkle/model/transformers/multi_lora_transformers.py` -- Modify: `src/twinkle/model/multi_lora.py` -- Modify: `tests/model/transformers/test_multi_lora_fsdp2_sft.py` - -- [ ] **Step 1: Run the shared targeted verification suite** - -Run: `python -m pytest tests/model/transformers/test_multi_lora_fsdp2_sft.py -v` -Expected: PASS - -- [ ] **Step 2: Run one adjacent regression if save/load behavior was touched deeply** - -Run: `python -m pytest tests/sampler/test_weight_sync.py -k lora -v` -Expected: PASS or SKIP, with no new failures caused by LoRA state handling changes. - -- [ ] **Step 3: Inspect the final diff for scope discipline** - -Run: `git diff --stat HEAD~1..HEAD` -Expected: only multi-LoRA transformers model code and the new SFT regression tests are touched. - -- [ ] **Step 4: Document verification evidence in the final handoff** - -Required notes: -- exact test commands run -- whether each command passed, failed, or skipped -- any remaining risks, especially GPU-only coverage limitations - -- [ ] **Step 5: Commit cleanup if needed** - -```bash -git add src/twinkle/model/transformers/multi_lora_transformers.py src/twinkle/model/multi_lora.py tests/model/transformers/test_multi_lora_fsdp2_sft.py -git commit -m "test: verify multi-lora fsdp2 sft coverage" -``` diff --git a/docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md b/docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md deleted file mode 100644 index 674b9d51..00000000 --- a/docs/superpowers/specs/2026-04-08-multi-lora-fsdp2-design.md +++ /dev/null @@ -1,261 +0,0 @@ -# Multi-LoRA Transformers FSDP2 SFT Support Design - -## Summary - -Enable `MultiLoraTransformersModel` to run under FSDP2 for SFT across both `AccelerateStrategy` and `NativeFSDPStrategy`, with staged delivery. Implementation will proceed in this order: - -1. `AccelerateStrategy + SFT` -2. `native_fsdp + SFT` - -The design centers on a shared FSDP2 compatibility layer for transformers multi-LoRA, so the native FSDP stage builds on the same model lifecycle, adapter-slot semantics, and distributed weight handling established by the accelerate stage. - -## Goals - -### Final Goal - -Make `MultiLoraTransformersModel` work under FSDP2 for SFT across both supported transformers strategies: - -- `AccelerateStrategy` -- `NativeFSDPStrategy` - -### Delivery Goal - -Ship the capability in two stages: - -1. `AccelerateStrategy + SFT` -2. `native_fsdp + SFT` - -Each stage must leave the shared foundations in a state that later stages can reuse without strategy-specific rewrites. - -## Non-Goals - -- No megatron multi-LoRA changes in this workstream -- No sampler or checkpoint-engine LoRA sync expansion in this workstream -- No attempt to solve all distributed checkpoint migration cases across arbitrary sharding layouts -- No requirement to support multi-adapter concurrent training in the initial stages -- No GRPO support in this workstream -- No large-scale performance tuning as part of correctness work -- No deep rewrite that merges `MultiLoraTransformersModel` into the single-adapter transformers implementation - -## Current Problem - -`MultiLoraTransformersModel` currently does not participate correctly in the transformers FSDP2 lifecycle. - -Current blockers: - -- construction hard-rejects FSDP via an assert -- the class always uses `AccelerateStrategy(device_mesh=None)` -- the model is eagerly wrapped in `__init__` instead of participating in lazy wrap -- multi-LoRA weight helpers assume local tensors and directly mutate `parameter.data` -- save, load, and slot-reset behavior is not strategy-aware - -Because of that: - -- `AccelerateStrategy` cannot shard the model for multi-LoRA training -- `NativeFSDPStrategy` is not wired into multi-LoRA at all -- adapter slot persistence is unsafe once LoRA parameters become DTensors or other sharded parameter forms -- SFT under FSDP2 is currently blocked for both strategy paths - -## Design Principles - -- Keep the existing multi-slot adapter model: tenants bind to preallocated internal LoRA slots -- Build one shared FSDP2 compatibility layer instead of separate accelerate-only and native-only implementations -- Let strategy differences stay in strategy code paths, not in duplicated multi-LoRA business logic -- Stage rollout by strategy, but avoid temporary patches that block later phases -- Prefer the smallest regression tests that prove correctness at each phase - -## Proposed Architecture - -The work is split into two layers: - -1. A shared FSDP2 compatibility foundation for transformers multi-LoRA -2. A staged rollout over SFT for accelerate and native FSDP2 - -### Shared Foundation - -The shared foundation must be correct before later phases can be added safely. - -#### 1. Unified strategy selection and lazy wrap lifecycle - -`MultiLoraTransformersModel` should stop managing wrapping as a special case. - -Required changes: - -- remove the constructor assert that blocks FSDP -- stop forcing `AccelerateStrategy(device_mesh=None)` -- let the class honor the requested strategy: - - default remains `AccelerateStrategy` - - `strategy='native_fsdp'` must instantiate `NativeFSDPStrategy` -- keep `multi_adapter.patch(self.model)` before wrapping so LoRA slots exist in the wrapped model graph -- move wrapping back into `_lazy_wrap_model()` so optimizer creation and strategy wrapping follow the same lifecycle as transformers models - -This is the main prerequisite for supporting both accelerate and native FSDP2 without forking the class. - -#### 2. Strategy-aware multi-LoRA parameter access - -`MultiLora` needs a small internal abstraction for reading and writing LoRA slot tensors under both unsharded and sharded parameter representations. - -This abstraction should support: - -- reading a saveable tensor view from LoRA slot parameters -- writing checkpoint tensors into LoRA slot parameters with the correct target layout -- restoring initial slot values when an adapter is removed -- preserving existing rank slicing rules for: - - `lora_A` - - `lora_B` - - `lora_embedding_A` - - `lora_embedding_B` - -The goal is not a generic distributed checkpoint layer. The goal is to make the current multi-LoRA slot logic safe under FSDP2. - -#### 3. Stable adapter-slot state machine after wrapping - -The tenant-to-slot model is part of the multi-LoRA contract and should remain unchanged. - -The following behaviors must stay valid after wrapping: - -- `activate_adapter` -- `deactivate_adapter` -- `save_context` -- `remove_adapter` - -This is important even in SFT because adapter activation, save/load, and slot reset must behave the same before and after wrapping. - -#### 4. Unified save/load/remove semantics - -The same slot-aware semantics should apply regardless of strategy. - -Required behaviors: - -- `get_state_dict` returns the tenant adapter's LoRA state with correct rank slicing -- `load` maps checkpoint weights into the correct internal slot -- `remove_adapter` restores the slot to its initial weights -- save/load logic works under both wrapped and unwrapped model states - -Single-adapter transformers already has FSDP2-aware load behavior. Multi-LoRA should reuse that idea, but route tensors through tenant-owned slots instead of direct single-adapter PEFT application. - -#### 5. Reusable test scaffolding - -Even though rollout is staged, the test base should be reusable from the start. - -Shared fixtures/helpers should cover: - -- a minimal FSDP2-capable device mesh -- a minimal multi-LoRA transformers model builder -- adapter-slot inspection helpers -- minimal SFT input samples - -This avoids rebuilding test infrastructure at every phase. - -## Staged Rollout - -### Phase 1: `AccelerateStrategy + SFT` - -This is the first delivery milestone and the narrowest supported training loop. - -Supported flow: - -- construct `MultiLoraTransformersModel` with accelerate FSDP2 -- add a single LoRA adapter -- run SFT training: - - `forward` - - `calculate_loss(CrossEntropyLoss)` - - `backward` - - `clip_grad_norm` - - `step` - - `zero_grad` -- save and load LoRA adapter state -- remove the adapter and restore the slot - -Implementation focus: - -- strategy/lazy-wrap integration -- FSDP2-safe slot state persistence -- SFT regression coverage - -Explicitly out of scope for this phase: - -- GRPO -- multi-adapter concurrent training -- sampler sync - -### Phase 2: `native_fsdp + SFT` - -Add native FSDP2 support by reusing the shared foundation. - -Supported flow: - -- construct `MultiLoraTransformersModel` with `strategy='native_fsdp'` -- add a single LoRA adapter -- run the same minimal SFT loop as Phase 1 -- save, load, and remove LoRA adapters correctly - -Implementation focus: - -- compatibility with `NativeFSDPStrategy.wrap_model` -- parameter layout handling after `fully_shard` -- optimizer rebinding and lazy-wrap correctness -- native SFT regression coverage - -Key risk areas: - -- wrapped parameter representation under native FSDP2 -- optimizer param-group rebinding after wrapping -- LoRA forward patch assumptions when parameters are sharded - -## Test Plan - -Tests should expand phase by phase, but stay narrow and behavior-oriented. - -### Shared test requirements - -- use the smallest model and mesh that still exercises the target strategy -- assert adapter-slot tensor behavior, not only absence of exceptions -- validate save/load round-trip and slot reset where relevant - -### Phase-specific checks - -#### Phase 1 - -- accelerate FSDP2 construction succeeds -- SFT forward-backward-step succeeds -- LoRA state round-trips through save/load -- `remove_adapter` restores initial slot values - -#### Phase 2 - -- native FSDP2 construction succeeds -- native SFT forward-backward-step succeeds -- native save/load/remove semantics are correct - -## Risks - -- patched LoRA forward code may assume local tensor access patterns that do not hold after sharding -- accelerate and native FSDP2 may expose LoRA parameters through different tensor representations -- some helper logic may accidentally reconstruct full tensors when only local shards are needed -- lazy wrap may surface ordering issues around optimizer creation, adapter activation, or template hooks - -## Deferred Work - -These should remain outside this design unless later stages prove them necessary for correctness: - -- GRPO support for either strategy -- sampler or checkpoint-engine LoRA synchronization enhancements -- memory-efficient-init customization specific to transformers multi-LoRA -- megatron multi-LoRA parity work -- broader adapter lifecycle redesign beyond the current slot model -- large-scale performance benchmarking or throughput tuning - -## Acceptance Criteria - -This design is complete when both stages are delivered in order and each stage has dedicated regression coverage: - -1. `AccelerateStrategy + SFT` -2. `native_fsdp + SFT` - -At the end of the rollout: - -- `MultiLoraTransformersModel` supports FSDP2 under both accelerate and native strategies -- SFT runs through the supported transformers multi-LoRA paths under both strategies -- save, load, and remove adapter semantics remain correct under FSDP2 -- unsupported areas remain explicitly unchanged diff --git a/tests/model/transformers/test_multi_lora_fsdp2_sft.py b/tests/model/transformers/test_multi_lora_fsdp2_sft.py deleted file mode 100644 index 68836c2c..00000000 --- a/tests/model/transformers/test_multi_lora_fsdp2_sft.py +++ /dev/null @@ -1,404 +0,0 @@ -from contextlib import nullcontext -import os -import shutil -import tempfile -from pathlib import Path - -import pytest -import torch -import twinkle -from peft import LoraConfig -from tokenizers import Tokenizer -from tokenizers.models import WordLevel -from tokenizers.pre_tokenizers import Whitespace -from transformers import LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerFast - -from twinkle import DeviceMesh -from twinkle.model.multi_lora import MultiLora -from twinkle.model.transformers.multi_lora_transformers import MultiLoraTransformersModel -from twinkle.model.transformers.transformers import TransformersModel - -TEST_MODEL_ID = os.environ.get('TEST_MODEL_ID') - - -def build_lora_config(r: int = 4) -> LoraConfig: - return LoraConfig( - r=r, - lora_alpha=max(8, r * 2), - target_modules='all-linear', - init_lora_weights=False, - ) - - -def build_sft_batch(): - return [{ - 'input_ids': [1, 3, 4, 2], - 'attention_mask': [1, 1, 1, 1], - 'position_ids': [0, 1, 2, 3], - 'labels': [1, 3, 4, 2], - }] - - -def assert_same_lora_state(state_before, state_after): - assert state_before.keys() == state_after.keys() - for name, value in state_before.items(): - assert torch.equal(value, state_after[name]), name - - -def _write_tiny_model_dir(model_dir: Path) -> str: - model_dir.mkdir(parents=True, exist_ok=True) - - config = LlamaConfig( - vocab_size=16, - hidden_size=16, - intermediate_size=32, - num_hidden_layers=1, - num_attention_heads=2, - num_key_value_heads=2, - max_position_embeddings=32, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - ) - LlamaForCausalLM(config).save_pretrained(model_dir) - - vocab = { - '': 0, - '': 1, - '': 2, - 'hello': 3, - 'world': 4, - 'adapter': 5, - 'state': 6, - '': 7, - } - tokenizer = Tokenizer(WordLevel(vocab=vocab, unk_token='')) - tokenizer.pre_tokenizer = Whitespace() - fast_tokenizer = PreTrainedTokenizerFast( - tokenizer_object=tokenizer, - unk_token='', - pad_token='', - bos_token='', - eos_token='', - ) - fast_tokenizer.save_pretrained(model_dir) - return str(model_dir) - - -@pytest.fixture -def model_path() -> str: - if TEST_MODEL_ID: - return TEST_MODEL_ID - return _write_tiny_model_dir(make_workspace_temp_dir('tiny-llama')) - - -def make_workspace_temp_dir(prefix: str) -> Path: - base_dir = Path.cwd() / '.codex_test_tmp' - base_dir.mkdir(parents=True, exist_ok=True) - return Path(tempfile.mkdtemp(prefix=f'{prefix}-', dir=base_dir)) - - -def _npu_device_count() -> int: - npu = getattr(torch, 'npu', None) - if npu is None: - try: - import torch_npu # noqa: F401 - except ImportError: - return 0 - npu = getattr(torch, 'npu', None) - - if npu is None: - return 0 - - try: - if npu.is_available(): - return npu.device_count() - except Exception: - return 0 - return 0 - - -def accelerator_device_count() -> int: - return max(torch.cuda.device_count(), _npu_device_count()) - - -def build_device_mesh(fsdp_size: int = 2) -> DeviceMesh: - if torch.cuda.device_count() >= fsdp_size: - return DeviceMesh.from_sizes(fsdp_size=fsdp_size, device_type='cuda') - if _npu_device_count() >= fsdp_size: - return DeviceMesh.from_sizes(fsdp_size=fsdp_size, device_type='npu') - return DeviceMesh.from_sizes(world_size=1, dp_size=1, device_type='cpu') - - -def build_multi_lora_model(model_path: str, strategy: str, fsdp_size: int = 2): - mesh = build_device_mesh(fsdp_size=fsdp_size) - twinkle.initialize(mode='local', global_device_mesh=mesh) - return MultiLoraTransformersModel( - model_id=model_path, - device_mesh=mesh, - strategy=strategy, - ) - - -def build_multi_lora_state() -> MultiLora: - config = LlamaConfig( - vocab_size=16, - hidden_size=16, - intermediate_size=32, - num_hidden_layers=1, - num_attention_heads=2, - num_key_value_heads=2, - max_position_embeddings=32, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - ) - multi_lora = MultiLora(max_loras=2, max_r=8, max_length=32) - multi_lora.patch(LlamaForCausalLM(config)) - multi_lora.save_initial_weights() - multi_lora.acquire_lora('default', build_lora_config(r=4)) - return multi_lora - - -def test_multi_lora_accelerate_fsdp2_uses_device_mesh(model_path): - model = build_multi_lora_model(model_path, strategy='accelerate') - assert model.strategy.device_mesh is not None - assert model._model_wrapped is False - - -def test_multi_lora_native_fsdp2_uses_lazy_wrap(model_path): - model = build_multi_lora_model(model_path, strategy='native_fsdp') - assert model.strategy.device_mesh is not None - assert model._model_wrapped is False - - -def test_build_device_mesh_prefers_npu_when_available(monkeypatch): - class FakeNPU: - - @staticmethod - def is_available(): - return True - - @staticmethod - def device_count(): - return 2 - - monkeypatch.setattr(torch.cuda, 'device_count', lambda: 0) - monkeypatch.setattr(torch, 'npu', FakeNPU(), raising=False) - - mesh = build_device_mesh(fsdp_size=2) - - assert mesh.device_type == 'npu' - assert mesh.fsdp_world_size == 2 - - -def test_accelerator_device_count_uses_npu_when_cuda_absent(monkeypatch): - class FakeNPU: - - @staticmethod - def is_available(): - return True - - @staticmethod - def device_count(): - return 2 - - monkeypatch.setattr(torch.cuda, 'device_count', lambda: 0) - monkeypatch.setattr(torch, 'npu', FakeNPU(), raising=False) - - assert accelerator_device_count() == 2 - - -def test_build_sft_batch_includes_processor_fields(): - batch = build_sft_batch() - - assert batch == [{ - 'input_ids': [1, 3, 4, 2], - 'attention_mask': [1, 1, 1, 1], - 'position_ids': [0, 1, 2, 3], - 'labels': [1, 3, 4, 2], - }] - - -def test_multi_lora_state_dict_round_trip_preserves_rank_slices(): - multi_lora = build_multi_lora_state() - calls = [] - - def fake_read(parameter): - calls.append(tuple(parameter.shape)) - return parameter.detach().clone() - - multi_lora._read_param_tensor = fake_read - state = multi_lora.get_state_dict('default') - - assert state - assert all('.default.' not in key for key in state) - assert calls - - -def test_multi_lora_set_state_dict_uses_tensor_write_helper(): - multi_lora = build_multi_lora_state() - state = multi_lora.get_state_dict('default') - calls = [] - - def fake_write(parameter, value): - calls.append((tuple(parameter.shape), tuple(value.shape))) - parameter.data.copy_(value) - - multi_lora._write_param_tensor = fake_write - multi_lora.set_state_dict('default', state) - - assert calls - - -def test_multi_lora_load_initial_weights_uses_tensor_write_helper(): - multi_lora = build_multi_lora_state() - calls = [] - - def fake_write(parameter, value): - calls.append((tuple(parameter.shape), tuple(value.shape))) - parameter.data.copy_(value) - - multi_lora._write_param_tensor = fake_write - multi_lora._load_initial_weights('lora_0') - - assert calls - - -def test_multi_lora_write_param_tensor_distributes_leaf_tensor(monkeypatch): - multi_lora = MultiLora(max_loras=1, max_r=4, max_length=8) - parameter = torch.nn.Parameter(torch.zeros(2, 2)) - parameter.device_mesh = object() - parameter.placements = ('shard0', ) - - recorded = {} - - def fake_distribute_tensor(value, device_mesh, placements): - if not value.is_leaf: - raise RuntimeError('`distribute_tensor` should be used to distribute leaf tensors!') - recorded['is_leaf'] = value.is_leaf - recorded['device_mesh'] = device_mesh - recorded['placements'] = placements - return value - - monkeypatch.setattr('torch.distributed.tensor.distribute_tensor', fake_distribute_tensor) - - source = torch.ones(2, 2, requires_grad=True) * 3 - - multi_lora._write_param_tensor(parameter, source) - - assert recorded['is_leaf'] is True - assert recorded['device_mesh'] is parameter.device_mesh - assert recorded['placements'] == parameter.placements - assert torch.equal(parameter.data, torch.full((2, 2), 3.0)) - - -def _build_stub_multi_lora_model(): - model = object.__new__(MultiLoraTransformersModel) - model._check_adapter_valid = lambda adapter_name: None - model.multi_adapter = type( - 'DummyMultiAdapter', - (), - { - 'save_context': staticmethod(lambda adapter_name: nullcontext()), - 'set_state_dict': lambda self, adapter_name, state_dict: None, - }, - )() - model.strategy = type('DummyStrategy', (), {'unwrap_model': lambda self, wrapped: wrapped})() - model.model = object() - model._load_optimizer = lambda checkpoint_dir, adapter_name=None: None - return model - - -def test_multi_lora_save_barriers_after_checkpoint_write(monkeypatch): - model = _build_stub_multi_lora_model() - events = [] - - def fake_save(self, name, output_dir=None, interval=1, **kwargs): - events.append('save') - return 'ckpt' - - monkeypatch.setattr(TransformersModel, 'save', fake_save) - monkeypatch.setattr('torch.distributed.is_initialized', lambda: True) - monkeypatch.setattr('torch.distributed.barrier', lambda: events.append('barrier')) - - checkpoint_dir = model.save('ckpt', output_dir='output', adapter_name='default') - - assert checkpoint_dir == 'ckpt' - assert events == ['save', 'barrier'] - - -def test_multi_lora_load_barriers_after_adapter_restore(monkeypatch): - model = _build_stub_multi_lora_model() - events = [] - - class FakePeftModel: - pass - - fake_peft_model = FakePeftModel() - model.model = fake_peft_model - - def fake_set_state_dict(adapter_name, state_dict): - events.append(('set_state_dict', adapter_name, state_dict)) - - model.multi_adapter.set_state_dict = fake_set_state_dict - - monkeypatch.setattr('twinkle.model.transformers.multi_lora_transformers.PeftModel', FakePeftModel) - monkeypatch.setattr('twinkle.model.transformers.multi_lora_transformers.load_peft_weights', - lambda checkpoint_dir, device='cpu': events.append(('load_peft_weights', checkpoint_dir, - device)) or {'layer.weight': torch.ones(1)}) - monkeypatch.setattr('torch.distributed.is_initialized', lambda: True) - monkeypatch.setattr('torch.distributed.barrier', lambda: events.append('barrier')) - - model.load('ckpt', output_dir='output', adapter_name='default') - - assert events == [ - ('load_peft_weights', os.path.join('output', 'ckpt'), 'cpu'), - ('set_state_dict', 'default', {'layer.weight': torch.ones(1)}), - 'barrier', - ] - - -@pytest.mark.skipif(accelerator_device_count() < 2, reason='Requires 2+ CUDA GPUs or NPUs') -def test_multi_lora_accelerate_fsdp2_sft_round_trip(model_path): - output_dir = make_workspace_temp_dir('accelerate-ckpt') - model = build_multi_lora_model(model_path, strategy='accelerate') - model.add_adapter_to_model('default', build_lora_config(), gradient_accumulation_steps=1) - model.set_loss('CrossEntropyLoss', adapter_name='default') - model.set_optimizer('AdamW', lr=1e-4, adapter_name='default') - - batch = build_sft_batch() - model.forward_backward(inputs=batch, adapter_name='default') - model.clip_grad_and_step(adapter_name='default') - - state_before = model.get_state_dict(adapter_name='default') - model.save('ckpt', output_dir=str(output_dir), adapter_name='default') - model.remove_adapter('default') - model.add_adapter_to_model('default', build_lora_config(), gradient_accumulation_steps=1) - model.load('ckpt', output_dir=str(output_dir), adapter_name='default') - state_after = model.get_state_dict(adapter_name='default') - - assert_same_lora_state(state_before, state_after) - shutil.rmtree(output_dir, ignore_errors=True) - - -@pytest.mark.skipif(accelerator_device_count() < 2, reason='Requires 2+ CUDA GPUs or NPUs') -def test_multi_lora_native_fsdp_sft_round_trip(model_path): - output_dir = make_workspace_temp_dir('native-ckpt') - model = build_multi_lora_model(model_path, strategy='native_fsdp') - model.add_adapter_to_model('default', build_lora_config(), gradient_accumulation_steps=1) - model.set_loss('CrossEntropyLoss', adapter_name='default') - model.set_optimizer('AdamW', lr=1e-4, adapter_name='default') - - model.forward_backward(inputs=build_sft_batch(), adapter_name='default') - model.clip_grad_and_step(adapter_name='default') - - state_before = model.get_state_dict(adapter_name='default') - model.save('native-ckpt', output_dir=str(output_dir), adapter_name='default') - model.remove_adapter('default') - model.add_adapter_to_model('default', build_lora_config(), gradient_accumulation_steps=1) - model.load('native-ckpt', output_dir=str(output_dir), adapter_name='default') - state_after = model.get_state_dict(adapter_name='default') - - assert_same_lora_state(state_before, state_after) - shutil.rmtree(output_dir, ignore_errors=True) From af731b5d027ed22db88648cf5c560bb235725be5 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 14 Apr 2026 11:45:42 +0800 Subject: [PATCH 11/13] wip --- src/twinkle/model/multi_lora.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index 849df535..614fe518 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -46,6 +46,8 @@ def _read_param_tensor(self, parameter): return torch_util.to_local_tensor(parameter) def _write_param_tensor(self, parameter, value): + if value is None: + return value = value.detach().to(dtype=parameter.dtype) if hasattr(parameter, 'device_mesh') and hasattr(parameter, 'placements'): from torch.distributed.tensor import distribute_tensor @@ -55,7 +57,9 @@ def _write_param_tensor(self, parameter, value): parameter.data.copy_(value) @staticmethod - def _slice_rank_tensor(name: str, tensor: torch.Tensor, rank: int) -> torch.Tensor: + def _slice_rank_tensor(name: str, tensor, rank: int): + if tensor is None: + return None if 'embedding_A' in name: return tensor[:, :rank] if 'embedding_B' in name: @@ -67,7 +71,9 @@ def _slice_rank_tensor(name: str, tensor: torch.Tensor, rank: int) -> torch.Tens return tensor @staticmethod - def _copy_rank_tensor(name: str, target: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + def _copy_rank_tensor(name: str, target, value): + if target is None or value is None: + return None if 'embedding_A' in name: target[:, :value.shape[1]].copy_(value) elif 'embedding_B' in name: @@ -535,7 +541,10 @@ def _load_weights(_module): for name, parameter in _module.named_parameters(): if pattern.search(name) and self.match_target_modules(name, _lora.tenant_config.target_modules): state_key = name.replace(f'.{_lora.adapter_name}.', '.') - target_tensor = self._read_param_tensor(parameter).clone() + target_tensor = self._read_param_tensor(parameter) + if target_tensor is None: + continue + target_tensor = target_tensor.clone() src_tensor = state_dict[state_key].to(dtype=target_tensor.dtype, device=target_tensor.device) self._copy_rank_tensor(name, target_tensor, src_tensor) self._write_param_tensor(parameter, target_tensor) @@ -556,6 +565,8 @@ def _get_weights(_module): for name, parameter in _module.named_parameters(): if pattern.search(name) and self.match_target_modules(name, _lora.tenant_config.target_modules): _param = self._slice_rank_tensor(name, self._read_param_tensor(parameter), _lora.tenant_config.r) + if _param is None: + continue name = name.replace(f'.{_lora.adapter_name}.', '.') state_dict[name] = _param return state_dict @@ -575,11 +586,14 @@ def _load_initial_weights(self, origin_adapter_name): def _load_initial_weights(_module): for name, parameter in _module.named_parameters(): if pattern_A.search(name): - target_device = self._read_param_tensor(parameter).device - value = _lora.lora_A_weights[name].to(dtype=parameter.dtype, device=target_device) - self._write_param_tensor(parameter, value) + local_param = self._read_param_tensor(parameter) + if local_param is not None: + value = _lora.lora_A_weights[name].to(dtype=parameter.dtype, device=local_param.device) + self._write_param_tensor(parameter, value) if pattern_B.search(name): - self._write_param_tensor(parameter, torch.zeros_like(self._read_param_tensor(parameter))) + local_param = self._read_param_tensor(parameter) + if local_param is not None: + self._write_param_tensor(parameter, torch.zeros_like(local_param)) if isinstance(self.module, list): for _module in self.module: From 62c496c5c0eb96dbb47c1ae14bb8cf25503ea7ee Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 14 Apr 2026 14:33:37 +0800 Subject: [PATCH 12/13] fix --- .../transformers/multi_lora_transformers.py | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index c8efd687..ad4e4843 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -48,12 +48,25 @@ def __init__( self._memory_efficient_init = memory_efficient_init self._decide_strategy(strategy) self.grad_scaler_config = grad_scaler_config + if model_id is not None: + model_id = HubOperation.download_model(model_id) + self.model_id = model_id + if config is None: + from transformers import AutoConfig + self.hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) + else: + self.hf_config = config + if model_cls is None and hasattr(self.hf_config, 'architectures'): + model_cls = self.hf_config.architectures[0] + if model_cls is None: + model_cls = AutoModelForCausalLM if isinstance(model_cls, str): model_cls = getattr(transformers, model_cls) - model_id = HubOperation.download_model(model_id) - with self.strategy.pretrained_load_context(): - self.model = model_cls.from_pretrained(model_id, config=config, **kwargs) - self.model_id = model_id + if model_id is None: + self.model = model_cls.from_config(self.hf_config, **kwargs) + else: + with self.strategy.pretrained_load_context(): + self.model = model_cls.from_pretrained(model_id, config=self.hf_config, **kwargs) self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id) self._default_tokenizer = None self._model_wrapped = False From d43f541b859dde7a9d19e283ae3d30f4b762974b Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 14 Apr 2026 15:47:45 +0800 Subject: [PATCH 13/13] fix --- src/twinkle/model/multi_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index 614fe518..f0135fd8 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -5,6 +5,7 @@ from dataclasses import dataclass, field from peft import LoraConfig, PeftModel, get_peft_model from peft.tuners.lora import Embedding, Linear, LoraLayer +from torch.distributed.tensor import distribute_tensor from types import MethodType from typing import Any, Callable, Dict, List, Optional, Union @@ -50,7 +51,6 @@ def _write_param_tensor(self, parameter, value): return value = value.detach().to(dtype=parameter.dtype) if hasattr(parameter, 'device_mesh') and hasattr(parameter, 'placements'): - from torch.distributed.tensor import distribute_tensor value = distribute_tensor(value.to(parameter.device), parameter.device_mesh, parameter.placements) else: value = value.to(parameter.device)