Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 271 additions & 0 deletions cookbook/rl/short_math_grpo_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
"""GRPO training script for GSM8K dataset.

Converted from the Tinker client version to Ray-based training.
Uses short reasoning format: shorter thinking gets higher format reward.
Answer extracted from \\boxed{} or #### format.
"""
import os
import re
from typing import List, Tuple, Dict, Any

from peft import LoraConfig

import twinkle
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
from twinkle.advantage import GRPOAdvantage
from twinkle.checkpoint_engine import CheckpointEngineManager
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.metric import CompletionRewardMetric
from twinkle.model import TransformersModel
from twinkle.processor import InputProcessor
from twinkle.reward import GSM8KAccuracyReward
from twinkle.reward.base import Reward
from twinkle.sampler import vLLMSampler
from twinkle.preprocessor.llm import GSM8KProcessor

logger = get_logger()

# ========== Configuration ==========
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B')
USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))

MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
MODEL_EP = int(os.environ.get('MODEL_EP', 2))
MODEL_TP = int(os.environ.get('MODEL_TP', 2))
MODEL_PP = int(os.environ.get('MODEL_PP', 2))

SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
SAMPLER_TP = int(os.environ.get('SAMPLER_TP', 2))
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS

NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
LEARNING_RATE = float(os.environ.get('LR', 1e-5))
MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4))
MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4))
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1))
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
ADAPTER_NAME = 'default'
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000))
LORA_RANK = int(os.environ.get('LORA_RANK', 16))

SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning '
'and put your final answer within \\boxed{}.')

# ========== Reward Functions ==========
class GSM8KBrevityReward(Reward):
"""Brevity reward: rewards shorter completions that contain a valid answer.

Returns 0.0 if no valid answer format (\\boxed{} or ####).
Otherwise returns higher score for shorter completions (1.0 at <=200 chars).
"""

def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
rewards = []
for traj in trajectories:
messages = traj.get('messages', [])
completion = ''
for msg in reversed(messages):
if msg.get('role') == 'assistant':
completion = msg.get('content', '')
break

has_answer = bool(
re.search(r'\\boxed\{[^}]+\}', completion)
or re.search(r'####\s*[\-\d,\.]+', completion)
)

if not has_answer:
rewards.append(0.0)
else:
length = len(completion)
if length <= 200:
rewards.append(1.0)
else:
rewards.append(max(0.0, 1.0 - (length - 200) / 3000))
return rewards


# ========== Dataset ==========
def create_gsm8k_dataset():
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=4096, truncation_strategy='delete', enable_thinking=True)
dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT))
dataset.encode(add_generation_prompt=True)
return dataset


def compute_rewards(
trajectories: List[Dict[str, Any]],
) -> Tuple[List[float], List[float], List[float]]:
accuracy_reward_fn = GSM8KAccuracyReward()
brevity_reward_fn = GSM8KBrevityReward()

accuracy_rewards = accuracy_reward_fn(trajectories)
brevity_rewards = brevity_reward_fn(trajectories)
total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)]
return total_rewards, brevity_rewards, accuracy_rewards


# ========== Main ==========
def main():
device_groups = [
DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU', gpus_per_worker=SAMPLER_TP),
]
dp_size = MODEL_GPUS // (MODEL_TP * MODEL_PP)
model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=dp_size, tp_size=MODEL_TP, pp_size=MODEL_PP, ep_size=MODEL_EP, sequence_parallel=True)
sampler_dp_size = SAMPLER_GPUS // (SAMPLER_TP)
sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=sampler_dp_size, tp_size=SAMPLER_TP)
twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)

lora_config = LoraConfig(
target_modules=['all-linear'],
r=LORA_RANK,
lora_alpha=LORA_RANK * 2,
lora_dropout=0.05,
)

if USE_MEGATRON:
from twinkle.model.megatron import MegatronModel
model = MegatronModel(
model_id=MODEL_ID,
device_mesh=model_mesh,
remote_group='model',
mixed_precision='bf16',
)
else:
model = TransformersModel(
model_id=MODEL_ID,
device_mesh=model_mesh,
remote_group='model',
)

model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1)
if USE_MEGATRON:
model.set_optimizer('default', lr=LEARNING_RATE)
model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE)
else:
model.set_optimizer('AdamW', lr=LEARNING_RATE)
model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0)

model.set_loss('GRPOLoss', epsilon=0.2)
model.set_processor(InputProcessor)
model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=True)

sampler = vLLMSampler(
model_id=MODEL_ID,
engine_args={
'tensor_parallel_size': SAMPLER_TP,
'gpu_memory_utilization': 0.6,
'max_model_len': 8192,
'max_lora_rank': LORA_RANK, # save as lora_config
# NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976
# enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False)
# meaning only sync lora weights, if merge_and_sync=True,
# lora will be merged into the base model and sync all weights to vLLM
'enable_lora': True,
'enable_tower_connector_lora': True,
},
device_mesh=sampler_mesh,
remote_group='sampler',
)
sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=True)

ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler)

GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS
dataloader = DataLoader(
dataset=create_gsm8k_dataset,
batch_size=GLOBAL_BATCH_SIZE,
min_batch_size=GLOBAL_BATCH_SIZE,
device_mesh=model_mesh,
remote_group='model',
)

advantage_fn = GRPOAdvantage()
metrics = CompletionRewardMetric()
sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1, temperature=1.0, top_p=0.95)

optim_step = 0
logger.info('Starting GSM8K GRPO training (short reasoning)')
logger.info(get_device_placement())

for batch in dataloader:
if optim_step >= MAX_STEPS:
break

metrics.reset()
expand_prompts = []
for prompt in batch:
expand_prompts.extend([prompt] * NUM_GENERATIONS)

# enable_lora=True used with ckpt_manager.sync_weights(merge_and_sync=False)
# meaning only sync lora weights, if merge_and_sync=True,
# lora will be merged into the base model and sync all weights to vLLM
ckpt_manager.sync_weights(merge_and_sync=False)
sampler.reset_prefix_cache()

sample_responses = sampler.sample(
expand_prompts,
sampling_params,
)

all_input_data: List[Dict[str, Any]] = []
all_old_logps: List[List[float]] = []
all_completion_lengths: List[int] = []

for sample_response in sample_responses:
for sequence in sample_response.sequences:
all_input_data.append(sequence.new_input_feature)
all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs])
all_completion_lengths.append(len(sequence.tokens))

total_rewards, brevity_rewards, accuracy_rewards = compute_rewards(all_input_data)

metrics.accumulate(
completion_lengths=all_completion_lengths,
rewards={
'total': total_rewards,
'brevity': brevity_rewards,
'accuracy': accuracy_rewards,
},
)

advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist()

total_completions = len(all_input_data)
for mb_start in range(0, total_completions, MINI_BATCH_SIZE):
mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions)
mb_inputs = all_input_data[mb_start:mb_end]
mb_old_logps = all_old_logps[mb_start:mb_end]
mb_advantages = advantages[mb_start:mb_end]

model.forward_backward(
inputs=mb_inputs,
old_logps=mb_old_logps,
advantages=mb_advantages,
micro_batch_size=MICRO_BATCH_SIZE,
)
model.clip_grad_and_step()
optim_step += 1

if optim_step >= MAX_STEPS:
break
if optim_step % SAVE_STEPS == 0:
model.save(f'math-grpo-checkpoint-{optim_step}')

log_dict = metrics.calculate()
log_dict.update(model.calculate_metric(is_training=True))
metrics.reset()
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}')

logger.info(f'Training completed. optim_steps={optim_step}')
model.save('math-grpo-final')


if __name__ == '__main__':
main()
4 changes: 4 additions & 0 deletions src/twinkle/patch/vllm_moe_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def __call__(self, model, **kwargs):
# (False, 'model.layers.0.mlp.experts.w2_weight') use mlp.experts.weight_loader

# Early return if no MOE models are supported
# expected_lora_modules : up_proj -> experts.0.up_proj
from vllm.model_executor.models.qwen3_5 import Qwen3_5MoeForConditionalGeneration
Qwen3_5MoeForConditionalGeneration.is_3d_moe_weight = False

if not SUPPORTED_MOE_MODELS:
return

Expand Down
Loading
Loading