diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py index d8be4832..f0135fd8 100644 --- a/src/twinkle/model/multi_lora.py +++ b/src/twinkle/model/multi_lora.py @@ -5,6 +5,7 @@ from dataclasses import dataclass, field from peft import LoraConfig, PeftModel, get_peft_model from peft.tuners.lora import Embedding, Linear, LoraLayer +from torch.distributed.tensor import distribute_tensor from types import MethodType from typing import Any, Callable, Dict, List, Optional, Union @@ -42,6 +43,49 @@ def _get_available_lora(self) -> Optional[LoraTenant]: return _lora return None + def _read_param_tensor(self, parameter): + return torch_util.to_local_tensor(parameter) + + def _write_param_tensor(self, parameter, value): + if value is None: + return + value = value.detach().to(dtype=parameter.dtype) + if hasattr(parameter, 'device_mesh') and hasattr(parameter, 'placements'): + value = distribute_tensor(value.to(parameter.device), parameter.device_mesh, parameter.placements) + else: + value = value.to(parameter.device) + parameter.data.copy_(value) + + @staticmethod + def _slice_rank_tensor(name: str, tensor, rank: int): + if tensor is None: + return None + if 'embedding_A' in name: + return tensor[:, :rank] + if 'embedding_B' in name: + return tensor[:rank, :] + if '_A' in name: + return tensor[:rank, :] + if '_B' in name: + return tensor[:, :rank] + return tensor + + @staticmethod + def _copy_rank_tensor(name: str, target, value): + if target is None or value is None: + return None + if 'embedding_A' in name: + target[:, :value.shape[1]].copy_(value) + elif 'embedding_B' in name: + target[:value.shape[0], :].copy_(value) + elif '_A' in name: + target[:value.shape[0], :].copy_(value) + elif '_B' in name: + target[:, :value.shape[1]].copy_(value) + else: + target.copy_(value) + return target + def _count_available_loras(self): return len([_lora for _lora in self.loras if _lora.tenant_adapter_name is None]) @@ -435,7 +479,7 @@ def save_initial_weights(self): def _store_weights(_module): for name, parameter in _module.named_parameters(): if pattern.search(name): - lora_tenant.lora_A_weights[name] = parameter.data.clone().to('cpu') + lora_tenant.lora_A_weights[name] = self._read_param_tensor(parameter).clone().to('cpu') if isinstance(self.module, list): for _module in self.module: @@ -483,17 +527,7 @@ def save_lora_converter(self, name, parameter, adapter_name): return None if re.search(rf'\.lora_\w+\.({adapter_name}|weight)', name) and self.match_target_modules( name, _lora.tenant_config.target_modules): - _param = torch_util.to_local_tensor(parameter) - if _param is None: - pass - elif 'embedding_A' in name: - _param = _param[:, :_lora.tenant_config.r] - elif 'embedding_B' in name: - _param = _param[:_lora.tenant_config.r, :] - elif '_A' in name: - _param = _param[:_lora.tenant_config.r, :] - elif '_B' in name: - _param = _param[:, :_lora.tenant_config.r] + _param = self._slice_rank_tensor(name, self._read_param_tensor(parameter), _lora.tenant_config.r) name = name.replace(f'.{_lora.adapter_name}.', '.') return name, _param else: @@ -506,20 +540,14 @@ def set_state_dict(self, tenant_adapter_name, state_dict): def _load_weights(_module): for name, parameter in _module.named_parameters(): if pattern.search(name) and self.match_target_modules(name, _lora.tenant_config.target_modules): - name = name.replace(f'.{_lora.adapter_name}.', '.') - src_tensor = state_dict[name] - if 'embedding_A' in name: - r_saved = src_tensor.shape[1] - parameter.data[:, :r_saved].copy_(src_tensor) - elif 'embedding_B' in name: - r_saved = src_tensor.shape[0] - parameter.data[:r_saved, :].copy_(src_tensor) - elif '_A' in name: - r_saved = src_tensor.shape[0] - parameter.data[:r_saved, :].copy_(src_tensor) - elif '_B' in name: - r_saved = src_tensor.shape[1] - parameter.data[:, :r_saved].copy_(src_tensor) + state_key = name.replace(f'.{_lora.adapter_name}.', '.') + target_tensor = self._read_param_tensor(parameter) + if target_tensor is None: + continue + target_tensor = target_tensor.clone() + src_tensor = state_dict[state_key].to(dtype=target_tensor.dtype, device=target_tensor.device) + self._copy_rank_tensor(name, target_tensor, src_tensor) + self._write_param_tensor(parameter, target_tensor) if isinstance(self.module, list): for _module in self.module: @@ -536,15 +564,9 @@ def _get_weights(_module): state_dict = {} for name, parameter in _module.named_parameters(): if pattern.search(name) and self.match_target_modules(name, _lora.tenant_config.target_modules): - _param = torch_util.to_local_tensor(parameter) - if 'embedding_A' in name: - _param = _param[:, :_lora.tenant_config.r] - elif 'embedding_B' in name: - _param = _param[:_lora.tenant_config.r, :] - elif '_A' in name: - _param = _param[:_lora.tenant_config.r, :] - elif '_B' in name: - _param = _param[:, :_lora.tenant_config.r] + _param = self._slice_rank_tensor(name, self._read_param_tensor(parameter), _lora.tenant_config.r) + if _param is None: + continue name = name.replace(f'.{_lora.adapter_name}.', '.') state_dict[name] = _param return state_dict @@ -564,9 +586,14 @@ def _load_initial_weights(self, origin_adapter_name): def _load_initial_weights(_module): for name, parameter in _module.named_parameters(): if pattern_A.search(name): - parameter.data.copy_(_lora.lora_A_weights[name]) + local_param = self._read_param_tensor(parameter) + if local_param is not None: + value = _lora.lora_A_weights[name].to(dtype=parameter.dtype, device=local_param.device) + self._write_param_tensor(parameter, value) if pattern_B.search(name): - parameter.data.copy_(torch.zeros_like(parameter.data).to(parameter.data.dtype)) + local_param = self._read_param_tensor(parameter) + if local_param is not None: + self._write_param_tensor(parameter, torch.zeros_like(local_param)) if isinstance(self.module, list): for _module in self.module: diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index fc2b53cd..ad4e4843 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os +import torch.distributed as dist import transformers from peft import LoraConfig, PeftConfig, PeftModel, load_peft_weights from torch.optim import Optimizer @@ -15,7 +16,6 @@ from twinkle.metric import Metric from twinkle.processor import InputProcessor from ..multi_lora import MultiLora -from .strategy import AccelerateStrategy from .transformers import OptimizerGroup, TransformersModel @@ -29,16 +29,27 @@ def __init__( config: Optional[PretrainedConfig] = None, device_mesh: Optional[DeviceMesh] = None, mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16', + strategy: Literal['accelerate', 'native_fsdp'] = 'accelerate', + ddp_config: Dict[str, Any] = None, + fsdp_config: Dict[str, Any] = None, grad_scaler_config: Dict[str, Any] = None, + memory_efficient_init: bool = False, max_loras: int = 5, max_r: int = 32, max_length: int = 8192, **kwargs): - assert device_mesh.fsdp_world_size <= 0, f'MultiLora does not support FSDP, current is: {str(device_mesh)}' os.environ['TOKENIZERS_PARALLELISM'] = 'true' self._try_init_process_group() super(PreTrainedModel, self).__init__() - model_id = HubOperation.download_model(model_id) + self.device_mesh = device_mesh + self.mixed_precision = mixed_precision + self._fsdp_config = dict(fsdp_config or {}) + self._ddp_config = ddp_config or {} + self._memory_efficient_init = memory_efficient_init + self._decide_strategy(strategy) + self.grad_scaler_config = grad_scaler_config + if model_id is not None: + model_id = HubOperation.download_model(model_id) self.model_id = model_id if config is None: from transformers import AutoConfig @@ -51,24 +62,20 @@ def __init__( model_cls = AutoModelForCausalLM if isinstance(model_cls, str): model_cls = getattr(transformers, model_cls) - self.model = model_cls.from_pretrained(model_id, config=self.hf_config, **kwargs) - self.model_id = model_id + if model_id is None: + self.model = model_cls.from_config(self.hf_config, **kwargs) + else: + with self.strategy.pretrained_load_context(): + self.model = model_cls.from_pretrained(model_id, config=self.hf_config, **kwargs) self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id) - self.device_mesh = device_mesh - self.mixed_precision = mixed_precision - self.grad_scaler_config = grad_scaler_config + self._default_tokenizer = None self._model_wrapped = False self.sp_strategy = None # Initialize expert parallel attributes (required by set_optimizer in TransformersModel) - self._expert_parallel_config = None - self._enable_expert_parallel = False - self._expert_parallel_applied = False self.optimizer_group: Dict[str, OptimizerGroup] = {} self.multi_adapter = MultiLora(max_loras=max_loras, max_r=max_r, max_length=max_length) self.model.gradient_checkpointing_enable() self.model = self.multi_adapter.patch(self.model) - self.strategy = AccelerateStrategy(mixed_precision=mixed_precision, device_mesh=None) - self.model = self.strategy.wrap_model(self.model) self.multi_adapter.save_initial_weights() # Active group for compatibility with single adapter self.active_group = None @@ -98,7 +105,7 @@ def unregister_mm_forward_hook(self, optimizer_group: OptimizerGroup): pass def _lazy_wrap_model(self): - pass + return super()._lazy_wrap_model() @remote_function(dispatch='slice_dp', collect=collect_tensor_dict) def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs): @@ -230,7 +237,10 @@ def get_state_dict(self, **kwargs): def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) with self.multi_adapter.save_context(kwargs.get('adapter_name')): - return super().save(name, output_dir, interval, **kwargs) + checkpoint_dir = super().save(name, output_dir, interval, **kwargs) + if dist.is_initialized(): + dist.barrier() + return checkpoint_dir @remote_function() def load(self, name: Optional[str] = None, output_dir: Optional[str] = None, **kwargs): @@ -252,6 +262,8 @@ def load(self, name: Optional[str] = None, output_dir: Optional[str] = None, **k if load_optimizer: self._load_optimizer(checkpoint_dir, adapter_name=adapter_name) + if dist.is_initialized(): + dist.barrier() @remote_function() def set_grad_scaler(self, **kwargs):