Skip to content
Draft
101 changes: 64 additions & 37 deletions src/twinkle/model/multi_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass, field
from peft import LoraConfig, PeftModel, get_peft_model
from peft.tuners.lora import Embedding, Linear, LoraLayer
from torch.distributed.tensor import distribute_tensor
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Union

Expand Down Expand Up @@ -42,6 +43,49 @@ def _get_available_lora(self) -> Optional[LoraTenant]:
return _lora
return None

def _read_param_tensor(self, parameter):
return torch_util.to_local_tensor(parameter)

def _write_param_tensor(self, parameter, value):
if value is None:
return
value = value.detach().to(dtype=parameter.dtype)
if hasattr(parameter, 'device_mesh') and hasattr(parameter, 'placements'):
value = distribute_tensor(value.to(parameter.device), parameter.device_mesh, parameter.placements)
else:
value = value.to(parameter.device)
parameter.data.copy_(value)
Comment thread
kevssim marked this conversation as resolved.

@staticmethod
def _slice_rank_tensor(name: str, tensor, rank: int):
if tensor is None:
return None
if 'embedding_A' in name:
return tensor[:, :rank]
if 'embedding_B' in name:
return tensor[:rank, :]
if '_A' in name:
return tensor[:rank, :]
if '_B' in name:
return tensor[:, :rank]
return tensor

@staticmethod
def _copy_rank_tensor(name: str, target, value):
if target is None or value is None:
return None
if 'embedding_A' in name:
target[:, :value.shape[1]].copy_(value)
elif 'embedding_B' in name:
target[:value.shape[0], :].copy_(value)
elif '_A' in name:
target[:value.shape[0], :].copy_(value)
elif '_B' in name:
target[:, :value.shape[1]].copy_(value)
else:
target.copy_(value)
return target

def _count_available_loras(self):
return len([_lora for _lora in self.loras if _lora.tenant_adapter_name is None])

Expand Down Expand Up @@ -435,7 +479,7 @@ def save_initial_weights(self):
def _store_weights(_module):
for name, parameter in _module.named_parameters():
if pattern.search(name):
lora_tenant.lora_A_weights[name] = parameter.data.clone().to('cpu')
lora_tenant.lora_A_weights[name] = self._read_param_tensor(parameter).clone().to('cpu')

if isinstance(self.module, list):
for _module in self.module:
Expand Down Expand Up @@ -483,17 +527,7 @@ def save_lora_converter(self, name, parameter, adapter_name):
return None
if re.search(rf'\.lora_\w+\.({adapter_name}|weight)', name) and self.match_target_modules(
name, _lora.tenant_config.target_modules):
_param = torch_util.to_local_tensor(parameter)
if _param is None:
pass
elif 'embedding_A' in name:
_param = _param[:, :_lora.tenant_config.r]
elif 'embedding_B' in name:
_param = _param[:_lora.tenant_config.r, :]
elif '_A' in name:
_param = _param[:_lora.tenant_config.r, :]
elif '_B' in name:
_param = _param[:, :_lora.tenant_config.r]
_param = self._slice_rank_tensor(name, self._read_param_tensor(parameter), _lora.tenant_config.r)
name = name.replace(f'.{_lora.adapter_name}.', '.')
return name, _param
else:
Expand All @@ -506,20 +540,14 @@ def set_state_dict(self, tenant_adapter_name, state_dict):
def _load_weights(_module):
for name, parameter in _module.named_parameters():
if pattern.search(name) and self.match_target_modules(name, _lora.tenant_config.target_modules):
name = name.replace(f'.{_lora.adapter_name}.', '.')
src_tensor = state_dict[name]
if 'embedding_A' in name:
r_saved = src_tensor.shape[1]
parameter.data[:, :r_saved].copy_(src_tensor)
elif 'embedding_B' in name:
r_saved = src_tensor.shape[0]
parameter.data[:r_saved, :].copy_(src_tensor)
elif '_A' in name:
r_saved = src_tensor.shape[0]
parameter.data[:r_saved, :].copy_(src_tensor)
elif '_B' in name:
r_saved = src_tensor.shape[1]
parameter.data[:, :r_saved].copy_(src_tensor)
state_key = name.replace(f'.{_lora.adapter_name}.', '.')
target_tensor = self._read_param_tensor(parameter)
if target_tensor is None:
continue
target_tensor = target_tensor.clone()
src_tensor = state_dict[state_key].to(dtype=target_tensor.dtype, device=target_tensor.device)
self._copy_rank_tensor(name, target_tensor, src_tensor)
self._write_param_tensor(parameter, target_tensor)

if isinstance(self.module, list):
for _module in self.module:
Expand All @@ -536,15 +564,9 @@ def _get_weights(_module):
state_dict = {}
for name, parameter in _module.named_parameters():
if pattern.search(name) and self.match_target_modules(name, _lora.tenant_config.target_modules):
_param = torch_util.to_local_tensor(parameter)
if 'embedding_A' in name:
_param = _param[:, :_lora.tenant_config.r]
elif 'embedding_B' in name:
_param = _param[:_lora.tenant_config.r, :]
elif '_A' in name:
_param = _param[:_lora.tenant_config.r, :]
elif '_B' in name:
_param = _param[:, :_lora.tenant_config.r]
_param = self._slice_rank_tensor(name, self._read_param_tensor(parameter), _lora.tenant_config.r)
Comment thread
kevssim marked this conversation as resolved.
if _param is None:
continue
name = name.replace(f'.{_lora.adapter_name}.', '.')
state_dict[name] = _param
return state_dict
Expand All @@ -564,9 +586,14 @@ def _load_initial_weights(self, origin_adapter_name):
def _load_initial_weights(_module):
for name, parameter in _module.named_parameters():
if pattern_A.search(name):
parameter.data.copy_(_lora.lora_A_weights[name])
local_param = self._read_param_tensor(parameter)
if local_param is not None:
value = _lora.lora_A_weights[name].to(dtype=parameter.dtype, device=local_param.device)
self._write_param_tensor(parameter, value)
if pattern_B.search(name):
parameter.data.copy_(torch.zeros_like(parameter.data).to(parameter.data.dtype))
local_param = self._read_param_tensor(parameter)
if local_param is not None:
self._write_param_tensor(parameter, torch.zeros_like(local_param))

if isinstance(self.module, list):
for _module in self.module:
Expand Down
42 changes: 27 additions & 15 deletions src/twinkle/model/transformers/multi_lora_transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import os
import torch.distributed as dist
import transformers
from peft import LoraConfig, PeftConfig, PeftModel, load_peft_weights
from torch.optim import Optimizer
Expand All @@ -15,7 +16,6 @@
from twinkle.metric import Metric
from twinkle.processor import InputProcessor
from ..multi_lora import MultiLora
from .strategy import AccelerateStrategy
from .transformers import OptimizerGroup, TransformersModel


Expand All @@ -29,16 +29,27 @@ def __init__(
config: Optional[PretrainedConfig] = None,
device_mesh: Optional[DeviceMesh] = None,
mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16',
strategy: Literal['accelerate', 'native_fsdp'] = 'accelerate',
ddp_config: Dict[str, Any] = None,
fsdp_config: Dict[str, Any] = None,
grad_scaler_config: Dict[str, Any] = None,
memory_efficient_init: bool = False,
max_loras: int = 5,
max_r: int = 32,
max_length: int = 8192,
**kwargs):
assert device_mesh.fsdp_world_size <= 0, f'MultiLora does not support FSDP, current is: {str(device_mesh)}'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
self._try_init_process_group()
super(PreTrainedModel, self).__init__()
Comment thread
kevssim marked this conversation as resolved.
model_id = HubOperation.download_model(model_id)
self.device_mesh = device_mesh
self.mixed_precision = mixed_precision
self._fsdp_config = dict(fsdp_config or {})
self._ddp_config = ddp_config or {}
self._memory_efficient_init = memory_efficient_init
self._decide_strategy(strategy)
self.grad_scaler_config = grad_scaler_config
if model_id is not None:
model_id = HubOperation.download_model(model_id)
self.model_id = model_id
if config is None:
from transformers import AutoConfig
Expand All @@ -51,24 +62,20 @@ def __init__(
model_cls = AutoModelForCausalLM
if isinstance(model_cls, str):
model_cls = getattr(transformers, model_cls)
self.model = model_cls.from_pretrained(model_id, config=self.hf_config, **kwargs)
self.model_id = model_id
if model_id is None:
self.model = model_cls.from_config(self.hf_config, **kwargs)
else:
with self.strategy.pretrained_load_context():
self.model = model_cls.from_pretrained(model_id, config=self.hf_config, **kwargs)
self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id)
self.device_mesh = device_mesh
self.mixed_precision = mixed_precision
self.grad_scaler_config = grad_scaler_config
self._default_tokenizer = None
self._model_wrapped = False
self.sp_strategy = None
# Initialize expert parallel attributes (required by set_optimizer in TransformersModel)
self._expert_parallel_config = None
self._enable_expert_parallel = False
self._expert_parallel_applied = False
self.optimizer_group: Dict[str, OptimizerGroup] = {}
self.multi_adapter = MultiLora(max_loras=max_loras, max_r=max_r, max_length=max_length)
self.model.gradient_checkpointing_enable()
self.model = self.multi_adapter.patch(self.model)
self.strategy = AccelerateStrategy(mixed_precision=mixed_precision, device_mesh=None)
self.model = self.strategy.wrap_model(self.model)
self.multi_adapter.save_initial_weights()
# Active group for compatibility with single adapter
self.active_group = None
Expand Down Expand Up @@ -98,7 +105,7 @@ def unregister_mm_forward_hook(self, optimizer_group: OptimizerGroup):
pass

def _lazy_wrap_model(self):
pass
return super()._lazy_wrap_model()

@remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs):
Expand Down Expand Up @@ -230,7 +237,10 @@ def get_state_dict(self, **kwargs):
def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs):
self._check_adapter_valid(kwargs.get('adapter_name'))
with self.multi_adapter.save_context(kwargs.get('adapter_name')):
return super().save(name, output_dir, interval, **kwargs)
checkpoint_dir = super().save(name, output_dir, interval, **kwargs)
if dist.is_initialized():
dist.barrier()
return checkpoint_dir

@remote_function()
def load(self, name: Optional[str] = None, output_dir: Optional[str] = None, **kwargs):
Expand All @@ -252,6 +262,8 @@ def load(self, name: Optional[str] = None, output_dir: Optional[str] = None, **k

if load_optimizer:
self._load_optimizer(checkpoint_dir, adapter_name=adapter_name)
if dist.is_initialized():
dist.barrier()

@remote_function()
def set_grad_scaler(self, **kwargs):
Expand Down
Loading