From cccc36911a81acf3583f3f38ce543160c4922eda Mon Sep 17 00:00:00 2001 From: root Date: Wed, 15 Apr 2026 22:32:31 +0800 Subject: [PATCH 1/4] add base_layer suffix for expert weights --- src/twinkle/model/megatron/megatron.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 4cdbba84..f087d3a6 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -1432,6 +1432,8 @@ def _add_base_layer_suffix(name): base_layer_name = f'{name[:-5]}.base_layer.bias' if not model_keys or base_layer_name in model_keys: name = base_layer_name + if 'experts' in name: + return base_layer_name return name is_peft_format = (adapter_name != _default_adapter_name) From 457f941836e8b453f95397e5cc48043cc0d19921 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 16 Apr 2026 22:15:16 +0800 Subject: [PATCH 2/4] qwen3.6 grpo --- cookbook/rl/short_math_grpo_moe.py | 271 ++++++++++++++++++ src/twinkle/patch/vllm_moe_loader.py | 4 + .../sampler/vllm_sampler/vllm_engine.py | 25 +- .../vllm_sampler/vllm_worker_extension.py | 31 +- 4 files changed, 323 insertions(+), 8 deletions(-) create mode 100644 cookbook/rl/short_math_grpo_moe.py diff --git a/cookbook/rl/short_math_grpo_moe.py b/cookbook/rl/short_math_grpo_moe.py new file mode 100644 index 00000000..97b7da45 --- /dev/null +++ b/cookbook/rl/short_math_grpo_moe.py @@ -0,0 +1,271 @@ +"""GRPO training script for GSM8K dataset. + +Converted from the Tinker client version to Ray-based training. +Uses short reasoning format: shorter thinking gets higher format reward. +Answer extracted from \\boxed{} or #### format. +""" +import os +import re +from typing import List, Tuple, Dict, Any + +from peft import LoraConfig + +import twinkle +from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger +from twinkle.advantage import GRPOAdvantage +from twinkle.checkpoint_engine import CheckpointEngineManager +from twinkle.data_format import SamplingParams +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.metric import CompletionRewardMetric +from twinkle.model import TransformersModel +from twinkle.processor import InputProcessor +from twinkle.reward import GSM8KAccuracyReward +from twinkle.reward.base import Reward +from twinkle.sampler import vLLMSampler +from twinkle.preprocessor.llm import GSM8KProcessor + +logger = get_logger() + +# ========== Configuration ========== +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B') +USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1'))) + +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +MODEL_EP = int(os.environ.get('MODEL_EP', 2)) +MODEL_TP = int(os.environ.get('MODEL_TP', 2)) +MODEL_PP = int(os.environ.get('MODEL_PP', 2)) + +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) +SAMPLER_TP = int(os.environ.get('SAMPLER_TP', 2)) +NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS + +NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) +MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) +LEARNING_RATE = float(os.environ.get('LR', 1e-5)) +MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) +MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4)) +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) +ADAPTER_NAME = 'default' +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000)) +LORA_RANK = int(os.environ.get('LORA_RANK', 16)) + +SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' + 'and put your final answer within \\boxed{}.') + +# ========== Reward Functions ========== +class GSM8KBrevityReward(Reward): + """Brevity reward: rewards shorter completions that contain a valid answer. + + Returns 0.0 if no valid answer format (\\boxed{} or ####). + Otherwise returns higher score for shorter completions (1.0 at <=200 chars). + """ + + def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: + rewards = [] + for traj in trajectories: + messages = traj.get('messages', []) + completion = '' + for msg in reversed(messages): + if msg.get('role') == 'assistant': + completion = msg.get('content', '') + break + + has_answer = bool( + re.search(r'\\boxed\{[^}]+\}', completion) + or re.search(r'####\s*[\-\d,\.]+', completion) + ) + + if not has_answer: + rewards.append(0.0) + else: + length = len(completion) + if length <= 200: + rewards.append(1.0) + else: + rewards.append(max(0.0, 1.0 - (length - 200) / 3000)) + return rewards + + +# ========== Dataset ========== +def create_gsm8k_dataset(): + dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=4096, truncation_strategy='delete', enable_thinking=True) + dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT)) + dataset.encode(add_generation_prompt=True) + return dataset + + +def compute_rewards( + trajectories: List[Dict[str, Any]], +) -> Tuple[List[float], List[float], List[float]]: + accuracy_reward_fn = GSM8KAccuracyReward() + brevity_reward_fn = GSM8KBrevityReward() + + accuracy_rewards = accuracy_reward_fn(trajectories) + brevity_rewards = brevity_reward_fn(trajectories) + total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)] + return total_rewards, brevity_rewards, accuracy_rewards + + +# ========== Main ========== +def main(): + device_groups = [ + DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU', gpus_per_worker=SAMPLER_TP), + ] + dp_size = MODEL_GPUS // (MODEL_TP * MODEL_PP) + model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=dp_size, tp_size=MODEL_TP, pp_size=MODEL_PP, ep_size=MODEL_EP, sequence_parallel=True) + sampler_dp_size = SAMPLER_GPUS // (SAMPLER_TP) + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=sampler_dp_size, tp_size=SAMPLER_TP) + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) + + lora_config = LoraConfig( + target_modules=['all-linear'], + r=LORA_RANK, + lora_alpha=LORA_RANK * 2, + lora_dropout=0.05, + ) + + if USE_MEGATRON: + from twinkle.model.megatron import MegatronModel + model = MegatronModel( + model_id=MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + mixed_precision='bf16', + ) + else: + model = TransformersModel( + model_id=MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + ) + + model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1) + if USE_MEGATRON: + model.set_optimizer('default', lr=LEARNING_RATE) + model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE) + else: + model.set_optimizer('AdamW', lr=LEARNING_RATE) + model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0) + + model.set_loss('GRPOLoss', epsilon=0.2) + model.set_processor(InputProcessor) + model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=True) + + sampler = vLLMSampler( + model_id=MODEL_ID, + engine_args={ + 'tensor_parallel_size': SAMPLER_TP, + 'gpu_memory_utilization': 0.7, + 'max_model_len': 8192, + 'max_lora_rank': LORA_RANK, # save as lora_config + # NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976 + # enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False) + # meaning only sync lora weights, if merge_and_sync=True, + # lora will be merged into the base model and sync all weights to vLLM + 'enable_lora': True, + 'enable_tower_connector_lora': True, + }, + device_mesh=sampler_mesh, + remote_group='sampler', + ) + sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=True) + + ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler) + + GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS + dataloader = DataLoader( + dataset=create_gsm8k_dataset, + batch_size=GLOBAL_BATCH_SIZE, + min_batch_size=GLOBAL_BATCH_SIZE, + device_mesh=model_mesh, + remote_group='model', + ) + + advantage_fn = GRPOAdvantage() + metrics = CompletionRewardMetric() + sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1, temperature=1.0, top_p=0.95) + + optim_step = 0 + logger.info('Starting GSM8K GRPO training (short reasoning)') + logger.info(get_device_placement()) + + for batch in dataloader: + if optim_step >= MAX_STEPS: + break + + metrics.reset() + expand_prompts = [] + for prompt in batch: + expand_prompts.extend([prompt] * NUM_GENERATIONS) + + # enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False) + # meaning only sync lora weights, if merge_and_sync=True, + # lora will be merged into the base model and sync all weights to vLLM + ckpt_manager.sync_weights(merge_and_sync=False) + sampler.reset_prefix_cache() + + sample_responses = sampler.sample( + expand_prompts, + sampling_params, + ) + + all_input_data: List[Dict[str, Any]] = [] + all_old_logps: List[List[float]] = [] + all_completion_lengths: List[int] = [] + + for sample_response in sample_responses: + for sequence in sample_response.sequences: + all_input_data.append(sequence.new_input_feature) + all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) + all_completion_lengths.append(len(sequence.tokens)) + + total_rewards, brevity_rewards, accuracy_rewards = compute_rewards(all_input_data) + + metrics.accumulate( + completion_lengths=all_completion_lengths, + rewards={ + 'total': total_rewards, + 'brevity': brevity_rewards, + 'accuracy': accuracy_rewards, + }, + ) + + advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist() + + total_completions = len(all_input_data) + for mb_start in range(0, total_completions, MINI_BATCH_SIZE): + mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions) + mb_inputs = all_input_data[mb_start:mb_end] + mb_old_logps = all_old_logps[mb_start:mb_end] + mb_advantages = advantages[mb_start:mb_end] + + model.forward_backward( + inputs=mb_inputs, + old_logps=mb_old_logps, + advantages=mb_advantages, + micro_batch_size=MICRO_BATCH_SIZE, + ) + model.clip_grad_and_step() + optim_step += 1 + + if optim_step >= MAX_STEPS: + break + if optim_step % SAVE_STEPS == 0: + model.save(f'math-grpo-checkpoint-{optim_step}') + + log_dict = metrics.calculate() + log_dict.update(model.calculate_metric(is_training=True)) + metrics.reset() + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}') + + logger.info(f'Training completed. optim_steps={optim_step}') + model.save('math-grpo-final') + + +if __name__ == '__main__': + main() diff --git a/src/twinkle/patch/vllm_moe_loader.py b/src/twinkle/patch/vllm_moe_loader.py index 5d064c21..c9b68ac6 100644 --- a/src/twinkle/patch/vllm_moe_loader.py +++ b/src/twinkle/patch/vllm_moe_loader.py @@ -79,6 +79,10 @@ def __call__(self, model, **kwargs): # (False, 'model.layers.0.mlp.experts.w2_weight') use mlp.experts.weight_loader # Early return if no MOE models are supported + # expected_lora_modules : up_proj -> experts.0.up_proj + from vllm.model_executor.models.qwen3_5 import Qwen3_5MoeForConditionalGeneration + Qwen3_5MoeForConditionalGeneration.is_3d_moe_weight = False + if not SUPPORTED_MOE_MODELS: return diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index a1b7123e..87a9fd82 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -1,6 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import inspect import os +import re import torch import uuid from typing import Any, Dict, List, Optional, Union @@ -512,7 +513,14 @@ async def _sync_iter(): sync_id = uuid.uuid4().hex zmq_handle = f'ipc:///tmp/twinkle-ipc-{device_uuid}-{os.getpid()}-{sync_id}.sock' + env_bucket_mb = os.environ.get('TWINKLE_VLLM_BUCKET_SIZE_MB') + if env_bucket_mb is not None: + bucket_size_mb = int(env_bucket_mb) + if bucket_size_mb <= 0: + raise ValueError(f'bucket_size_mb must be > 0, got {bucket_size_mb}') + bucket_size = bucket_size_mb << 20 + lora_mode = bool(base_sync_done and peft_config) # Create transfer buffer buffer = None @@ -575,9 +583,14 @@ async def _chain_first(): offset = 0 bucket_meta: list[dict] = [] n_weights = 0 + current_expert_layer: Optional[str] = None + + def _get_expert_layer_prefix(weight_name: str) -> Optional[str]: + m = re.match(r'^(.*\.mlp\.experts)\.\d+\.', weight_name) + return m.group(1) if m else None async def _flush_bucket(is_last: bool) -> None: - nonlocal offset, bucket_meta + nonlocal offset, bucket_meta, current_expert_layer if not bucket_meta and not is_last: return if buffer.device.type != 'cpu': @@ -593,6 +606,7 @@ async def _flush_bucket(is_last: bool) -> None: ) offset = 0 bucket_meta = [] + current_expert_layer = None async for name, weight in _chain_first(): if use_shm and weight.device.type != 'cpu': @@ -602,6 +616,15 @@ async def _flush_bucket(is_last: bool) -> None: weight_u8 = weight.view(-1).view(torch.uint8) total_nbytes = int(weight_u8.numel()) + expert_layer_prefix = _get_expert_layer_prefix(name) if lora_mode else None + if lora_mode and offset > 0: + # Keep each expert layer in an isolated bucket to avoid sending + # partial expert-layer weights. + if current_expert_layer != expert_layer_prefix: + await _flush_bucket(is_last=False) + if lora_mode: + current_expert_layer = expert_layer_prefix + chunk_offset = 0 while chunk_offset < total_nbytes: if offset >= bucket_size: diff --git a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py index ac58f92f..31eda787 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py @@ -129,9 +129,6 @@ def update_weights_from_ipc( logger.info(f'vLLM worker bind device: local_rank={local_rank}, device={device_str}') self.device = torch.device(device_str) - if peft_config and base_sync_done: - self.remove_lora(VLLM_LORA_INT_ID) - # Detect TP rank — vLLM sets self.rank on each worker. tp_rank = getattr(self, 'rank', 0) tp_size = 1 @@ -200,6 +197,8 @@ def _broadcast_obj(obj): # ── Step 3: Receive and process weight buckets ── partial_tensors: dict = {} + lora_bucket_accum: list[tuple[str, torch.Tensor]] = [] + lora_mode = bool(peft_config and base_sync_done) while True: # Only the driver receives bucket metadata from VLLMEngine. if is_driver: @@ -245,6 +244,10 @@ def _broadcast_obj(obj): tensor = cpu_u8.view(dtype=dtype).view(shape) else: tensor = raw_u8.view(dtype=dtype).view(shape).clone() + # In LoRA mode we accumulate across buckets; move to CPU + # immediately to avoid GPU memory growth/OOM. + if lora_mode and tensor.device.type != 'cpu': + tensor = tensor.cpu() weights.append((name, tensor)) continue @@ -279,6 +282,8 @@ def _broadcast_obj(obj): tensor = assembled else: tensor = assembled.clone() + if lora_mode and tensor.device.type != 'cpu': + tensor = tensor.cpu() weights.append((name, tensor)) del partial_tensors[name] @@ -292,7 +297,14 @@ def _broadcast_obj(obj): if tp_size > 1: dist.barrier(group=cpu_group) - self._load_weights(weights, peft_config=peft_config, base_sync_done=base_sync_done) + # LoRA weights are streamed in multiple buckets for large adapters. + # Applying add_lora() per-bucket will create incomplete adapters and + # break MoE triplet packing. Accumulate all LoRA tensors and load once + # at stream end. + if lora_mode: + lora_bucket_accum.extend(weights) + else: + self._load_weights(weights, peft_config=peft_config, base_sync_done=base_sync_done) del weights if metadata['is_last']: @@ -300,9 +312,16 @@ def _broadcast_obj(obj): pending = ', '.join(sorted(partial_tensors.keys())[:8]) raise RuntimeError( f'Incomplete chunked weights at stream end: pending {len(partial_tensors)} ({pending})') + if lora_mode: + self._load_weights( + lora_bucket_accum, + peft_config=peft_config, + base_sync_done=base_sync_done, + ) break partial_tensors.clear() + lora_bucket_accum.clear() metadata = None raw_u8 = None cpu_u8 = None @@ -403,9 +422,6 @@ def _load_weights( here. """ if peft_config and base_sync_done: - # Remove existing LoRA before replacing - self.remove_lora(VLLM_LORA_INT_ID) - from twinkle.patch.vllm_lora_weights import TensorLoRARequest converted = {self._convert_peft_to_vllm_lora_name(n): t for n, t in weights} @@ -415,6 +431,7 @@ def _load_weights( lora_path=VLLM_LORA_PATH, peft_config=peft_config, lora_tensors=converted, + load_inplace=True, ) self.add_lora(lora_request) else: From 69684385fd368997fa60263d17651aed07710278 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 16 Apr 2026 23:09:40 +0800 Subject: [PATCH 3/4] adjust gpu_memory_utilization to avoid oom --- cookbook/rl/short_math_grpo_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cookbook/rl/short_math_grpo_moe.py b/cookbook/rl/short_math_grpo_moe.py index 97b7da45..170cca03 100644 --- a/cookbook/rl/short_math_grpo_moe.py +++ b/cookbook/rl/short_math_grpo_moe.py @@ -160,7 +160,7 @@ def main(): model_id=MODEL_ID, engine_args={ 'tensor_parallel_size': SAMPLER_TP, - 'gpu_memory_utilization': 0.7, + 'gpu_memory_utilization': 0.6, 'max_model_len': 8192, 'max_lora_rank': LORA_RANK, # save as lora_config # NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976 From 44969e8cbf78e8deb830117d670031c34c50cb01 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 17 Apr 2026 16:51:21 +0800 Subject: [PATCH 4/4] reuse ipc buffer to a avoid oom --- .../sampler/vllm_sampler/vllm_engine.py | 39 ++++++++++++- .../vllm_sampler/vllm_worker_extension.py | 55 +++++++++++++++++-- 2 files changed, 87 insertions(+), 7 deletions(-) diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index 87a9fd82..bf52b10a 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -96,6 +96,18 @@ def __init__( # ``list_loras()`` per request. self._synced_lora_request: Optional[Any] = None + # Long-lived CUDA IPC bucket reused across all update_weights() + # calls. Allocating a new IPC buffer (and hence a new IPC handle) + # per sync forces every worker to create a new CUDA IPC mapping via + # ``rebuild_cuda_tensor`` because PyTorch's ``shared_cache`` cannot + # hit on unseen storage handles. The driver reclaims those mappings + # lazily, which is the root cause of the slow GPU memory drift we + # observed under frequent LoRA syncs. By pinning a single buffer + # and its handle we guarantee the worker-side cache always hits. + self._ipc_buffer: Optional[torch.Tensor] = None + self._ipc_handle: Any = None + self._ipc_buffer_size: int = 0 + # Initialize engine self.engine = self._create_engine() @@ -528,8 +540,25 @@ async def _sync_iter(): if use_gpu_ipc: from torch.multiprocessing.reductions import reduce_tensor - buffer = torch.empty(bucket_size, dtype=torch.uint8, device=first_tensor.device) - ipc_handle = reduce_tensor(buffer) + + # Reuse a long-lived IPC bucket whenever the requested size + # fits. The handle is produced once and shipped to every + # subsequent sync so each worker's ``shared_cache`` stays warm + # and no new CUDA IPC mapping is created per sync. + need_realloc = ( + self._ipc_buffer is None or self._ipc_buffer_size < bucket_size + or self._ipc_buffer.device != first_tensor.device) + if need_realloc: + # Drop the old handle/buffer before allocating a bigger one + # so we do not briefly hold both and double the peak usage. + self._ipc_buffer = None + self._ipc_handle = None + self._ipc_buffer_size = 0 + self._ipc_buffer = torch.empty(bucket_size, dtype=torch.uint8, device=first_tensor.device) + self._ipc_handle = reduce_tensor(self._ipc_buffer) + self._ipc_buffer_size = bucket_size + buffer = self._ipc_buffer + ipc_handle = self._ipc_handle else: from multiprocessing import shared_memory shm_name = f'twinkle_weights_{uuid.uuid4().hex}' @@ -669,7 +698,7 @@ async def _flush_bucket(is_last: bool) -> None: shm.close() shm.unlink() del shm - gc.collect() + gc.collect() elapsed = time.time() - start_time mode = 'LoRA' if base_sync_done and peft_config else 'base' @@ -686,6 +715,10 @@ async def shutdown(self) -> None: logger.info('Shutting down VLLMEngine...') + self._ipc_buffer = None + self._ipc_handle = None + self._ipc_buffer_size = 0 + if self.engine is not None: try: # vLLM v1 AsyncLLM has shutdown() method diff --git a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py index 31eda787..4566bbe0 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py @@ -59,6 +59,33 @@ def _rebuild_ipc(handle, device_id: Optional[int] = None) -> torch.Tensor: return rebuild_cuda_tensor(*list_args) +def _ipc_handle_signature(handle) -> Optional[tuple]: + """Derive a stable signature for a CUDA IPC handle. + + ``reduce_tensor`` returns ``(func, args)`` where ``args`` contains the + CUDA IPC storage handle bytes, storage size, ref-counter handle, etc. + Two handles are equivalent (i.e. map the same CUDA memory region) when + these inner fields match. We hash only the parts that are picklable and + comparable to avoid accidental mismatches due to local objects. + """ + try: + _, args = handle + except Exception: + return None + sig = [] + for v in args: + if isinstance(v, (bytes, bytearray)): + sig.append(('bytes', bytes(v))) + elif isinstance(v, (int, float, bool, str)) or v is None: + sig.append(('scalar', v)) + else: + try: + sig.append(('repr', repr(v))) + except Exception: + return None + return tuple(sig) + + def _rebuild_shared_memory(name: str, size: int): """Rebuild tensor from shared memory. Returns (tensor, shm).""" from multiprocessing import shared_memory @@ -184,7 +211,27 @@ def _broadcast_obj(obj): handle = comm_metadata # All TP ranks rebuild the IPC buffer from the same handle. # CUDA IPC allows any process on the same node to map the memory. - buffer = _rebuild_ipc(handle, self.device.index) + # Reuse a cached buffer across syncs when the sender reuses the + # same IPC handle: this avoids creating a fresh CUDA IPC mapping + # per sync, which the driver releases lazily and is the root + # cause of the apparent GPU memory growth under frequent syncs. + handle_signature = _ipc_handle_signature(handle) + cached_buffer = getattr(self, '_twinkle_ipc_buffer', None) + cached_signature = getattr(self, '_twinkle_ipc_handle_signature', None) + if cached_buffer is not None and cached_signature == handle_signature: + buffer = cached_buffer + else: + # Drop the previous mapping before creating a new one so the + # driver can reclaim the old shared memory region. + if cached_buffer is not None: + self._twinkle_ipc_buffer = None + self._twinkle_ipc_handle_signature = None + del cached_buffer + gc.collect() + Torch.ipc_collect() + buffer = _rebuild_ipc(handle, self.device.index) + self._twinkle_ipc_buffer = buffer + self._twinkle_ipc_handle_signature = handle_signature else: from multiprocessing import shared_memory buffer, shm = _rebuild_shared_memory( @@ -331,7 +378,6 @@ def _broadcast_obj(obj): if is_driver and socket is not None: socket.close() del buffer - gc.collect() if shm is not None: try: shm.close() @@ -343,8 +389,9 @@ def _broadcast_obj(obj): except BufferError as e: logger.warning(f'SharedMemory close skipped due to exported pointers: {e}') del shm - Torch.ipc_collect() - Torch.empty_cache() + gc.collect() + Torch.ipc_collect() + Torch.empty_cache() def load_synced_weights( self,