diff --git a/notebook/dpo.ipynb b/notebook/dpo.ipynb new file mode 100644 index 00000000..34a9247d --- /dev/null +++ b/notebook/dpo.ipynb @@ -0,0 +1,598 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "181f7b7e", + "metadata": {}, + "source": [ + "# Twinkle DPO 训练\n", + "\n", + "> 🎯 **训练目标**:通过 DPO 偏好优化,让模型学会用 **中文 + Emoji 风格** 回答问题,生成更生动友好的回复。\n", + "\n", + "本 Notebook 演示如何使用 **DPO(Direct Preference Optimization,直接偏好优化)** 算法通过 Twinkle 对语言模型进行微调。\n", + "\n", + "### 什么是 DPO?\n", + "\n", + "DPO 是一种从人类偏好数据中直接学习的训练方法,无需显式奖励模型。训练流程如下:\n", + "1. 准备 **偏好数据集**,每条样本包含一个 prompt,以及一条「更好的回答」(chosen)和一条「较差的回答」(rejected)\n", + "2. 用 **参考模型**(base model,禁用 LoRA)对 chosen/rejected 做前向推理,得到参考 log 概率 `ref_logps`\n", + "3. 将 `ref_logps` 附加到训练数据中,再用 **策略模型**(带 LoRA)做前向反向传播,计算 DPO 损失\n", + "4. 执行 **优化器更新**,让模型更倾向于生成 chosen 回答\n", + "\n", + "### 整体流程\n", + "\n", + "```\n", + "准备数据集 → 创建 LoRA 训练客户端 → 训练循环:\n", + " 参考前向(disable_lora=True)→ 附加 ref_logps → DPO 前向反向 → 优化器步骤\n", + "```\n", + "\n", + "### DPO 批次格式\n", + "\n", + "为了让每个 DP(Data Parallel)分片都含有完整的 chosen/rejected 对,批次采用**交错排列**格式:\n", + "\n", + "```\n", + "[pos_1, neg_1, pos_2, neg_2, ..., pos_N, neg_N]\n", + "```\n", + "\n", + "### 前置条件\n", + "\n", + "| 条件 | 说明 |\n", + "|------|------|\n", + "| 环境变量 | 设置 `MODELSCOPE_TOKEN` |\n", + "| 依赖安装 | `pip install twinkle-kit[tinker]` |\n", + "\n", + "> 💡 **获取 Token**:访问 [ModelScope Token 页面](https://www.modelscope.cn/my/access/token) 获取你的 `MODELSCOPE_TOKEN`,并设置为环境变量:`export MODELSCOPE_TOKEN=<你的Token>`\n" + ] + }, + { + "cell_type": "markdown", + "id": "18ea5343", + "metadata": {}, + "source": [ + "## 🚀 零卡训练服务化(Serverless Training)\n", + "\n", + "本 Notebook 运行在 **ModelScope 零卡训练平台** 上。你无需自备 GPU,只需在 Notebook 中编写训练逻辑,平台会自动调度云端 GPU 资源完成训练。\n", + "\n", + "### 架构示意图\n", + "\n", + "```\n", + "┌─────────────────────────────────────────────────────────────┐\n", + "│ 你的 Notebook(CPU 环境) │\n", + "│ │\n", + "│ ┌──────────┐ HTTP / gRPC ┌──────────────────────┐ │\n", + "│ │ Twinkle │ ─────────────────► │ ModelScope 云端 │ │\n", + "│ │ Client │ ◄───────────────── │ GPU 训练集群 │ │\n", + "│ └──────────┘ 训练结果返回 │ │ │\n", + "│ │ │ ┌────┐ ┌────┐ ┌────┐│ │\n", + "│ │ 构造数据 │ │GPU0│ │GPU1│ │... ││ │\n", + "│ │ 发送训练请求 │ └────┘ └────┘ └────┘│ │\n", + "│ │ 接收指标/检查点 │ 模型加载 + LoRA 训练 │ │\n", + "│ ▼ └──────────────────────┘ │\n", + "│ ┌──────────┐ │\n", + "│ │ 数据准备 │ Dataset / DataLoader / Preprocessor │\n", + "│ └──────────┘ │\n", + "└─────────────────────────────────────────────────────────────┘\n", + "```\n", + "\n", + "### 核心优势\n", + "\n", + "| 特性 | 说明 |\n", + "|------|------|\n", + "| **零卡启动** | Notebook 本身不需要 GPU,训练在云端自动执行 |\n", + "| **按需付费** | 仅在训练时占用 GPU 资源 |\n", + "| **开箱即用** | 预置主流模型,无需下载权重 |\n", + "| **LoRA 微调** | 高效参数微调,几分钟即可完成小规模训练 |\n", + "\n", + "> 🔗 本项目由 [Twinkle](https://github.com/modelscope/twinkle) 框架提供支持 | [GitHub](https://github.com/modelscope/twinkle)" + ] + }, + { + "cell_type": "markdown", + "id": "8ab093f1", + "metadata": {}, + "source": [ + "## 第一步:导入依赖与全局配置\n", + "\n", + "| 配置项 | 默认值 | 说明 |\n", + "|--------|--------|------|\n", + "| `BASE_MODEL` | Qwen/Qwen3.6-35B-A3B | 基座模型 |\n", + "| `BATCH_SIZE` | 4 | 每步处理的 DPO 样本对数 |\n", + "| `LEARNING_RATE` | 1e-4 | 学习率 |\n", + "| `DPO_BETA` | 0.1 | DPO 温度系数,控制偏好强度 |\n", + "| `SFT_WEIGHT` | 1.0 | SFT 损失权重(辅助监督信号) |\n", + "| `MAX_LENGTH` | 2048 | 序列最大长度 |\n", + "| `LORA_RANK` | 8 | LoRA 秩 |\n", + "| `DATA_NUM` | 5000 | 使用的数据集样本数量 |" + ] + }, + { + "cell_type": "markdown", + "id": "f6441022", + "metadata": {}, + "source": [ + "## 环境安装\n", + "\n", + "首次运行前,请先执行以下安装命令。如已安装可跳过此步。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b210863c", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install twinkle-kit[tinker] -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77116976", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np\n", + "import torch\n", + "from tqdm import tqdm\n", + "from typing import Any, Dict, List\n", + "\n", + "\n", + "from tinker import types\n", + "from twinkle import init_tinker_client, get_logger\n", + "from twinkle.dataset import Dataset, DatasetMeta, LazyDataset\n", + "from twinkle.dataloader import DataLoader\n", + "from twinkle.preprocessor import EmojiDPOProcessor\n", + "from twinkle.server.common import input_feature_to_datum\n", + "\n", + "logger = get_logger()\n", + "\n", + "# ========== 全局配置 ==========\n", + "BASE_MODEL = 'Qwen/Qwen3.6-35B-A3B'\n", + "BASE_URL = 'http://www.modelscope.cn/twinkle'\n", + "API_KEY = 'EMPTY_API_KEY'\n", + "DATASET_ID = 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji'\n", + "\n", + "BATCH_SIZE = 4\n", + "LEARNING_RATE = 1e-4\n", + "DPO_BETA = 0.1\n", + "SFT_WEIGHT = 1.0\n", + "MAX_LENGTH = 2048\n", + "LORA_RANK = 8\n", + "DATA_NUM = 5000\n", + "SYSTEM_PROMPT = 'You are a helpful assistant.'" + ] + }, + { + "cell_type": "markdown", + "id": "c6cdddc3", + "metadata": {}, + "source": [ + "## 第二步:准备数据集\n", + "\n", + "本示例使用 ModelScope 上的 `shareAI-Llama3-DPO-zh-en-emoji` 数据集,格式为:\n", + "\n", + "```json\n", + "{\n", + " \"prompt\": \"问题文本\",\n", + " \"answer_zh\": \"中文回答(chosen)\",\n", + " \"answer_en\": \"英文回答(rejected)\"\n", + "}\n", + "```\n", + "\n", + "`EmojiDPOProcessor` 会将每条原始样本转换为:\n", + "- `positive`:包含 chosen 回答的 `Trajectory`\n", + "- `negative`:包含 rejected 回答的 `Trajectory`\n", + "\n", + "`dataset.encode()` 会自动识别 `positive`/`negative` 格式,将双轨迹编码为对应的 `InputFeature`。\n", + "\n", + "> **注意**:使用 `LazyDataset` 可以在迭代时按需加载数据,节省内存。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8bfe11b", + "metadata": {}, + "outputs": [], + "source": [ + "def create_dpo_dataset():\n", + " \"\"\"创建 DPO 数据集,返回包含 positive/negative 对的编码数据集。\"\"\"\n", + " dataset = LazyDataset(DatasetMeta(DATASET_ID, data_slice=range(DATA_NUM)))\n", + " dataset.set_template('Qwen3_5Template', model_id=f'ms://{BASE_MODEL}', max_length=MAX_LENGTH)\n", + " dataset.map(\n", + " EmojiDPOProcessor,\n", + " init_args={'system': SYSTEM_PROMPT},\n", + " )\n", + " # EmojiDPOProcessor 返回 {'positive': InputFeature, 'negative': InputFeature}\n", + " # encode 会自动处理该格式\n", + " dataset.encode()\n", + " return dataset\n", + "\n", + "\n", + "dataset = create_dpo_dataset()\n", + "dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)\n", + "print(f'数据集大小: {len(dataset)} 条')\n", + "print(f'每轮训练步数: {len(dataloader)} 步')" + ] + }, + { + "cell_type": "markdown", + "id": "c1fda3c0", + "metadata": {}, + "source": [ + "## 第三步:构造 DPO 批次\n", + "\n", + "DPO 要求每个训练批次同时包含 chosen 和 rejected 样本,且必须**成对出现**。\n", + "\n", + "为确保 Data Parallel 切分后每个设备都能看到完整的 chosen/rejected 对,我们将批次重排为交错格式:\n", + "\n", + "```\n", + "原始批次:[{positive: A+, negative: A-}, {positive: B+, negative: B-}]\n", + "交错后: [A+, A-, B+, B-]\n", + "```\n", + "\n", + "这样无论按多少个 DP 分片切割,每个分片都包含完整的偏好对。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3589c44", + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:\n", + " \"\"\"将批次重排为交错格式 [pos_1, neg_1, pos_2, neg_2, ...]。\n", + "\n", + " Args:\n", + " batch: 原始批次,每条含 'positive' 和 'negative' InputFeature。\n", + "\n", + " Returns:\n", + " 交错后的列表,保证每个 DP 切片包含完整的 chosen/rejected 对。\n", + " \"\"\"\n", + " result = []\n", + " for row in batch:\n", + " base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')}\n", + " pos_sample = {**base_fields, **row['positive']}\n", + " neg_sample = {**base_fields, **row['negative']}\n", + " result.append(pos_sample)\n", + " result.append(neg_sample)\n", + " return result\n", + "\n", + "\n", + "# 验证:取第一个 batch 演示交错逻辑\n", + "sample_batch = next(iter(dataloader))\n", + "dpo_batch = prepare_dpo_batch(sample_batch)\n", + "print(f'原始批次大小: {len(sample_batch)} 对')\n", + "print(f'交错后批次大小: {len(dpo_batch)} 条(每对展开为 2 条)')" + ] + }, + { + "cell_type": "markdown", + "id": "5a182365", + "metadata": {}, + "source": [ + "## 第四步:初始化客户端\n", + "\n", + "### 客户端初始化顺序\n", + "\n", + "> **重要**:必须先调用 `init_tinker_client()` 完成 Tinker 运行时初始化,再 import `ServiceClient`。\n", + "\n", + "### LoRA 训练客户端\n", + "\n", + "`create_lora_training_client` 会在服务端创建一个携带 LoRA 适配器的训练会话:\n", + "- `base_model`:指定基座模型(需与服务端已加载模型一致)\n", + "- `rank`:LoRA 秩,秩越大表达能力越强,但显存消耗也越大\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75f41699", + "metadata": {}, + "outputs": [], + "source": [ + "# 必须先初始化 Tinker 客户端,再 import ServiceClient\n", + "init_tinker_client()\n", + "\n", + "from tinker import ServiceClient # noqa: E402\n", + "\n", + "service_client = ServiceClient(base_url=BASE_URL, api_key=API_KEY)\n", + "training_client = service_client.create_lora_training_client(\n", + " base_model=BASE_MODEL,\n", + " rank=LORA_RANK,\n", + ")\n", + "logger.info(f'LoRA 训练客户端已创建(rank={LORA_RANK})')\n", + "logger.info(f'开始 DPO 训练:beta={DPO_BETA}, lr={LEARNING_RATE}')" + ] + }, + { + "cell_type": "markdown", + "id": "41078b1d", + "metadata": {}, + "source": [ + "## 第五步:DPO 训练主循环\n", + "\n", + "每个训练步骤包含四个阶段:\n", + "\n", + "### 5.1 参考前向传播(Reference Forward)\n", + "\n", + "调用 `training_client.forward()` 并设置 `disable_lora=True`,让服务端以**纯基座模型**进行前向计算:\n", + "- LoRA 权重不参与计算图 → 反向传播产生零梯度,**不影响 LoRA 参数**\n", + "- 返回每个 token 的 log 概率 `ref_logps`,作为 DPO 损失的参考基线\n", + "\n", + "### 5.2 附加参考 Log 概率\n", + "\n", + "将参考前向的输出 `ref_logps` 附加到每条 Datum 的 `loss_fn_inputs` 中。\n", + "服务端检测到 `ref_logps` 字段后,会自动切换为 `DPOLoss + DPOMetric`。\n", + "\n", + "### 5.3 DPO 前向反向传播\n", + "\n", + "调用 `training_client.forward_backward()` 使用 `importance_sampling` 损失:\n", + "- `dpo_beta`:控制策略偏离参考模型的程度,越大越保守\n", + "- `dpo_sft_weight`:SFT 辅助损失权重,有助于稳定训练\n", + "\n", + "### 5.4 优化器步骤\n", + "\n", + "`optim_step` 更新 LoRA 参数,并返回 DPO 相关指标(chosen/rejected reward、reward margin 等)。\n", + "\n", + "> **提示**:DPO 训练中 `reward_margin`(chosen reward - rejected reward)持续增大,说明模型正在正确学习偏好。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd65b207", + "metadata": {}, + "outputs": [], + "source": [ + "for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)):\n", + "\n", + " # ---- 数据预处理:numpy/torch → Python list(序列化需要)----\n", + " for row in batch:\n", + " for key in list(row.keys()):\n", + " if isinstance(row[key], np.ndarray):\n", + " row[key] = row[key].tolist()\n", + " elif isinstance(row[key], torch.Tensor):\n", + " row[key] = row[key].cpu().numpy().tolist()\n", + "\n", + " # 将批次重排为交错格式 [pos, neg, pos, neg, ...]\n", + " dpo_batch = prepare_dpo_batch(batch)\n", + "\n", + " # 将每条 InputFeature dict 转换为 Tinker Datum\n", + " input_datums = [input_feature_to_datum(row) for row in dpo_batch]\n", + "\n", + " # =================================================================\n", + " # A. 参考前向传播(base model,disable_lora=True)\n", + " # LoRA 不在计算图中 → 反向不更新 LoRA,安全执行\n", + " # =================================================================\n", + " ref_result = training_client.forward(\n", + " input_datums,\n", + " 'cross_entropy',\n", + " loss_fn_config={'disable_lora': True},\n", + " ).result()\n", + "\n", + " # =================================================================\n", + " # B. 将参考 log 概率附加到每条 Datum 的 loss_fn_inputs\n", + " # 服务端检测到 ref_logps 后自动切换 DPOLoss + DPOMetric\n", + " # =================================================================\n", + " for datum, ref_out in zip(input_datums, ref_result.loss_fn_outputs):\n", + " ref_logprobs_np = np.array(ref_out['logprobs'].tolist(), dtype=np.float32)\n", + " datum.loss_fn_inputs['ref_logps'] = types.TensorData.from_numpy(ref_logprobs_np)\n", + "\n", + " # =================================================================\n", + " # C. DPO 前向反向传播\n", + " # 服务端自动计算 DPO 损失,无需手动实现\n", + " # =================================================================\n", + " fwdbwd_result = training_client.forward_backward(\n", + " input_datums,\n", + " 'importance_sampling',\n", + " loss_fn_config={\n", + " 'dpo_beta': DPO_BETA,\n", + " 'dpo_sft_weight': SFT_WEIGHT,\n", + " },\n", + " ).result()\n", + "\n", + " # =================================================================\n", + " # D. 优化器步骤\n", + " # 更新 LoRA 参数,DPOMetric 在服务端自动计算并随结果返回\n", + " # =================================================================\n", + " optim_result = training_client.optim_step(\n", + " types.AdamParams(learning_rate=LEARNING_RATE)\n", + " ).result()\n", + "\n", + " logger.info(f'[Step {step}] metrics={optim_result.metrics}')\n" + ] + }, + { + "cell_type": "markdown", + "id": "6c78de34", + "metadata": {}, + "source": [ + "## 第六步:保存与导出检查点\n", + "\n", + "训练完成后,使用 `save_state` 将 LoRA 权重保存到服务端:\n", + "- 保存路径由服务端配置决定,返回的 `save_result.path` 为实际路径\n", + "- 该检查点包含 LoRA 适配器权重,可直接加载到基座模型上进行推理\n", + "\n", + "### 可选:上传到 ModelScope Hub\n", + "\n", + "取消注释下方代码,可将检查点一键发布到 ModelScope 模型库,方便共享和部署。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa097e99", + "metadata": {}, + "outputs": [], + "source": [ + "# 保存 LoRA 检查点\n", + "save_result = training_client.save_state('dpo-lora-final').result()\n", + "logger.info(f'检查点已保存:{save_result.path}')\n", + "\n", + "# (可选)上传到 ModelScope Hub\n", + "# YOUR_USER_NAME = 'your_username'\n", + "# hub_model_id = f'{YOUR_USER_NAME}/twinkle-tinker-dpo-lora'\n", + "# training_client.publish_checkpoint_from_tinker_path(save_result.path).result()\n", + "# logger.info(f'检查点已上传至 Hub: {hub_model_id}')" + ] + }, + { + "cell_type": "markdown", + "id": "d10bef25", + "metadata": {}, + "source": [ + "## 推理(Inference)\n", + "\n", + "训练完成后,可以直接使用 **线上服务** 进行推理,无需本地 GPU。\n", + "\n", + "通过 `save_weights_and_get_sampling_client` 或 `create_sampling_client` 加载训练好的 LoRA 检查点,即可在线采样生成。\n", + "\n", + "> 将下方 `weight_path` 替换为训练输出的检查点路径(`twinkle://...` 格式)。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd70c735", + "metadata": {}, + "outputs": [], + "source": [ + "# 推理示例(使用线上服务,无需本地 GPU)\n", + "import os\n", + "from tinker import types\n", + "from twinkle import init_tinker_client, get_logger\n", + "from twinkle.data_format import Message, Trajectory\n", + "from twinkle.template import Template\n", + "\n", + "logger = get_logger()\n", + "\n", + "BASE_MODEL = 'Qwen/Qwen3.6-35B-A3B'\n", + "\n", + "# TODO: 替换为训练输出的检查点路径\n", + "weight_path = '<替换为你的 twinkle:// 检查点路径>' # 例如: save_result.path\n", + "\n", + "init_tinker_client()\n", + "from tinker import ServiceClient\n", + "\n", + "service_client = ServiceClient(\n", + " base_url='http://www.modelscope.cn/twinkle',\n", + " api_key=os.environ.get('MODELSCOPE_TOKEN'),\n", + ")\n", + "\n", + "# 加载 LoRA 检查点并创建采样客户端\n", + "sampling_client = service_client.create_sampling_client(\n", + " model_path=weight_path,\n", + " base_model=BASE_MODEL,\n", + ")\n", + "\n", + "# 构造 Prompt\n", + "template = Template(model_id=f'ms://{BASE_MODEL}')\n", + "trajectory = Trajectory(\n", + " messages=[\n", + " Message(role='system', content='You are a helpful assistant.'),\n", + " Message(role='user', content='你好,请介绍一下你自己。'),\n", + " ]\n", + ")\n", + "\n", + "input_feature = template.encode(trajectory, add_generation_prompt=True)\n", + "input_ids = input_feature['input_ids'].tolist()\n", + "\n", + "# 采样\n", + "prompt = types.ModelInput.from_ints(input_ids)\n", + "params = types.SamplingParams(max_tokens=256, temperature=0.7)\n", + "\n", + "print('Sampling...')\n", + "future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=3)\n", + "result = future.result()\n", + "\n", + "# 输出结果\n", + "print('Responses:')\n", + "for i, seq in enumerate(result.sequences):\n", + " print(f'{i}: {repr(template.decode(seq.tokens))}')" + ] + }, + { + "cell_type": "markdown", + "id": "13d43fcc", + "metadata": {}, + "source": [ + "## 合并权重并导出\n", + "\n", + "训练得到的 LoRA 权重可以与原始模型合并,导出为完整的 HuggingFace 模型,方便后续部署和推理。\n", + "\n", + "> **注意**:合并操作需要 GPU 资源(需要加载完整模型),请在有足够显存的环境下执行。\n", + "\n", + "```bash\n", + "CUDA_VISIBLE_DEVICES=0,1,2,3 \\\n", + "NPROC_PER_NODE=4 \\\n", + "/opt/conda/envs/twinkle/bin/megatron export \\\n", + " --model Qwen/Qwen3.6-35B-A3B \\\n", + " --adapters <替换为你的 LoRA 检查点路径> \\\n", + " --output_dir <替换为输出目录> \\\n", + " --merge_lora true \\\n", + " --to_hf true \\\n", + " --tensor_model_parallel_size 2 \\\n", + " --expert_model_parallel_size 2 \\\n", + " --pipeline_model_parallel_size 2\n", + "```\n", + "\n", + "**参数说明**:\n", + "\n", + "| 参数 | 说明 |\n", + "|------|------|\n", + "| `--model` | 基座模型 ID |\n", + "| `--adapters` | 训练保存的 LoRA 检查点路径 |\n", + "| `--output_dir` | 合并后的完整模型输出目录 |\n", + "| `--merge_lora true` | 将 LoRA 权重合并到基座模型中 |\n", + "| `--to_hf true` | 导出为 HuggingFace 格式 |\n", + "| `--tensor_model_parallel_size` | 张量并行大小 |\n", + "| `--expert_model_parallel_size` | 专家并行大小(MoE 模型专用) |\n", + "| `--pipeline_model_parallel_size` | 流水线并行大小 |\n", + "\n", + "合并完成后,输出目录中即为完整的 HuggingFace 模型,可直接用于推理或部署。" + ] + }, + { + "cell_type": "markdown", + "id": "322cd5c8", + "metadata": {}, + "source": [ + "## 总结\n", + "\n", + "本 Notebook 展示了使用 Twinkle 进行 DPO 训练的完整流程:\n", + "\n", + "| 步骤 | 操作 | 关键 API |\n", + "|------|------|----------|\n", + "| 数据准备 | 加载偏好数据集,用 `EmojiDPOProcessor` 生成 chosen/rejected 对 | `LazyDataset`, `EmojiDPOProcessor` |\n", + "| 批次构造 | 交错排列 `[pos, neg, pos, neg, ...]` | `prepare_dpo_batch` |\n", + "| 参考前向 | 用基座模型计算 `ref_logps` | `training_client.forward(..., disable_lora=True)` |\n", + "| DPO 训练 | 附加 `ref_logps`,自动触发 DPO 损失计算 | `training_client.forward_backward(..., 'importance_sampling')` |\n", + "| 参数更新 | 更新 LoRA,获取 DPO 指标 | `training_client.optim_step(AdamParams)` |\n", + "| 保存导出 | 保存 LoRA 检查点 | `training_client.save_state` |\n", + "\n", + "### 调参建议\n", + "\n", + "- **`DPO_BETA`**:通常在 0.01 ~ 0.5 之间,越小越激进,越大越保守\n", + "- **`SFT_WEIGHT`**:增大该值可防止模型忘记基础能力(灾难性遗忘)\n", + "- **`LORA_RANK`**:rank=8 适合大多数场景;数据量大或任务复杂时可尝试 16/32" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebook/multi_modal.ipynb b/notebook/multi_modal.ipynb new file mode 100644 index 00000000..081d7de8 --- /dev/null +++ b/notebook/multi_modal.ipynb @@ -0,0 +1,514 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "81c2a3f2", + "metadata": {}, + "source": [ + "# Twinkle 客户端 - 多模态 LoRA 训练(LaTeX OCR)\n", + "\n", + "> 🎯 **训练目标**:通过多模态 LoRA 微调,让模型学会 **LaTeX OCR** —— 从图片中识别并输出 LaTeX 数学公式。\n", + "\n", + "本 Notebook 演示如何通过 **Twinkle 客户端** 使用 LoRA 微调一个多模态语言模型,实现 **LaTeX OCR**(从图片识别 LaTeX 公式)。\n", + "\n", + "### 与文本训练的区别\n", + "\n", + "~\n", + "\n", + "相比 `self_cognition.ipynb`(纯文本训练),多模态训练的主要差异:\n", + "- 数据集包含 **图片**,消息中用 `` 标记图片位置\n", + "- 使用 `Qwen3_5Template`(支持视觉输入)而非通用 `Template`\n", + "- 使用 `LazyDataset`(延迟加载)而非 `Dataset`,避免图片数据一次性加载到内存\n", + "- 训练循环中需要将 numpy/torch 张量转为列表,适配网络传输格式\n", + "\n", + "### 整体流程\n", + "\n", + "```\n", + "初始化客户端 -> 定义数据预处理器 -> 准备数据集 -> 配置模型 -> 训练循环 -> 保存检查点\n", + "```\n", + "\n", + "### 前置条件\n", + "\n", + "| 条件 | 说明 |\n", + "|------|------|\n", + "| 服务端已启动 | 需支持多模态模型(如 Qwen3.6 系列) |\n", + "| 环境变量 | `.env` 文件或环境中设置 `MODELSCOPE_TOKEN` |\n", + "| 依赖安装 | `pip install twinkle-kit[tinker]` |\n", + "\n", + "> 💡 **获取 Token**:访问 [ModelScope Token 页面](https://www.modelscope.cn/my/access/token) 获取你的 `MODELSCOPE_TOKEN`,并设置为环境变量:`export MODELSCOPE_TOKEN=<你的Token>`\n" + ] + }, + { + "cell_type": "markdown", + "id": "01e0cb9f", + "metadata": {}, + "source": [ + "## 🚀 零卡训练服务化(Serverless Training)\n", + "\n", + "本 Notebook 运行在 **ModelScope 零卡训练平台** 上。你无需自备 GPU,只需在 Notebook 中编写训练逻辑,平台会自动调度云端 GPU 资源完成训练。\n", + "\n", + "### 架构示意图\n", + "\n", + "```\n", + "┌─────────────────────────────────────────────────────────────┐\n", + "│ 你的 Notebook(CPU 环境) │\n", + "│ │\n", + "│ ┌──────────┐ HTTP / gRPC ┌──────────────────────┐ │\n", + "│ │ Twinkle │ ─────────────────► │ ModelScope 云端 │ │\n", + "│ │ Client │ ◄───────────────── │ GPU 训练集群 │ │\n", + "│ └──────────┘ 训练结果返回 │ │ │\n", + "│ │ │ ┌────┐ ┌────┐ ┌────┐│ │\n", + "│ │ 构造数据 │ │GPU0│ │GPU1│ │... ││ │\n", + "│ │ 发送训练请求 │ └────┘ └────┘ └────┘│ │\n", + "│ │ 接收指标/检查点 │ 模型加载 + LoRA 训练 │ │\n", + "│ ▼ └──────────────────────┘ │\n", + "│ ┌──────────┐ │\n", + "│ │ 数据准备 │ Dataset / DataLoader / Preprocessor │\n", + "│ └──────────┘ │\n", + "└─────────────────────────────────────────────────────────────┘\n", + "```\n", + "\n", + "### 核心优势\n", + "\n", + "| 特性 | 说明 |\n", + "|------|------|\n", + "| **零卡启动** | Notebook 本身不需要 GPU,训练在云端自动执行 |\n", + "| **按需付费** | 仅在训练时占用 GPU 资源 |\n", + "| **开箱即用** | 预置主流模型,无需下载权重 |\n", + "| **LoRA 微调** | 高效参数微调,几分钟即可完成小规模训练 |\n", + "\n", + "> 🔗 本项目由 [Twinkle](https://github.com/modelscope/twinkle) 框架提供支持 | [GitHub](https://github.com/modelscope/twinkle)" + ] + }, + { + "cell_type": "markdown", + "id": "69b8a2df", + "metadata": {}, + "source": [ + "## 第一步:导入依赖并加载环境变量" + ] + }, + { + "cell_type": "markdown", + "id": "b491d97c", + "metadata": {}, + "source": [ + "## 环境安装\n", + "\n", + "首次运行前,请先执行以下安装命令。如已安装可跳过此步。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "562e9aee", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install twinkle-kit[tinker] -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff4fbfd0", + "metadata": {}, + "outputs": [], + "source": [ + "import dotenv\n", + "import os\n", + "dotenv.load_dotenv('.env')\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from peft import LoraConfig\n", + "\n", + "from twinkle import get_logger\n", + "from twinkle.data_format import Trajectory, Message\n", + "from twinkle.preprocessor import Preprocessor\n", + "from twinkle.dataset import DatasetMeta\n", + "from twinkle_client import init_twinkle_client\n", + "from twinkle.dataloader import DataLoader\n", + "from twinkle.dataset import LazyDataset\n", + "from twinkle_client.model import MultiLoraTransformersModel\n", + "\n", + "logger = get_logger()" + ] + }, + { + "cell_type": "markdown", + "id": "f6427b1c", + "metadata": {}, + "source": [ + "## 第二步:初始化客户端并查询历史训练" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f285a570", + "metadata": {}, + "outputs": [], + "source": [ + "base_model = 'Qwen/Qwen3.6-35B-A3B'\n", + "base_url = 'http://www.modelscope.cn/twinkle'\n", + "\n", + "client = init_twinkle_client(base_url=base_url, api_key=os.environ.get('MODELSCOPE_TOKEN'))\n", + "\n", + "runs = client.list_training_runs()\n", + "resume_path = None\n", + "for run in runs:\n", + " logger.info(run.model_dump_json(indent=2))\n", + " checkpoints = client.list_checkpoints(run.training_run_id)\n", + " for checkpoint in checkpoints:\n", + " logger.info(checkpoint.model_dump_json(indent=2))\n", + " # resume_path = checkpoint.twinkle_path\n", + "\n", + "print(f'Found {len(runs)} training run(s)')" + ] + }, + { + "cell_type": "markdown", + "id": "abb5e7d7", + "metadata": {}, + "source": [ + "## 第三步:定义数据预处理器\n", + "\n", + "`LatexOCRProcessor` 将原始数据转为多模态对话格式:\n", + "- **user 消息**:包含图片(`` 标记)和指令文本\n", + "- **assistant 消息**:图片对应的 LaTeX 公式(训练目标)\n", + "\n", + "数据集中每条样本包含:\n", + "- `image`:公式图片\n", + "- `text`:对应的 LaTeX 代码" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18b48bc4", + "metadata": {}, + "outputs": [], + "source": [ + "class LatexOCRProcessor(Preprocessor):\n", + "\n", + " def __call__(self, rows):\n", + " rows = self.map_col_to_row(rows)\n", + " rows = [self.preprocess(row) for row in rows]\n", + " rows = self.map_row_to_col(rows)\n", + " return rows\n", + "\n", + " def preprocess(self, row) -> Trajectory:\n", + " return Trajectory(\n", + " messages=[\n", + " Message(\n", + " role='user',\n", + " content='Using LaTeX to perform OCR on the image.',\n", + " images=[row['image']]\n", + " ),\n", + " Message(\n", + " role='assistant',\n", + " content=row['text']\n", + " ),\n", + " ]\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "c29cd5bd", + "metadata": {}, + "source": [ + "## 第四步:准备数据集\n", + "\n", + "使用 ModelScope 上的 `AI-ModelScope/LaTeX_OCR` 数据集。\n", + "\n", + "**关键区别**:这里使用 `LazyDataset` 而非 `Dataset`,因为图片数据量较大,延迟加载可以节省内存。\n", + "\n", + "| 参数 | 值 | 说明 |\n", + "|------|-----|------|\n", + "| 模板 | `Qwen3_5Template` | 支持多模态输入(图片+文本) |\n", + "| max_length | 512 | 最大 token 数 |\n", + "| data_slice | range(500) | 取前 500 条样本 |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "457f1faf", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = LazyDataset(dataset_meta=DatasetMeta('ms://AI-ModelScope/LaTeX_OCR', data_slice=range(500)))\n", + "\n", + "# 使用 Qwen3.5 专用多模态模板\n", + "dataset.set_template('Qwen3_5Template', model_id=f'ms://{base_model}', max_length=512)\n", + "\n", + "# 应用 LaTeX OCR 预处理\n", + "dataset.map(LatexOCRProcessor)\n", + "\n", + "# 编码\n", + "dataset.encode(batched=True)\n", + "\n", + "dataloader = DataLoader(dataset=dataset, batch_size=4)\n", + "print(f'Dataset size: {len(dataset)}')" + ] + }, + { + "cell_type": "markdown", + "id": "243026ce", + "metadata": {}, + "source": [ + "## 第五步:配置模型\n", + "\n", + "与文本训练基本一致,唯一区别是模板使用 `Qwen3_5Template`。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75cf9403", + "metadata": {}, + "outputs": [], + "source": [ + "model = MultiLoraTransformersModel(model_id=f'ms://{base_model}')\n", + "\n", + "lora_config = LoraConfig(target_modules='all-linear')\n", + "model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)\n", + "\n", + "model.set_template('Qwen3_5Template')\n", + "model.set_processor('InputProcessor', padding_side='right')\n", + "model.set_loss('CrossEntropyLoss')\n", + "model.set_optimizer('Adam', lr=1e-4)\n", + "\n", + "# 断点续训\n", + "if resume_path:\n", + " logger.info(f'Resuming from {resume_path}')\n", + " model.load(resume_path, load_optimizer=True)\n", + "\n", + "logger.info(model.get_train_configs().model_dump())\n", + "print('Model configured')" + ] + }, + { + "cell_type": "markdown", + "id": "6f509e0d", + "metadata": {}, + "source": [ + "## 第六步:训练循环\n", + "\n", + "与文本训练的训练循环几乎一样,但有一个关键区别:**数据格式转换**。\n", + "\n", + "多模态数据中包含 numpy 数组和 torch 张量(图片特征),在通过网络发送到服务端前需要转为 Python 列表。\n", + "\n", + "```python\n", + "# 这段转换逻辑是多模态训练特有的\n", + "for sample in batch:\n", + " for key in sample:\n", + " if isinstance(sample[key], np.ndarray):\n", + " sample[key] = sample[key].tolist()\n", + " elif isinstance(sample[key], torch.Tensor):\n", + " sample[key] = sample[key].cpu().numpy().tolist()\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f91a88dc", + "metadata": {}, + "outputs": [], + "source": [ + "for epoch in range(3):\n", + " logger.info(f'Starting epoch {epoch}')\n", + " for step, batch in enumerate(dataloader):\n", + " # 多模态特有:将张量转为列表以适配网络传输\n", + " for sample in batch:\n", + " for key in sample:\n", + " if isinstance(sample[key], np.ndarray):\n", + " sample[key] = sample[key].tolist()\n", + " elif isinstance(sample[key], torch.Tensor):\n", + " sample[key] = sample[key].cpu().numpy().tolist()\n", + "\n", + " # 前向 + 反向\n", + " model.forward_backward(inputs=batch)\n", + "\n", + " # 梯度裁剪 + 优化器更新\n", + " model.clip_grad_and_step()\n", + "\n", + " # 每 2 步打印指标\n", + " if step % 2 == 0:\n", + " metric = model.calculate_metric(is_training=True)\n", + " logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric.result}')\n", + "\n", + " # 保存检查点\n", + " twinkle_path = model.save(name=f'twinkle-epoch-{epoch}', save_optimizer=True)\n", + " logger.info(f'Saved checkpoint: {twinkle_path}')" + ] + }, + { + "cell_type": "markdown", + "id": "196a3809", + "metadata": {}, + "source": [ + "## 第七步:上传到 ModelScope Hub(可选)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b1376582", + "metadata": {}, + "outputs": [], + "source": [ + "# YOUR_USER_NAME = \"your_username\"\n", + "# hub_model_id = f'{YOUR_USER_NAME}/twinkle-multi-modal'\n", + "# model.upload_to_hub(\n", + "# checkpoint_dir=twinkle_path,\n", + "# hub_model_id=hub_model_id,\n", + "# async_upload=False\n", + "# )\n", + "# logger.info(f'Uploaded checkpoint to hub: {hub_model_id}')\n", + "\n", + "print('Training complete!')" + ] + }, + { + "cell_type": "markdown", + "id": "24713a4d", + "metadata": {}, + "source": [ + "## 推理(Inference)\n", + "\n", + "训练完成后,可以直接使用 **线上服务** 进行推理,无需本地 GPU。\n", + "\n", + "通过 `save_weights_and_get_sampling_client` 或 `create_sampling_client` 加载训练好的 LoRA 检查点,即可在线采样生成。\n", + "\n", + "> 将下方 `weight_path` 替换为训练输出的检查点路径(`twinkle://...` 格式)。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df109191", + "metadata": {}, + "outputs": [], + "source": [ + "# 推理示例(使用线上服务,无需本地 GPU)\n", + "import os\n", + "from tinker import types\n", + "from twinkle import init_tinker_client, get_logger\n", + "from twinkle.data_format import Message, Trajectory\n", + "from twinkle.template import Qwen3_5Template\n", + "\n", + "logger = get_logger()\n", + "\n", + "BASE_MODEL = 'Qwen/Qwen3.6-35B-A3B'\n", + "\n", + "# TODO: 替换为训练输出的检查点路径\n", + "weight_path = '<替换为你的 twinkle:// 检查点路径>'\n", + "\n", + "init_tinker_client()\n", + "from tinker import ServiceClient\n", + "\n", + "service_client = ServiceClient(\n", + " base_url='http://www.modelscope.cn/twinkle',\n", + " api_key=os.environ.get('MODELSCOPE_TOKEN'),\n", + ")\n", + "\n", + "sampling_client = service_client.create_sampling_client(\n", + " model_path=weight_path,\n", + " base_model=BASE_MODEL,\n", + ")\n", + "\n", + "# 构造 Prompt(多模态需使用 Qwen3_5Template)\n", + "template = Qwen3_5Template(model_id=f'ms://{BASE_MODEL}')\n", + "trajectory = Trajectory(\n", + " messages=[\n", + " Message(\n", + " role='user',\n", + " content='Using LaTeX to perform OCR on the image.',\n", + " ),\n", + " ]\n", + ")\n", + "\n", + "input_feature = template.encode(trajectory, add_generation_prompt=True)\n", + "input_ids = input_feature['input_ids'].tolist()\n", + "\n", + "prompt = types.ModelInput.from_ints(input_ids)\n", + "params = types.SamplingParams(max_tokens=256, temperature=0.2)\n", + "\n", + "print('Sampling...')\n", + "future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)\n", + "result = future.result()\n", + "\n", + "print('Responses:')\n", + "for i, seq in enumerate(result.sequences):\n", + " print(f'{i}: {repr(template.decode(seq.tokens))}')" + ] + }, + { + "cell_type": "markdown", + "id": "792c879b", + "metadata": {}, + "source": [ + "## 合并权重并导出\n", + "\n", + "训练得到的 LoRA 权重可以与原始模型合并,导出为完整的 HuggingFace 模型,方便后续部署和推理。\n", + "\n", + "> **注意**:合并操作需要 GPU 资源(需要加载完整模型),请在有足够显存的环境下执行。\n", + "\n", + "```bash\n", + "CUDA_VISIBLE_DEVICES=0,1,2,3 \\\n", + "NPROC_PER_NODE=4 \\\n", + "/opt/conda/envs/twinkle/bin/megatron export \\\n", + " --model Qwen/Qwen3.6-35B-A3B \\\n", + " --adapters <替换为你的 LoRA 检查点路径> \\\n", + " --output_dir <替换为输出目录> \\\n", + " --merge_lora true \\\n", + " --to_hf true \\\n", + " --tensor_model_parallel_size 2 \\\n", + " --expert_model_parallel_size 2 \\\n", + " --pipeline_model_parallel_size 2\n", + "```\n", + "\n", + "**参数说明**:\n", + "\n", + "| 参数 | 说明 |\n", + "|------|------|\n", + "| `--model` | 基座模型 ID |\n", + "| `--adapters` | 训练保存的 LoRA 检查点路径 |\n", + "| `--output_dir` | 合并后的完整模型输出目录 |\n", + "| `--merge_lora true` | 将 LoRA 权重合并到基座模型中 |\n", + "| `--to_hf true` | 导出为 HuggingFace 格式 |\n", + "| `--tensor_model_parallel_size` | 张量并行大小 |\n", + "| `--expert_model_parallel_size` | 专家并行大小(MoE 模型专用) |\n", + "| `--pipeline_model_parallel_size` | 流水线并行大小 |\n", + "\n", + "合并完成后,输出目录中即为完整的 HuggingFace 模型,可直接用于推理或部署。" + ] + }, + { + "cell_type": "markdown", + "id": "c0ff8403", + "metadata": {}, + "source": [ + "## 常见问题\n", + "\n", + "| 问题 | 可能原因 | 解决方法 |\n", + "|------|----------|----------|\n", + "| 图片加载失败 | 数据集下载不完整 | 重新下载或检查网络 |\n", + "| OOM 内存不足 | 图片特征占用显存 | 减小 batch_size 或 max_length |\n", + "| 张量序列化错误 | 忘记转换为列表 | 确保训练循环中有 numpy/torch 转 list 逻辑 |\n", + "| 模板不匹配 | 使用了通用 Template | 多模态需用 `Qwen3_5Template` |" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebook/sample.ipynb b/notebook/sample.ipynb new file mode 100644 index 00000000..22926ac5 --- /dev/null +++ b/notebook/sample.ipynb @@ -0,0 +1,300 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "24ecac17", + "metadata": {}, + "source": [ + "# Twinkle 推理示例(Inference with LoRA)\n", + "\n", + "> 🎯 **本 Notebook 目标**:对训练好的 LoRA 模型进行 **在线推理**,验证训练效果。\n", + "\n", + "本 Notebook 演示如何使用 Twinkle **线上服务** 对训练好的 LoRA 模型进行推理。无需本地 GPU。\n", + "\n", + "### 整体流程\n", + "\n", + "```\n", + "初始化客户端 → 加载 LoRA 检查点 → 构造 Prompt → 在线采样 → 输出结果\n", + "```\n", + "\n", + "### 前置条件\n", + "\n", + "| 条件 | 说明 |\n", + "|------|------|\n", + "| 环境变量 | 设置 `MODELSCOPE_TOKEN` |\n", + "| LoRA 检查点 | 训练产出的 `twinkle://` 路径 |\n", + "| 依赖安装 | `pip install twinkle-kit[tinker]` |\n", + "> 💡 **获取 Token**:访问 [ModelScope Token 页面](https://www.modelscope.cn/my/access/token) 获取你的 `MODELSCOPE_TOKEN`,并设置为环境变量:`export MODELSCOPE_TOKEN=<你的Token>`\n" + ] + }, + { + "cell_type": "markdown", + "id": "4572d725", + "metadata": {}, + "source": [ + "## 🚀 线上推理服务\n", + "\n", + "本 Notebook 通过 **ModelScope 线上服务** 进行推理,你的 Notebook 环境不需要 GPU。\n", + "\n", + "### 架构示意图\n", + "\n", + "```\n", + "┌───────────────────────────────────────────────────────┐\n", + "│ 你的 Notebook(CPU 环境) │\n", + "│ │\n", + "│ ┌──────────────┐ HTTP ┌────────────────────┐ │\n", + "│ │ Tinker │ ─────────► │ ModelScope 云端 │ │\n", + "│ │ ServiceClient│ ◄───────── │ 推理集群 │ │\n", + "│ └──────────────┘ 生成结果 │ │ │\n", + "│ │ │ ┌────┐ ┌────┐ │ │\n", + "│ │ 发送 Prompt │ │GPU0│ │GPU1│ ... │ │\n", + "│ │ 接收生成文本 │ └────┘ └────┘ │ │\n", + "│ ▼ │ 基座模型 + LoRA │ │\n", + "│ ┌──────────────┐ └────────────────────┘ │\n", + "│ │ Template │ 本地编码 Prompt / 解码结果 │\n", + "│ └──────────────┘ │\n", + "└───────────────────────────────────────────────────────┘\n", + "```\n", + "\n", + "> 🔗 本项目由 [Twinkle](https://github.com/modelscope/twinkle) 框架提供支持 | [GitHub](https://github.com/modelscope/twinkle)" + ] + }, + { + "cell_type": "markdown", + "id": "fffe2f41", + "metadata": {}, + "source": [ + "## 第一步:导入依赖与配置\n", + "\n", + "| 配置项 | 说明 |\n", + "|--------|------|\n", + "| `BASE_MODEL` | 基座模型 ID |\n", + "| `weight_path` | 训练产出的 LoRA 检查点路径(`twinkle://...` 格式) |\n", + "| `MODELSCOPE_TOKEN` | ModelScope API Token(环境变量) |" + ] + }, + { + "cell_type": "markdown", + "id": "8c4b0c77", + "metadata": {}, + "source": [ + "## 环境安装\n", + "\n", + "首次运行前,请先执行以下安装命令。如已安装可跳过此步。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3afce4b9", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install twinkle-kit[tinker] -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "78ff4238", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from tinker import types\n", + "from twinkle import init_tinker_client, get_logger\n", + "from twinkle.data_format import Message, Trajectory\n", + "from twinkle.template import Template\n", + "\n", + "logger = get_logger()\n", + "\n", + "BASE_MODEL = 'Qwen/Qwen3.6-35B-A3B'\n", + "\n", + "# TODO: 替换为你的训练检查点路径\n", + "weight_path = '<替换为你的 twinkle:// 检查点路径>' # 例如: 'twinkle://xxx/weights/twinkle-lora-2'" + ] + }, + { + "cell_type": "markdown", + "id": "22db0edc", + "metadata": {}, + "source": [ + "## 第二步:初始化客户端\n", + "\n", + "连接 ModelScope 线上推理服务,并加载训练好的 LoRA 检查点。\n", + "\n", + "> **重要**:必须先调用 `init_tinker_client()` 完成运行时初始化,再 import `ServiceClient`。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b5e0def", + "metadata": {}, + "outputs": [], + "source": [ + "init_tinker_client()\n", + "from tinker import ServiceClient\n", + "\n", + "service_client = ServiceClient(\n", + " base_url='http://www.modelscope.cn/twinkle',\n", + " api_key=os.environ.get('MODELSCOPE_TOKEN'),\n", + ")\n", + "\n", + "# 加载 LoRA 检查点并创建采样客户端\n", + "sampling_client = service_client.create_sampling_client(\n", + " model_path=weight_path,\n", + " base_model=BASE_MODEL,\n", + ")\n", + "print('采样客户端创建成功')" + ] + }, + { + "cell_type": "markdown", + "id": "3cfbc0a1", + "metadata": {}, + "source": [ + "## 第三步:构造 Prompt\n", + "\n", + "使用 Template 将对话格式的 Trajectory 编码为 token 序列。\n", + "\n", + "> Template 需要在本地加载 tokenizer(会自动从 ModelScope 下载),但 **不需要 GPU**。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0747f171", + "metadata": {}, + "outputs": [], + "source": [ + "template = Template(model_id=f'ms://{BASE_MODEL}')\n", + "\n", + "# 构造多条 Prompt\n", + "prompts = [\n", + " Trajectory(\n", + " messages=[\n", + " Message(role='system', content='You are a helpful assistant.'),\n", + " Message(role='user', content='什么是强化学习?请简单解释。'),\n", + " ]\n", + " ),\n", + " Trajectory(\n", + " messages=[\n", + " Message(role='user', content='Write a short poem about the moon.'),\n", + " ]\n", + " ),\n", + " Trajectory(\n", + " messages=[\n", + " Message(role='user', content='求解方程 2x + 3 = 11,x 等于多少?'),\n", + " ]\n", + " ),\n", + "]\n", + "\n", + "print(f'共 {len(prompts)} 条 Prompt')" + ] + }, + { + "cell_type": "markdown", + "id": "9aa6ace6", + "metadata": {}, + "source": [ + "## 第四步:采样推理\n", + "\n", + "对每条 Prompt 编码后发送到线上服务进行采样。\n", + "\n", + "| 参数 | 说明 |\n", + "|------|------|\n", + "| `max_tokens` | 最大生成 token 数 |\n", + "| `temperature` | 采样温度,越高越多样 |\n", + "| `num_samples` | 每条 Prompt 生成几条回答 |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "196e6027", + "metadata": {}, + "outputs": [], + "source": [ + "params = types.SamplingParams(\n", + " max_tokens=256,\n", + " temperature=0.7,\n", + ")\n", + "\n", + "for i, trajectory in enumerate(prompts):\n", + " # 编码 Prompt\n", + " input_feature = template.encode(trajectory, add_generation_prompt=True)\n", + " input_ids = input_feature['input_ids'].tolist()\n", + " prompt = types.ModelInput.from_ints(input_ids)\n", + "\n", + " # 发送采样请求\n", + " future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)\n", + " result = future.result()\n", + "\n", + " # 解码并输出\n", + " user_msg = trajectory['messages'][-1]['content']\n", + " print(f'\\n{\"=\"*60}')\n", + " print(f'Prompt {i}: {user_msg}')\n", + " print(f'{\"─\"*60}')\n", + " for j, seq in enumerate(result.sequences):\n", + " print(f'{template.decode(seq.tokens)}')" + ] + }, + { + "cell_type": "markdown", + "id": "6643d9ae", + "metadata": {}, + "source": [ + "## 合并权重并导出\n", + "\n", + "训练得到的 LoRA 权重可以与原始模型合并,导出为完整的 HuggingFace 模型,方便后续部署和推理。\n", + "\n", + "> **注意**:合并操作需要 GPU 资源(需要加载完整模型),请在有足够显存的环境下执行。\n", + "\n", + "```bash\n", + "CUDA_VISIBLE_DEVICES=0,1,2,3 \\\n", + "NPROC_PER_NODE=4 \\\n", + "/opt/conda/envs/twinkle/bin/megatron export \\\n", + " --model Qwen/Qwen3.6-35B-A3B \\\n", + " --adapters <替换为你的 LoRA 检查点路径> \\\n", + " --output_dir <替换为输出目录> \\\n", + " --merge_lora true \\\n", + " --to_hf true \\\n", + " --tensor_model_parallel_size 2 \\\n", + " --expert_model_parallel_size 2 \\\n", + " --pipeline_model_parallel_size 2\n", + "```\n", + "\n", + "合并完成后,输出目录中即为完整的 HuggingFace 模型,可直接用于推理或部署。" + ] + }, + { + "cell_type": "markdown", + "id": "fcffd0d3", + "metadata": {}, + "source": [ + "## 常见问题\n", + "\n", + "| 问题 | 可能原因 | 解决方法 |\n", + "|------|----------|----------|\n", + "| 连接超时 | 网络问题或服务端繁忙 | 检查网络并重试 |\n", + "| 检查点不存在 | weight_path 路径错误 | 确认训练完成并检查 save_result.path |\n", + "| 输出质量差 | LoRA 训练不充分 | 增加训练步数或调整学习率 |\n", + "| Token 认证失败 | MODELSCOPE_TOKEN 未设置 | 检查环境变量 |" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebook/self_cognition.ipynb b/notebook/self_cognition.ipynb new file mode 100644 index 00000000..7fc69897 --- /dev/null +++ b/notebook/self_cognition.ipynb @@ -0,0 +1,436 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "21286f2b", + "metadata": {}, + "source": [ + "# Twinkle 自我认知训练与评估\n", + "\n", + "> 🎯 **训练目标**:通过 SFT 微调,让模型 **改变自我认知** —— 在回答「你是谁」时使用你指定的名字和团队名。\n", + "\n", + "本 Notebook 演示如何 **微调** 一个大语言模型,使其学会自定义身份(名称、团队),然后 **评估** 训练效果。\n", + "\n", + "### 什么是自我认知训练?\n", + "\n", + "自我认知训练让模型在回答「你是谁」「谁创建了你」等问题时,使用你指定的名字和团队名,而非默认身份。适用场景:\n", + "- 构建品牌化的 AI 助手\n", + "- 为特定组织定制模型行为\n", + "- 快速演示 LoRA 微调能力\n", + "\n", + "### 整体流程\n", + "\n", + "```\n", + "Part 1(训练): 准备数据集 → 初始化训练客户端 → 训练循环 → 保存检查点\n", + "Part 2(评估): 加载检查点 → 构造提示词 → 采样生成 → 验证身份\n", + "```\n", + "\n", + "### 前置条件\n", + "\n", + "| 条件 | 说明 |\n", + "|------|------|\n", + "| 环境变量 | 设置 `MODELSCOPE_TOKEN` 为你的 ModelScope API Token |\n", + "| 依赖安装 | `pip install twinkle-kit[tinker]` |\n", + "\n", + "> 💡 **获取 Token**:访问 [ModelScope Token 页面](https://www.modelscope.cn/my/access/token) 获取你的 `MODELSCOPE_TOKEN`,并设置为环境变量:`export MODELSCOPE_TOKEN=<你的Token>`\n" + ] + }, + { + "cell_type": "markdown", + "id": "41e237cf", + "metadata": {}, + "source": [ + "## 🚀 零卡训练服务化(Serverless Training)\n", + "\n", + "本 Notebook 运行在 **ModelScope 零卡训练平台** 上。你无需自备 GPU,只需在 Notebook 中编写训练逻辑,平台会自动调度云端 GPU 资源完成训练。\n", + "\n", + "### 架构示意图\n", + "\n", + "```\n", + "┌─────────────────────────────────────────────────────────────┐\n", + "│ 你的 Notebook(CPU 环境) │\n", + "│ │\n", + "│ ┌──────────┐ HTTP / gRPC ┌──────────────────────┐ │\n", + "│ │ Twinkle │ ─────────────────► │ ModelScope 云端 │ │\n", + "│ │ Client │ ◄───────────────── │ GPU 训练集群 │ │\n", + "│ └──────────┘ 训练结果返回 │ │ │\n", + "│ │ │ ┌────┐ ┌────┐ ┌────┐│ │\n", + "│ │ 构造数据 │ │GPU0│ │GPU1│ │... ││ │\n", + "│ │ 发送训练请求 │ └────┘ └────┘ └────┘│ │\n", + "│ │ 接收指标/检查点 │ 模型加载 + LoRA 训练 │ │\n", + "│ ▼ └──────────────────────┘ │\n", + "│ ┌──────────┐ │\n", + "│ │ 数据准备 │ Dataset / DataLoader / Preprocessor │\n", + "│ └──────────┘ │\n", + "└─────────────────────────────────────────────────────────────┘\n", + "```\n", + "\n", + "### 核心优势\n", + "\n", + "| 特性 | 说明 |\n", + "|------|------|\n", + "| **零卡启动** | Notebook 本身不需要 GPU,训练在云端自动执行 |\n", + "| **按需付费** | 仅在训练时占用 GPU 资源 |\n", + "| **开箱即用** | 预置主流模型,无需下载权重 |\n", + "| **LoRA 微调** | 高效参数微调,几分钟即可完成小规模训练 |\n", + "\n", + "> 🔗 本项目由 [Twinkle](https://github.com/modelscope/twinkle) 框架提供支持 | [GitHub](https://github.com/modelscope/twinkle)" + ] + }, + { + "cell_type": "markdown", + "id": "2cfbde0a", + "metadata": {}, + "source": [ + "---\n", + "## Part 1:训练\n", + "\n", + "### 1.1 导入依赖" + ] + }, + { + "cell_type": "markdown", + "id": "ff560956", + "metadata": {}, + "source": [ + "## 环境安装\n", + "\n", + "首次运行前,请先执行以下安装命令。如已安装可跳过此步。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5be2b154", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install twinkle-kit[tinker] -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eaf38b0b", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from tqdm import tqdm\n", + "from tinker import types\n", + "from twinkle import init_tinker_client\n", + "from twinkle.data_format import Message, Trajectory\n", + "from twinkle.template import Template\n", + "from twinkle.dataloader import DataLoader\n", + "from twinkle.dataset import Dataset, DatasetMeta\n", + "from twinkle.preprocessor import SelfCognitionProcessor\n", + "from twinkle.server.common import input_feature_to_datum" + ] + }, + { + "cell_type": "markdown", + "id": "ba4b4a49", + "metadata": {}, + "source": [ + "### 1.2 初始化客户端并配置模型\n", + "\n", + "| 参数 | 说明 |\n", + "|------|------|\n", + "| `base_model` | 基座模型,服务端必须已加载该模型 |\n", + "| `base_url` | 服务端地址 |\n", + "| `api_key` | ModelScope Token |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28071991", + "metadata": {}, + "outputs": [], + "source": [ + "init_tinker_client()\n", + "\n", + "from tinker import ServiceClient\n", + "\n", + "base_model = 'Qwen/Qwen3.6-35B-A3B'\n", + "base_url = 'http://www.modelscope.cn/twinkle'" + ] + }, + { + "cell_type": "markdown", + "id": "08ebad60", + "metadata": {}, + "source": [ + "### 1.3 准备自我认知数据集\n", + "\n", + "数据集来自 ModelScope 上的 `swift/self-cognition`,包含中英文的「你是谁」「谁创建了你」等问答对。\n", + "\n", + "处理流程:\n", + "1. **加载数据**:取前 500 条样本\n", + "2. **应用模板**:使用基座模型对应的 chat template,最大长度 256 token\n", + "3. **替换身份**:用 `SelfCognitionProcessor` 将占位符替换为自定义名称\n", + "4. **编码**:将文本转为 token 序列\n", + "\n", + "> **可自定义**:修改 `model_name` 和 `author_name` 来设置你想要的身份。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "168e57b1", + "metadata": {}, + "outputs": [], + "source": [ + "# 加载自我认知数据集(取前 500 条)\n", + "dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(500)))\n", + "\n", + "# 应用 chat template\n", + "dataset.set_template('Template', model_id=f'ms://{base_model}', max_length=256)\n", + "\n", + "# 替换身份占位符\n", + "# model_name: 模型回答「你是谁」时使用的名字\n", + "# author_name: 模型回答「谁创建了你」时使用的团队名\n", + "dataset.map(SelfCognitionProcessor('twinkle模型', 'twinkle团队'), load_from_cache_file=False)\n", + "\n", + "# 编码为 token 序列\n", + "dataset.encode(batched=True, load_from_cache_file=False)\n", + "\n", + "# 构建 DataLoader,batch_size=8\n", + "dataloader = DataLoader(dataset=dataset, batch_size=8)\n", + "\n", + "print(f'数据集大小: {len(dataset)} 条')" + ] + }, + { + "cell_type": "markdown", + "id": "a0aff53f", + "metadata": {}, + "source": [ + "### 1.4 创建训练客户端\n", + "\n", + "使用 LoRA(Low-Rank Adaptation)进行高效微调,只训练少量额外参数,不修改原始模型权重。\n", + "\n", + "| 参数 | 值 | 说明 |\n", + "|------|-----|------|\n", + "| `rank` | 16 | LoRA 秩,越大表达能力越强但参数越多。自我认知任务 16 足够 |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "611adbea", + "metadata": {}, + "outputs": [], + "source": [ + "service_client = ServiceClient(\n", + " base_url=base_url,\n", + " api_key=os.environ.get('MODELSCOPE_TOKEN')\n", + ")\n", + "\n", + "# 创建 LoRA 训练客户端\n", + "training_client = service_client.create_lora_training_client(base_model=base_model, rank=16)\n", + "print('训练客户端创建成功')" + ] + }, + { + "cell_type": "markdown", + "id": "3cae5ffe", + "metadata": {}, + "source": [ + "### 1.5 执行训练循环\n", + "\n", + "训练 3 个 epoch,每个 epoch 遍历整个数据集。每个 batch 的处理流程:\n", + "\n", + "1. **`forward_backward`**:将数据发送到服务端,执行前向传播 + 反向传播,计算梯度\n", + "2. **`optim_step`**:使用 Adam 优化器更新模型权重\n", + "3. **`save_state`**:每个 epoch 结束后保存一个 LoRA 检查点\n", + "\n", + "| 参数 | 值 | 说明 |\n", + "|------|-----|------|\n", + "| epoch 数 | 3 | 训练轮数,自我认知任务通常 2-3 轮即可收敛 |\n", + "| learning_rate | 1e-4 | Adam 学习率 |\n", + "| loss 函数 | cross_entropy | 标准交叉熵损失 |\n", + "\n", + "> **预期输出**:每个 step 打印训练指标,每个 epoch 结束打印检查点保存路径。请记录最终的路径,Part 2 评估需要用到。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb882e3b", + "metadata": {}, + "outputs": [], + "source": [ + "for epoch in range(3):\n", + " print(f'Epoch {epoch}')\n", + " for step, batch in tqdm(enumerate(dataloader)):\n", + " # 将 InputFeature 转为 Tinker API 所需的 Datum 格式\n", + " input_datum = [input_feature_to_datum(input_feature) for input_feature in batch]\n", + "\n", + " # 前向 + 反向传播(计算梯度)\n", + " fwdbwd_future = training_client.forward_backward(input_datum, 'cross_entropy')\n", + "\n", + " # 优化器更新权重\n", + " optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))\n", + "\n", + " # 等待完成\n", + " fwdbwd_result = fwdbwd_future.result()\n", + " optim_result = optim_future.result()\n", + "\n", + " print(f'Training Metrics: {optim_result}')\n", + "\n", + " # 每个 epoch 保存检查点\n", + " save_future = training_client.save_state(f'twinkle-lora-{epoch}')\n", + " save_result = save_future.result()\n", + " print(f'Saved checkpoint to {save_result.path}')" + ] + }, + { + "cell_type": "markdown", + "id": "1b245842", + "metadata": {}, + "source": [ + "---\n", + "## Part 2:评估\n", + "\n", + "加载训练好的 LoRA 检查点,向模型提问「你是谁?」,观察模型是否以自定义身份回答。\n", + "\n", + "### 2.1 加载检查点并创建采样客户端\n", + "\n", + "> 将下方 `weight_path` 替换为 Part 1 训练输出的检查点路径。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c19eceb", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: 替换为 Part 1 输出的检查点路径\n", + "weight_path = 'twinkle://20260212_174205-Qwen_Qwen3_6-35B-A3B-51edc9ed/weights/twinkle-lora-2'\n", + "\n", + "service_client = ServiceClient(base_url=base_url, api_key=os.environ.get('MODELSCOPE_TOKEN'))\n", + "sampling_client = service_client.create_sampling_client(model_path=weight_path, base_model=base_model)\n", + "print('采样客户端创建成功')" + ] + }, + { + "cell_type": "markdown", + "id": "53dc6a4b", + "metadata": {}, + "source": [ + "### 2.2 构造提示词并采样\n", + "\n", + "向模型提问「你是谁?」,生成 8 条独立回复,观察回答的一致性。\n", + "\n", + "| 参数 | 值 | 说明 |\n", + "|------|-----|------|\n", + "| `max_tokens` | 50 | 自我认知回答通常很短 |\n", + "| `temperature` | 0.2 | 低温度使回答更聚焦一致 |\n", + "| `num_samples` | 8 | 生成 8 条独立回复验证一致性 |\n", + "\n", + "**预期效果**:\n", + "- 训练成功:8 条回复都应包含「twinkle模型」或「twinkle团队」\n", + "- 训练不足:部分回复可能仍使用原始身份" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "230fccb8", + "metadata": {}, + "outputs": [], + "source": [ + "template = Template(model_id=f'ms://{base_model}')\n", + "\n", + "trajectory = Trajectory(\n", + " messages=[\n", + " Message(role='system', content='You are a helpful assistant'),\n", + " Message(role='user', content='你是谁?'),\n", + " ]\n", + ")\n", + "\n", + "input_feature = template.encode(trajectory, add_generation_prompt=True)\n", + "input_ids = input_feature['input_ids'].tolist()\n", + "\n", + "prompt = types.ModelInput.from_ints(input_ids)\n", + "params = types.SamplingParams(\n", + " max_tokens=50,\n", + " temperature=0.2,\n", + ")\n", + "\n", + "print('Sampling...')\n", + "future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8)\n", + "result = future.result()\n", + "\n", + "print('Responses:')\n", + "for i, seq in enumerate(result.sequences):\n", + " print(f'{i}: {repr(template.decode(seq.tokens))}')" + ] + }, + { + "cell_type": "markdown", + "id": "a57cf366", + "metadata": {}, + "source": [ + "## 合并权重并导出\n", + "\n", + "训练得到的 LoRA 权重可以与原始模型合并,导出为完整的 HuggingFace 模型,方便后续部署和推理。\n", + "\n", + "> **注意**:合并操作需要 GPU 资源(需要加载完整模型),请在有足够显存的环境下执行。\n", + "\n", + "```bash\n", + "CUDA_VISIBLE_DEVICES=0,1,2,3 \\\n", + "NPROC_PER_NODE=4 \\\n", + "/opt/conda/envs/twinkle/bin/megatron export \\\n", + " --model Qwen/Qwen3.6-35B-A3B \\\n", + " --adapters <替换为你的 LoRA 检查点路径> \\\n", + " --output_dir <替换为输出目录> \\\n", + " --merge_lora true \\\n", + " --to_hf true \\\n", + " --tensor_model_parallel_size 2 \\\n", + " --expert_model_parallel_size 2 \\\n", + " --pipeline_model_parallel_size 2\n", + "```\n", + "\n", + "**参数说明**:\n", + "\n", + "| 参数 | 说明 |\n", + "|------|------|\n", + "| `--model` | 基座模型 ID |\n", + "| `--adapters` | 训练保存的 LoRA 检查点路径 |\n", + "| `--output_dir` | 合并后的完整模型输出目录 |\n", + "| `--merge_lora true` | 将 LoRA 权重合并到基座模型中 |\n", + "| `--to_hf true` | 导出为 HuggingFace 格式 |\n", + "| `--tensor_model_parallel_size` | 张量并行大小 |\n", + "| `--expert_model_parallel_size` | 专家并行大小(MoE 模型专用) |\n", + "| `--pipeline_model_parallel_size` | 流水线并行大小 |\n", + "\n", + "合并完成后,输出目录中即为完整的 HuggingFace 模型,可直接用于推理或部署。" + ] + }, + { + "cell_type": "markdown", + "id": "1be6cbdb", + "metadata": {}, + "source": [ + "## 常见问题\n", + "\n", + "| 问题 | 可能原因 | 解决方法 |\n", + "|------|----------|----------|\n", + "| 模型仍以原始身份回答 | 训练不充分 | 增加 epoch 数或 data_slice 范围 |\n", + "| Loss 不下降 | 学习率不合适 | 尝试调整 learning_rate(如 5e-5 或 2e-4) |\n", + "| 回答不稳定 | temperature 太高 | 评估时降低 temperature 到 0.1 |\n", + "| 连接超时 | 服务端问题 | 确认服务端正常运行 |" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebook/short_math_grpo.ipynb b/notebook/short_math_grpo.ipynb new file mode 100644 index 00000000..7bd4c87e --- /dev/null +++ b/notebook/short_math_grpo.ipynb @@ -0,0 +1,740 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cf7d65c6", + "metadata": {}, + "source": [ + "# Twinkle 数学 GRPO 训练\n", + "\n", + "> 🎯 **训练目标**:通过 GRPO 强化学习,让模型在解数学题时 **减少 thinking 长度**,用更简短的推理过程得出正确答案。\n", + "\n", + "本 Notebook 演示如何使用 **GRPO(Group Relative Policy Optimization)** 算法在数学问题上训练语言模型。\n", + "\n", + "### 什么是 GRPO?\n", + "\n", + "GRPO 是一种强化学习算法,训练流程如下:\n", + "1. 对每个数学题,让模型 **生成多个回答**\n", + "2. 用奖励函数 **评分** 每个回答(答案是否正确、格式是否规范)\n", + "3. 在同组内 **计算优势值**(advantage):好回答得正值,差回答得负值\n", + "4. 用优势值 **更新策略**:让模型更倾向于生成好答案\n", + "\n", + "### 整体流程\n", + "\n", + "```\n", + "准备数据集 → 初始化训练/采样客户端 → 训练循环:\n", + " 同步权重 → 采样生成 → 计算奖励 → 计算优势 → GRPO 训练 → 日志记录\n", + "```\n", + "\n", + "### 前置条件\n", + "\n", + "| 条件 | 说明 |\n", + "|------|------|\n", + "| 环境变量 | 设置 `MODELSCOPE_TOKEN` |\n", + "| 依赖安装 | `pip install twinkle-kit[tinker]` |\n", + "\n", + "> 💡 **获取 Token**:访问 [ModelScope Token 页面](https://www.modelscope.cn/my/access/token) 获取你的 `MODELSCOPE_TOKEN`,并设置为环境变量:`export MODELSCOPE_TOKEN=<你的Token>`\n" + ] + }, + { + "cell_type": "markdown", + "id": "a5d9101f", + "metadata": {}, + "source": [ + "## 🚀 零卡训练服务化(Serverless Training)\n", + "\n", + "本 Notebook 运行在 **ModelScope 零卡训练平台** 上。你无需自备 GPU,只需在 Notebook 中编写训练逻辑,平台会自动调度云端 GPU 资源完成训练。\n", + "\n", + "### 架构示意图\n", + "\n", + "```\n", + "┌─────────────────────────────────────────────────────────────┐\n", + "│ 你的 Notebook(CPU 环境) │\n", + "│ │\n", + "│ ┌──────────┐ HTTP / gRPC ┌──────────────────────┐ │\n", + "│ │ Twinkle │ ─────────────────► │ ModelScope 云端 │ │\n", + "│ │ Client │ ◄───────────────── │ GPU 训练集群 │ │\n", + "│ └──────────┘ 训练结果返回 │ │ │\n", + "│ │ │ ┌────┐ ┌────┐ ┌────┐│ │\n", + "│ │ 构造数据 │ │GPU0│ │GPU1│ │... ││ │\n", + "│ │ 发送训练请求 │ └────┘ └────┘ └────┘│ │\n", + "│ │ 接收指标/检查点 │ 模型加载 + LoRA 训练 │ │\n", + "│ ▼ └──────────────────────┘ │\n", + "│ ┌──────────┐ │\n", + "│ │ 数据准备 │ Dataset / DataLoader / Preprocessor │\n", + "│ └──────────┘ │\n", + "└─────────────────────────────────────────────────────────────┘\n", + "```\n", + "\n", + "### 核心优势\n", + "\n", + "| 特性 | 说明 |\n", + "|------|------|\n", + "| **零卡启动** | Notebook 本身不需要 GPU,训练在云端自动执行 |\n", + "| **按需付费** | 仅在训练时占用 GPU 资源 |\n", + "| **开箱即用** | 预置主流模型,无需下载权重 |\n", + "| **LoRA 微调** | 高效参数微调,几分钟即可完成小规模训练 |\n", + "\n", + "> 🔗 本项目由 [Twinkle](https://github.com/modelscope/twinkle) 框架提供支持 | [GitHub](https://github.com/modelscope/twinkle)" + ] + }, + { + "cell_type": "markdown", + "id": "f7bdee2e", + "metadata": {}, + "source": [ + "## 第一步:导入依赖与全局配置\n", + "\n", + "> **为什么使用 Twinkle 客户端语法?**\n", + "> Twinkle 提供 `tinker` 和 `twinkle` 两套客户端 API。其中 **tinker** 接口不支持设置 `target_modules`、`LoraConfig` 等细节调控,而 GRPO 训练在 MoE 模型上需要显式指定 LoRA 的 target modules(否则会触发 vLLM 兼容性问题)。\n", + "> 因此本 Notebook 使用 **twinkle 客户端语法**,以获得对训练参数的完整控制。\n", + "\n", + "| 配置项 | 默认值 | 说明 |\n", + "|--------|--------|------|\n", + "| `MODEL_ID` | ms://Qwen/Qwen3.6-35B-A3B | 基座模型(需加 `ms://` 前缀) |\n", + "| `NUM_GENERATIONS` | 4 | 每个 prompt 生成几条回答 |\n", + "| `MAX_NEW_TOKENS` | 1024 | 单条回答最大 token 数 |\n", + "| `LEARNING_RATE` | 2e-5 | 学习率 |\n", + "| `MAX_STEPS` | 100 | 最大训练步数 |\n", + "| `BATCH_SIZE` | 2 | 每步的 prompt 数量(实际训练样本 = BATCH_SIZE × NUM_GENERATIONS) |\n", + "| `TEMPERATURE` | 1.0 | 采样温度,RL 训练中通常设为 1.0 保持多样性 |\n", + "| `SYNC_INTERVAL` | 1 | 每隔多少步同步权重到采样端 |" + ] + }, + { + "cell_type": "markdown", + "id": "ef3af352", + "metadata": {}, + "source": [ + "### ⚠️ MoE 模型 LoRA 注意事项\n", + "\n", + "由于 `Qwen/Qwen3.6-35B-A3B` 是 MoE(Mixture of Experts)架构,在配合 vLLM 采样时存在已知兼容性问题。\n", + "如果你在本地使用 Megatron 进行 GRPO 训练,建议显式指定 `target_modules`(而非 `all-linear`):\n", + "\n", + "```python\n", + "target_modules:\n", + " - mlp.linear_fc1\n", + " - mlp.linear_fc2\n", + " - attn.proj\n", + " - shared_experts.linear_fc1\n", + " - shared_experts.linear_fc2\n", + " - linear_qkv\n", + " - in_proj\n", + " - out_proj\n", + " - linear_proj\n", + "```\n", + "\n", + "> **注意**:此配置是一个示例,由于问题来自 vLLM 侧的 MoE LoRA 支持尚不完善,实际训练效果可能受限。\n", + "> 如果不需要在线采样(vLLM),使用 `all-linear` 仍然可以正常训练。" + ] + }, + { + "cell_type": "markdown", + "id": "8b6828b2", + "metadata": {}, + "source": [ + "## 环境安装\n", + "\n", + "首次运行前,请先执行以下安装命令。如已安装可跳过此步。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9ecc0cbb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install twinkle-kit[tinker] -q" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a5811355", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/root/miniconda3/envs/vllm19/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import dotenv\n", + "dotenv.load_dotenv('.env')\n", + "\n", + "import gc\n", + "import os\n", + "import re\n", + "from peft import LoraConfig\n", + "from typing import List, Tuple, Dict, Any\n", + "\n", + "from twinkle import get_logger, init_twinkle_client\n", + "from twinkle.reward.base import Reward\n", + "from twinkle.advantage import GRPOAdvantage\n", + "from twinkle.dataset import DatasetMeta, Dataset\n", + "from twinkle.metric import CompletionRewardMetric\n", + "from twinkle.dataloader import DataLoader\n", + "from twinkle.preprocessor.llm import GSM8KProcessor\n", + "from twinkle_client.model import MultiLoraTransformersModel\n", + "from twinkle_client.sampler import vLLMSampler\n", + "\n", + "logger = get_logger()\n", + "\n", + "# ========== 全局配置 ==========\n", + "MODEL_ID = 'ms://Qwen/Qwen3.6-35B-A3B'\n", + "NUM_GENERATIONS = 4\n", + "MAX_NEW_TOKENS = 1024\n", + "LEARNING_RATE = 2e-5\n", + "MAX_STEPS = 100\n", + "BATCH_SIZE = 2\n", + "TEMPERATURE = 1.0\n", + "SYNC_INTERVAL = 1\n", + "GRADIENT_ACCUMULATION_STEPS = 1\n", + "DATA_NUM = 2000\n", + "\n", + "SYSTEM_PROMPT = (\n", + " 'You are a helpful math assistant. Solve the problem with minimal but correct reasoning '\n", + " 'and put your final answer within \\\\boxed{}.'\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "id": "082fb9bd", + "metadata": {}, + "source": [ + "## 第二步:定义奖励函数\n", + "\n", + "GRPO 需要奖励函数来评判每条回答的质量。本例使用两个奖励函数:\n", + "\n", + "### 准确性奖励 (MathAccuracyReward)\n", + "- 从模型输出中提取 `#### 数字` 格式的答案\n", + "- 与标准答案做数值比较\n", + "- 正确得 1.0 分,错误得 0.0 分\n", + "\n", + "### 格式奖励 (MathFormatReward)\n", + "- 检查输出是否包含 `...` 推理标签和 `####` 答案标记\n", + "- 格式正确时,回答越短得分越高(鼓励简洁推理)\n", + "- 格式不正确得 0.0 分\n", + "\n", + "**总奖励 = 准确性奖励 + 格式奖励**,最高 2.0 分。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9f8a4a5", + "metadata": {}, + "outputs": [], + "source": [ + "class MathAccuracyReward(Reward):\n", + " \"\"\"准确性奖励:检查模型答案是否与标准答案一致\"\"\"\n", + "\n", + " @staticmethod\n", + " def extract_answer(completion: str) -> str:\n", + " text = completion[-500:] if len(completion) > 500 else completion\n", + " matches = re.findall(r'####\\s*([\\-\\d,\\.\\s]+)', text)\n", + " if matches:\n", + " return matches[-1].replace(',', '').replace(' ', '').strip()\n", + " return ''\n", + "\n", + " def __call__(self, trajectories: List[Dict[str, Any]], ground_truths: List[Dict[str, Any]]) -> List[float]:\n", + " rewards = []\n", + " for trajectory in trajectories:\n", + " messages = trajectory.get('messages', [])\n", + " completion = ''\n", + " for msg in reversed(messages):\n", + " if msg.get('role') == 'assistant':\n", + " completion = msg.get('content', '')\n", + " break\n", + "\n", + " gt = ''\n", + " user_data = trajectory.get('user_data', [])\n", + " if isinstance(user_data, list):\n", + " for item in user_data:\n", + " if isinstance(item, (list, tuple)) and len(item) == 2:\n", + " if item[0] == 'ground_truth':\n", + " gt = str(item[1])\n", + " break\n", + "\n", + " predicted = self.extract_answer(completion)\n", + " correct = False\n", + " if predicted and gt:\n", + " try:\n", + " correct = abs(float(predicted) - float(gt)) < 1e-5\n", + " except (ValueError, OverflowError):\n", + " correct = predicted == gt\n", + "\n", + " rewards.append(1.0 if correct else 0.0)\n", + " return rewards\n", + "\n", + "\n", + "class MathFormatReward(Reward):\n", + " \"\"\"格式奖励:检查格式并奖励简短回答\"\"\"\n", + "\n", + " def __call__(self, trajectories: List[Dict[str, Any]], ground_truths: List[Dict[str, Any]]) -> List[float]:\n", + " rewards = []\n", + " for trajectory in trajectories:\n", + " messages = trajectory.get('messages', [])\n", + " completion = ''\n", + " for msg in reversed(messages):\n", + " if msg.get('role') == 'assistant':\n", + " completion = msg.get('content', '')\n", + " break\n", + "\n", + " has_think = bool(re.search(r'.*?', completion, re.DOTALL))\n", + " has_answer = bool(re.search(r'####\\s*[\\-\\d,\\.]+', completion))\n", + "\n", + " if not (has_think and has_answer):\n", + " rewards.append(0.0)\n", + " else:\n", + " length = len(completion)\n", + " if length <= 100:\n", + " rewards.append(1.0)\n", + " else:\n", + " reward = max(0.0, 1.0 - (length - 100) / 2000)\n", + " rewards.append(reward)\n", + " return rewards\n", + "\n", + "\n", + "def compute_rewards(trajectories: List[Dict[str, Any]]) -> Tuple[List[float], List[float], List[float]]:\n", + " \"\"\"计算总奖励 = 准确性 + 格式\"\"\"\n", + " accuracy_reward_fn = MathAccuracyReward()\n", + " format_reward_fn = MathFormatReward()\n", + " accuracy_rewards = accuracy_reward_fn(trajectories, [])\n", + " format_rewards = format_reward_fn(trajectories, [])\n", + " total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)]\n", + " return total_rewards, format_rewards, accuracy_rewards\n" + ] + }, + { + "cell_type": "markdown", + "id": "82d58ef3", + "metadata": {}, + "source": [ + "## 第三步:准备数据集\n", + "\n", + "加载 ModelScope 上的 `gsm8k` 数学数据集,并进行预处理和编码。\n", + "\n", + "- `GSM8KProcessor`:提取题目和标准答案(`####` 格式),构造 system + user 对话\n", + "- `add_generation_prompt=True`:编码时在末尾加上 assistant 前缀,准备让模型生成回答\n", + "- `truncation_strategy='delete'`:超过最大长度的样本直接删除而非截断" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "837c4112", + "metadata": {}, + "outputs": [], + "source": [ + "def create_math_dataset():\n", + " meta = DatasetMeta(\n", + " 'ms://modelscope/gsm8k',\n", + " subset_name='main',\n", + " split='train',\n", + " data_slice=range(DATA_NUM),\n", + " )\n", + " dataset = Dataset(meta)\n", + " dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=2048, enable_thinking=False)\n", + " dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT))\n", + " dataset.encode(add_generation_prompt=True, truncation_strategy='delete')\n", + " return dataset\n", + "\n", + "dataset = create_math_dataset()\n", + "dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=0)\n", + "print(f'数据集加载完成,共 {len(dataset)} 条')" + ] + }, + { + "cell_type": "markdown", + "id": "5d9e33c5", + "metadata": {}, + "source": [ + "## 第四步:初始化 Twinkle 客户端与模型\n", + "\n", + "Twinkle 客户端直接与训练服务通信,支持完整的模型配置:\n", + "\n", + "- **`MultiLoraTransformersModel`**:支持 LoRA 适配器、损失函数、优化器、模板等全部设置\n", + "- **`vLLMSampler`**:采样端,支持 `adapter_uri` 动态加载最新 LoRA 权重\n", + "- **`LoraConfig`**:可以精确控制 `target_modules`,这是使用 twinkle 语法的关键优势\n", + "\n", + "> 对于 MoE 模型,必须显式指定 `target_modules` 而非 `all-linear`,以避免 vLLM 兼容性问题。\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f49839e9", + "metadata": {}, + "outputs": [], + "source": [ + "# 初始化 Twinkle 客户端\n", + "client = init_twinkle_client(\n", + " base_url='http://www.modelscope.cn/twinkle',\n", + " api_key=os.environ.get('MODELSCOPE_TOKEN', 'EMPTY_TOKEN'),\n", + ")\n", + "\n", + "# 配置训练模型\n", + "model = MultiLoraTransformersModel(model_id=MODEL_ID)\n", + "\n", + "# LoRA 配置 —— 显式指定 target_modules(MoE 模型关键)\n", + "lora_config = LoraConfig(\n", + " target_modules=[\n", + " 'mlp.linear_fc1', 'mlp.linear_fc2',\n", + " 'attn.proj',\n", + " 'shared_experts.linear_fc1', 'shared_experts.linear_fc2',\n", + " 'linear_qkv', 'in_proj', 'out_proj', 'linear_proj',\n", + " ],\n", + " r=8,\n", + " lora_alpha=32,\n", + " lora_dropout=0.05,\n", + ")\n", + "model.add_adapter_to_model(\n", + " 'default',\n", + " lora_config,\n", + " gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n", + ")\n", + "\n", + "# 设置 GRPO 损失函数(RL 训练的核心)\n", + "model.set_loss('GRPOLoss', epsilon=0.2, beta=0.0)\n", + "\n", + "# 设置优化器\n", + "model.set_optimizer('Adam', lr=LEARNING_RATE)\n", + "\n", + "# 设置输入处理器和模板\n", + "model.set_processor('InputProcessor')\n", + "model.set_template('Qwen3_5Template', model_id=MODEL_ID)\n", + "\n", + "# 配置采样端\n", + "sampler = vLLMSampler(model_id=MODEL_ID)\n", + "sampler.set_template('Qwen3_5Template', model_id=MODEL_ID)\n", + "\n", + "# 设置指标和优势函数\n", + "advantage_fn = GRPOAdvantage()\n", + "metrics = CompletionRewardMetric()\n", + "\n", + "sampling_params = {\n", + " 'max_tokens': MAX_NEW_TOKENS,\n", + " 'temperature': TEMPERATURE,\n", + " 'top_p': 0.95,\n", + " 'num_samples': NUM_GENERATIONS,\n", + " 'logprobs': 1,\n", + "}\n", + "\n", + "print('模型和采样端配置完成')" + ] + }, + { + "cell_type": "markdown", + "id": "b00959b2", + "metadata": {}, + "source": [ + "## 第五步:GRPO 训练主循环\n", + "\n", + "每个训练步骤包含以下阶段:\n", + "\n", + "### 5.1 保存权重\n", + "每隔 `SYNC_INTERVAL` 步调用 `model.save()` 保存 LoRA 权重,获取 `twinkle_path` 用于采样端加载。\n", + "\n", + "### 5.2 采样生成\n", + "通过 `sampler.sample(inputs, adapter_uri=twinkle_path)` 使用最新 LoRA 权重生成回答。\n", + "\n", + "### 5.3 计算奖励与优势\n", + "- 对每条回答计算准确性和格式奖励\n", + "- 用 `GRPOAdvantage` 在同组内标准化,得到优势值\n", + "\n", + "### 5.4 训练步\n", + "调用 `model.forward_backward(inputs, advantages, old_logps)` 执行 GRPO 策略优化,然后 `model.clip_grad_and_step()` 更新参数。\n", + "\n", + "> **与 tinker 语法的区别**:twinkle 客户端将 GRPO 的 forward_backward 和 optimizer step 封装为高级 API,无需手动构造 `Datum` 和 `loss_fn_inputs`。\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "844b40d6", + "metadata": {}, + "outputs": [], + "source": [ + "current_adapter_uri = None\n", + "step = 0\n", + "\n", + "for batch in dataloader:\n", + " if step >= MAX_STEPS:\n", + " break\n", + "\n", + " metrics.reset()\n", + " prompts = batch if isinstance(batch, list) else [batch]\n", + "\n", + " # ===== 5.1 保存权重并更新 adapter_uri =====\n", + " if step % SYNC_INTERVAL == 0:\n", + " logger.info(f'Step {step}: Saving weights for sampler...')\n", + " result = model.save(\n", + " name=f'grpo-sampler-step-{step}',\n", + " save_optimizer=False,\n", + " )\n", + " current_adapter_uri = result.twinkle_path\n", + " logger.info(f'Step {step}: Saved weights to {current_adapter_uri}')\n", + "\n", + " # ===== 5.2 采样生成 =====\n", + " sample_responses = sampler.sample(\n", + " inputs=prompts,\n", + " sampling_params=sampling_params,\n", + " adapter_uri=current_adapter_uri,\n", + " )\n", + "\n", + " all_input_data: List[Dict[str, Any]] = []\n", + " all_old_logps: List[List[float]] = []\n", + " all_completion_lengths: List[int] = []\n", + "\n", + " for sample_response in sample_responses:\n", + " for sequence in sample_response.sequences:\n", + " all_input_data.append(sequence.new_input_feature)\n", + " all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs])\n", + " all_completion_lengths.append(len(sequence.tokens))\n", + "\n", + " # ===== 5.3 计算奖励 =====\n", + " total_rewards, format_rewards, accuracy_rewards = compute_rewards(all_input_data)\n", + " metrics.accumulate(\n", + " completion_lengths=all_completion_lengths,\n", + " rewards={\n", + " 'total': total_rewards,\n", + " 'format': format_rewards,\n", + " 'accuracy': accuracy_rewards,\n", + " },\n", + " )\n", + "\n", + " # ===== 5.4 计算优势值 =====\n", + " advantages = advantage_fn(\n", + " total_rewards,\n", + " num_generations=NUM_GENERATIONS,\n", + " scale='group',\n", + " ).tolist()\n", + "\n", + " frac_zero_std = (1.0 if all(abs(a) < 1e-8 for a in advantages) else 0.0)\n", + " if frac_zero_std == 1.0:\n", + " logger.info(f'Step {step}: All advantages are zero, skipping training')\n", + " step += 1\n", + " continue\n", + "\n", + " # ===== 5.5 GRPO 训练步 =====\n", + " model.forward_backward(\n", + " inputs=all_input_data,\n", + " advantages=advantages,\n", + " old_logps=all_old_logps,\n", + " )\n", + " model.clip_grad_and_step()\n", + "\n", + " gc.collect()\n", + "\n", + " # ===== 5.6 日志 =====\n", + " log_dict = metrics.calculate()\n", + " log_dict.update(model.calculate_metric(is_training=True).result)\n", + " log_dict['train/frac_reward_zero_std'] = frac_zero_std\n", + " logger.info(f'Step {step}: {log_dict}')\n", + " step += 1\n" + ] + }, + { + "cell_type": "markdown", + "id": "463836a8", + "metadata": {}, + "source": [ + "## 第六步:保存最终检查点\n", + "\n", + "训练完成后保存最终的 LoRA 检查点,可用于后续推理(参见 `sample.ipynb`)。\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1500f83", + "metadata": {}, + "outputs": [], + "source": [ + "result = model.save(name='grpo-math-final', save_optimizer=True)\n", + "logger.info(f'Saved final checkpoint: {result}')\n", + "print(f'训练完成!检查点路径: {result.twinkle_path}')" + ] + }, + { + "cell_type": "markdown", + "id": "f19f7050", + "metadata": {}, + "source": [ + "## 推理(Inference)\n", + "\n", + "训练完成后,可以直接使用 **线上服务** 进行推理,无需本地 GPU。\n", + "\n", + "通过 `save_weights_and_get_sampling_client` 或 `create_sampling_client` 加载训练好的 LoRA 检查点,即可在线采样生成。\n", + "\n", + "> 将下方 `weight_path` 替换为训练输出的检查点路径(`twinkle://...` 格式)。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca087072", + "metadata": {}, + "outputs": [], + "source": [ + "# 推理示例(使用线上服务,无需本地 GPU)\n", + "import os\n", + "from tinker import types\n", + "from twinkle import init_tinker_client, get_logger\n", + "from twinkle.data_format import Message, Trajectory\n", + "from twinkle.template import Template\n", + "\n", + "logger = get_logger()\n", + "\n", + "BASE_MODEL = 'Qwen/Qwen3.6-35B-A3B'\n", + "\n", + "# TODO: 替换为训练输出的检查点路径\n", + "weight_path = '<替换为你的 twinkle:// 检查点路径>' # 例如: save_result.path\n", + "\n", + "init_tinker_client()\n", + "from tinker import ServiceClient\n", + "\n", + "service_client = ServiceClient(\n", + " base_url='http://www.modelscope.cn/twinkle',\n", + " api_key=os.environ.get('MODELSCOPE_TOKEN'),\n", + ")\n", + "\n", + "# 加载 LoRA 检查点并创建采样客户端\n", + "sampling_client = service_client.create_sampling_client(\n", + " model_path=weight_path,\n", + " base_model=BASE_MODEL,\n", + ")\n", + "\n", + "# 构造 Prompt\n", + "template = Template(model_id=f'ms://{BASE_MODEL}')\n", + "trajectory = Trajectory(\n", + " messages=[\n", + " Message(role='system', content='You are a helpful assistant.'),\n", + " Message(role='user', content='你好,请介绍一下你自己。'),\n", + " ]\n", + ")\n", + "\n", + "input_feature = template.encode(trajectory, add_generation_prompt=True)\n", + "input_ids = input_feature['input_ids'].tolist()\n", + "\n", + "# 采样\n", + "prompt = types.ModelInput.from_ints(input_ids)\n", + "params = types.SamplingParams(max_tokens=256, temperature=0.7)\n", + "\n", + "print('Sampling...')\n", + "future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=3)\n", + "result = future.result()\n", + "\n", + "# 输出结果\n", + "print('Responses:')\n", + "for i, seq in enumerate(result.sequences):\n", + " print(f'{i}: {repr(template.decode(seq.tokens))}')" + ] + }, + { + "cell_type": "markdown", + "id": "4bc85a32", + "metadata": {}, + "source": [ + "## 合并权重并导出\n", + "\n", + "训练得到的 LoRA 权重可以与原始模型合并,导出为完整的 HuggingFace 模型,方便后续部署和推理。\n", + "\n", + "> **注意**:合并操作需要 GPU 资源(需要加载完整模型),请在有足够显存的环境下执行。\n", + "\n", + "```bash\n", + "CUDA_VISIBLE_DEVICES=0,1,2,3 \\\n", + "NPROC_PER_NODE=4 \\\n", + "/opt/conda/envs/twinkle/bin/megatron export \\\n", + " --model Qwen/Qwen3.6-35B-A3B \\\n", + " --adapters <替换为你的 LoRA 检查点路径> \\\n", + " --output_dir <替换为输出目录> \\\n", + " --merge_lora true \\\n", + " --to_hf true \\\n", + " --tensor_model_parallel_size 2 \\\n", + " --expert_model_parallel_size 2 \\\n", + " --pipeline_model_parallel_size 2\n", + "```\n", + "\n", + "**参数说明**:\n", + "\n", + "| 参数 | 说明 |\n", + "|------|------|\n", + "| `--model` | 基座模型 ID |\n", + "| `--adapters` | 训练保存的 LoRA 检查点路径 |\n", + "| `--output_dir` | 合并后的完整模型输出目录 |\n", + "| `--merge_lora true` | 将 LoRA 权重合并到基座模型中 |\n", + "| `--to_hf true` | 导出为 HuggingFace 格式 |\n", + "| `--tensor_model_parallel_size` | 张量并行大小 |\n", + "| `--expert_model_parallel_size` | 专家并行大小(MoE 模型专用) |\n", + "| `--pipeline_model_parallel_size` | 流水线并行大小 |\n", + "\n", + "合并完成后,输出目录中即为完整的 HuggingFace 模型,可直接用于推理或部署。" + ] + }, + { + "cell_type": "markdown", + "id": "521ade97", + "metadata": {}, + "source": [ + "## 关键指标解读\n", + "\n", + "训练过程中会输出以下指标:\n", + "\n", + "| 指标 | 含义 | 期望趋势 |\n", + "|------|------|----------|\n", + "| `accuracy` | 回答正确率 | 逐步上升 |\n", + "| `format` | 格式正确率 | 快速达到高值 |\n", + "| `total` | 总奖励(准确性+格式) | 逐步上升 |\n", + "| `frac_reward_zero_std` | 同组奖励标准差为零的比例 | 逐步下降(说明模型在区分好坏回答) |\n", + "| `completion_lengths` | 回答平均长度 | 逐步缩短(简洁奖励的效果) |\n", + "\n", + "## 常见问题\n", + "\n", + "| 问题 | 可能原因 | 解决方法 |\n", + "|------|----------|----------|\n", + "| 准确率不提升 | 学习率太低/太高 | 尝试 5e-5 或 2e-4 |\n", + "| 所有 advantage 为 0 | 同组回答奖励完全相同 | 增大 NUM_GENERATIONS 或提高 temperature |\n", + "| OOM 内存不足 | 生成太长 | 减小 MAX_NEW_TOKENS 或 BATCH_SIZE |\n", + "| 采样超时 | 服务端 sampler 未配置 | 检查 server_config.yaml 中 sampler 配置 |" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "vllm19", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}