diff --git a/dev/run_qwen3_5_megatron_yes_no_maybe.py b/dev/run_qwen3_5_megatron_yes_no_maybe.py new file mode 100644 index 00000000..2f8cdd20 --- /dev/null +++ b/dev/run_qwen3_5_megatron_yes_no_maybe.py @@ -0,0 +1,165 @@ +"""Launch a multi-step Qwen3.5 Megatron yes-no-maybe run on SkyPilot.""" + +import argparse +import os +import textwrap + +from dotenv import load_dotenv +import sky +from sky import ClusterStatus + +load_dotenv() + +DEFAULT_IMAGE_ID = "docker:nvidia/cuda:12.8.1-devel-ubuntu22.04" + + +def _format_env_bool(value: bool) -> str: + return "true" if value else "false" + + +def _format_int_list(values: list[int]) -> str: + return ",".join(str(value) for value in values) + + +parser = argparse.ArgumentParser( + description="Launch a Qwen3.5 Megatron yes-no-maybe convergence run." +) +parser.add_argument("--fast", action="store_true") +parser.add_argument("--base-model", type=str, default="Qwen/Qwen3.5-35B-A3B") +parser.add_argument("--accelerator", type=str, default="H200:2") +parser.add_argument( + "--cluster-name", type=str, default="art-qwen35-megatron-yes-no-maybe" +) +parser.add_argument("--image-id", type=str, default=DEFAULT_IMAGE_ID) +parser.add_argument("--project", type=str, default="qwen35-megatron-ynm") +parser.add_argument("--gpu-memory-utilization", type=float, default=0.65) +parser.add_argument("--max-model-len", type=int, default=1024) +parser.add_argument("--max-seq-length", type=int, default=1024) +parser.add_argument("--packed-sequence-length", type=int, default=None) +parser.add_argument("--max-num-seqs", type=int, default=8) +parser.add_argument("--num-steps", type=int, default=10) +parser.add_argument("--rollouts-per-prompt", type=int, default=8) +parser.add_argument("--eval-prompts", type=int, default=24) +parser.add_argument("--max-tokens", type=int, default=5) +parser.add_argument("--timeout", type=float, default=600.0) +parser.add_argument("--learning-rate", type=float, default=5e-5) +parser.add_argument( + "--load-in-4bit", action=argparse.BooleanOptionalAction, default=False +) +parser.add_argument( + "--load-in-16bit", action=argparse.BooleanOptionalAction, default=True +) +parser.add_argument("--trainer-gpu-ids", type=int, nargs="+", default=[0]) +parser.add_argument("--inference-gpu-ids", type=int, nargs="+", default=[1]) +args = parser.parse_args() + +cluster_name = args.cluster_name +cluster_prefix = os.environ.get("CLUSTER_PREFIX") +if cluster_prefix: + cluster_name = f"{cluster_prefix}-{cluster_name}" + +setup_script = textwrap.dedent("""\ + echo 'Setting up environment...' + apt-get update + apt-get install -y python3 python3-pip python-is-python3 git curl + curl -LsSf https://astral.sh/uv/install.sh | sh + source $HOME/.local/bin/env +""") + +env = [ + f"PROJECT={args.project}", + "MODEL_NAME=qwen35-megatron-ynm-$(date +%Y%m%d-%H%M%S)", + f"BASE_MODEL={args.base_model}", + f"GPU_MEMORY_UTILIZATION={args.gpu_memory_utilization}", + f"MAX_MODEL_LEN={args.max_model_len}", + f"MAX_SEQ_LENGTH={args.max_seq_length}", + "PACKED_SEQUENCE_LENGTH=" + + str( + args.packed_sequence_length + if args.packed_sequence_length is not None + else args.max_seq_length + ), + f"MAX_NUM_SEQS={args.max_num_seqs}", + f"LOAD_IN_4BIT={_format_env_bool(args.load_in_4bit)}", + f"LOAD_IN_16BIT={_format_env_bool(args.load_in_16bit)}", + f"NUM_STEPS={args.num_steps}", + f"ROLLOUTS_PER_PROMPT={args.rollouts_per_prompt}", + f"EVAL_PROMPTS={args.eval_prompts}", + f"MAX_TOKENS={args.max_tokens}", + f"TIMEOUT={args.timeout}", + f"LEARNING_RATE={args.learning_rate}", + f"TRAINER_GPU_IDS={_format_int_list(args.trainer_gpu_ids)}", + f"INFERENCE_GPU_IDS={_format_int_list(args.inference_gpu_ids)}", + "ROLLOUT_WEIGHTS_MODE=merged", +] +env_block = " \\\n ".join(env) + +run_script = textwrap.dedent( + f"""\ + source $HOME/.local/bin/env + cd ~/sky_workdir + bash src/art/megatron/setup.sh + {env_block} \\ + ~/.local/bin/uv run dev/yes-no-maybe-megatron.py +""" +) + +task = sky.Task( + name="qwen3.5-megatron-yes-no-maybe", + setup=setup_script, + run=run_script, + workdir=".", +) +task.set_resources( + sky.Resources( + accelerators=args.accelerator, + cloud=sky.clouds.Kubernetes(), + image_id=args.image_id, + ) +) +if os.path.exists(".env"): + task.set_file_mounts({"~/sky_workdir/.env": ".env"}) + +print(f"Launching on cluster: {cluster_name}") +print(f" base_model: {args.base_model}") +print(f" project: {args.project}") +print(f" accelerator: {args.accelerator}") +print(f" image_id: {args.image_id}") +print(f" gpu_memory_utilization: {args.gpu_memory_utilization}") +print(f" max_model_len: {args.max_model_len}") +print(f" max_seq_length: {args.max_seq_length}") +print( + " packed_sequence_length: " + f"{args.packed_sequence_length if args.packed_sequence_length is not None else args.max_seq_length}" +) +print(f" max_num_seqs: {args.max_num_seqs}") +print(f" num_steps: {args.num_steps}") +print(f" rollouts_per_prompt: {args.rollouts_per_prompt}") +print(f" eval_prompts: {args.eval_prompts}") +print(f" max_tokens: {args.max_tokens}") +print(f" timeout: {args.timeout}") +print(f" learning_rate: {args.learning_rate}") +print(f" load_in_4bit: {args.load_in_4bit}") +print(f" load_in_16bit: {args.load_in_16bit}") +print(f" trainer_gpu_ids: {args.trainer_gpu_ids}") +print(f" inference_gpu_ids: {args.inference_gpu_ids}") + +cluster_status = sky.stream_and_get(sky.status(cluster_names=[cluster_name])) +if cluster_status and cluster_status[0]["status"] == ClusterStatus.UP: + print(f"Cluster {cluster_name} is UP. Canceling any active jobs...") + sky.stream_and_get(sky.cancel(cluster_name, all=True)) + +job_id, _ = sky.stream_and_get( + sky.launch( + task, + cluster_name=cluster_name, + retry_until_up=True, + idle_minutes_to_autostop=60, + down=True, + fast=args.fast, + ) +) + +print(f"Job submitted (ID: {job_id}). Streaming logs...") +exit_code = sky.tail_logs(cluster_name=cluster_name, job_id=job_id, follow=True) +print(f"Job {job_id} finished with exit code {exit_code}.") diff --git a/dev/yes-no-maybe-megatron.py b/dev/yes-no-maybe-megatron.py index 7deff71e..b34f6f5d 100644 --- a/dev/yes-no-maybe-megatron.py +++ b/dev/yes-no-maybe-megatron.py @@ -1,86 +1,329 @@ +"""Yes-no-maybe metrics demo for the Megatron backend.""" + +from __future__ import annotations + import asyncio from itertools import permutations +import json import os +import time from dotenv import load_dotenv import openai -import torch import art from art.megatron import MegatronBackend +def _get_env_bool(name: str, default: bool | None = None) -> bool | None: + value = os.environ.get(name) + if value is None: + return default + lowered = value.strip().lower() + if lowered in {"1", "true", "yes", "on"}: + return True + if lowered in {"0", "false", "no", "off"}: + return False + raise ValueError(f"Invalid boolean value for {name}: {value!r}") + + +def _get_env_int_list(name: str, default: list[int] | None = None) -> list[int] | None: + value = os.environ.get(name) + if value is None: + return default + parts = [part.strip() for part in value.split(",") if part.strip()] + if not parts: + raise ValueError(f"Invalid GPU ID list for {name}: {value!r}") + return [int(part) for part in parts] + + +def _chat_completion_extra_body(base_model: str) -> dict[str, object] | None: + if base_model.startswith("Qwen/Qwen3"): + return {"chat_template_kwargs": {"enable_thinking": False}} + return None + + +def with_quotes(word: str) -> str: + return f"'{word}'" + + +def build_prompts() -> list[str]: + prompts: list[str] = [] + for prefix in ["respond", "just respond"]: + for use_quotes in [True, False]: + for length in [3, 2]: + for words in permutations(["yes", "no", "maybe"], length): + rendered_words = ( + [with_quotes(word) for word in words] + if use_quotes + else list(words) + ) + if length == 3: + suffix = ", ".join(rendered_words) + else: + suffix = f"{rendered_words[0]} or {rendered_words[1]}" + prompts.append(f"{prefix} with {suffix}") + return prompts + + +def first_word(content: str | None) -> str: + if not content: + return "" + words = content.strip().lower().split(maxsplit=1) + if not words: + return "" + return words[0].strip(".,!?:;\"'()[]{}") + + +def reward_for_answer(answer: str) -> float: + if answer == "yes": + return 0.5 + if answer == "no": + return 0.75 + if answer == "maybe": + return 1.0 + return 0.0 + + +def summarize(groups: list[art.TrajectoryGroup]) -> dict[str, float]: + trajectories = [trajectory for group in groups for trajectory in group.trajectories] + answers = [str(trajectory.metadata["answer"]) for trajectory in trajectories] + rewards = [trajectory.reward for trajectory in trajectories] + total = len(trajectories) + assert total > 0 + return { + "num_rollouts": float(total), + "avg_reward": sum(rewards) / total, + "yes_rate": answers.count("yes") / total, + "no_rate": answers.count("no") / total, + "maybe_rate": answers.count("maybe") / total, + "invalid_rate": sum(answer not in {"yes", "no", "maybe"} for answer in answers) + / total, + } + + async def rollout( - client: openai.AsyncOpenAI, model_name: str, prompt: str + client: openai.AsyncOpenAI, + model: art.TrainableModel, + prompt: str, + *, + max_tokens: int, + timeout: float, ) -> art.Trajectory: messages: art.Messages = [{"role": "user", "content": prompt}] - chat_completion = await client.chat.completions.create( - messages=messages, model=model_name, max_tokens=100, timeout=100 + completion = await client.chat.completions.create( + model=model.get_inference_name(), + messages=messages, + max_tokens=max_tokens, + timeout=timeout, + extra_body=_chat_completion_extra_body(model.base_model), + ) + choice = completion.choices[0] + answer = first_word(choice.message.content) + return art.Trajectory( + messages_and_choices=[*messages, choice], + reward=reward_for_answer(answer), + metadata={"answer": answer}, ) - choice = chat_completion.choices[0] - content = choice.message.content - assert isinstance(content, str) - if content == "yes": - reward = 0.5 - elif content == "no": - reward = 0.75 - elif content == "maybe": - reward = 1.0 - else: - reward = 0.0 - return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward) -def with_quotes(w: str) -> str: - return f"'{w}'" +async def evaluate( + client: openai.AsyncOpenAI, + model: art.TrainableModel, + prompts: list[str], + *, + max_tokens: int, + timeout: float, +) -> dict[str, float]: + groups = await art.gather_trajectory_groups( + art.TrajectoryGroup( + [rollout(client, model, prompt, max_tokens=max_tokens, timeout=timeout)] + ) + for prompt in prompts + ) + return summarize(groups) + +def build_internal_config() -> art.dev.InternalModelConfig: + trainer_gpu_ids = _get_env_int_list("TRAINER_GPU_IDS") + inference_gpu_ids = _get_env_int_list("INFERENCE_GPU_IDS") + rollout_weights_mode = os.environ.get("ROLLOUT_WEIGHTS_MODE") -async def main(): + internal_config = art.dev.InternalModelConfig( + engine_args=art.dev.EngineArgs( + gpu_memory_utilization=float( + os.environ.get("GPU_MEMORY_UTILIZATION", "0.8") + ), + max_model_len=int(os.environ.get("MAX_MODEL_LEN", "4096")), + max_num_seqs=int(os.environ.get("MAX_NUM_SEQS", "8")), + tensor_parallel_size=int(os.environ.get("TENSOR_PARALLEL_SIZE", "1")), + enforce_eager=_get_env_bool("ENFORCE_EAGER"), + ), + ) + max_seq_length = os.environ.get("MAX_SEQ_LENGTH") + if max_seq_length is not None: + init_args: art.dev.InitArgs = {"max_seq_length": int(max_seq_length)} + load_in_16bit = _get_env_bool("LOAD_IN_16BIT") + if load_in_16bit is not None: + init_args["load_in_16bit"] = load_in_16bit + load_in_4bit = _get_env_bool("LOAD_IN_4BIT") + if load_in_4bit is not None: + init_args["load_in_4bit"] = load_in_4bit + internal_config["init_args"] = init_args + if trainer_gpu_ids is not None: + assert inference_gpu_ids is not None + internal_config["trainer_gpu_ids"] = trainer_gpu_ids + internal_config["inference_gpu_ids"] = inference_gpu_ids + if rollout_weights_mode is not None: + internal_config["rollout_weights_mode"] = rollout_weights_mode + return internal_config + + +async def main() -> None: load_dotenv() - backend = MegatronBackend() base_model = os.environ.get("BASE_MODEL", "Qwen/Qwen3-30B-A3B-Instruct-2507") + project = os.environ.get("PROJECT", "yes-no-maybe-megatron") + model_name = os.environ.get("MODEL_NAME", f"megatron-ynm-{int(time.time())}") + num_steps = int(os.environ.get("NUM_STEPS", "20")) + rollouts_per_prompt = int(os.environ.get("ROLLOUTS_PER_PROMPT", "32")) + max_tokens = int(os.environ.get("MAX_TOKENS", "100")) + timeout = float(os.environ.get("TIMEOUT", "100")) + learning_rate = float(os.environ.get("LEARNING_RATE", "1e-4")) + packed_sequence_length = int( + os.environ.get( + "PACKED_SEQUENCE_LENGTH", + os.environ.get("MAX_SEQ_LENGTH", "4096"), + ) + ) + + backend = MegatronBackend() model = art.TrainableModel( - name=os.environ.get("MODEL_NAME", "megatron-001"), - project="yes-no-maybe-megatron", + name=model_name, + project=project, base_model=base_model, - _internal_config=art.dev.InternalModelConfig( - engine_args=art.dev.EngineArgs( - gpu_memory_utilization=0.8, - tensor_parallel_size=torch.cuda.device_count(), - ), - ), + report_metrics=[], + _internal_config=build_internal_config(), ) - await model.register(backend) - - prompts = [ - f"{prefix} with {', '.join([with_quotes(w) if use_quotes else w for w in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}" - for prefix in ["respond", "just respond"] - for use_quotes in [True, False] - for words in ( - list(p) for n in [3, 2] for p in permutations(["yes", "no", "maybe"], n) + prompts = build_prompts() + prompts = prompts[: int(os.environ.get("PROMPTS_LIMIT", str(len(prompts))))] + eval_prompts = prompts[: int(os.environ.get("EVAL_PROMPTS", "24"))] + + try: + print(json.dumps({"event": "register_start"}), flush=True) + await model.register(backend) + print( + json.dumps( + { + "event": "register_done", + "step": int(await model.get_step()), + "model": model.get_inference_name(), + } + ), + flush=True, ) - ] + client = model.openai_client() - openai_client = model.openai_client() - max_steps = int(os.environ.get("NUM_STEPS", "20")) - start_step = await model.get_step() + print( + json.dumps({"event": "eval_start", "step": int(await model.get_step())}), + flush=True, + ) + initial_eval = await evaluate( + client, + model, + eval_prompts, + max_tokens=max_tokens, + timeout=timeout, + ) + print( + json.dumps( + { + "event": "eval", + "step": int(await model.get_step()), + "model": model.get_inference_name(), + **initial_eval, + } + ), + flush=True, + ) - for step in range(start_step, start_step + max_steps): - print(f"\n=== Step {step + 1} ===") - train_groups = await art.gather_trajectory_groups( - ( + start_step = await model.get_step() + for offset in range(num_steps): + current_step = start_step + offset + print( + json.dumps( + { + "event": "rollout_start", + "step": current_step, + "model": model.get_inference_name(), + } + ), + flush=True, + ) + train_groups = await art.gather_trajectory_groups( art.TrajectoryGroup( - rollout(openai_client, model.get_inference_name(), prompt) - for _ in range(32) + rollout( + client, + model, + prompt, + max_tokens=max_tokens, + timeout=timeout, + ) + for _ in range(rollouts_per_prompt) ) for prompt in prompts ) - ) - await model.train( - train_groups, - config=art.TrainConfig(learning_rate=1e-4), - ) + train_summary = summarize(train_groups) + print( + json.dumps( + { + "event": "train_start", + "step": current_step, + "model": model.get_inference_name(), + **train_summary, + } + ), + flush=True, + ) + result = await backend.train( + model, + train_groups, + learning_rate=learning_rate, + packed_sequence_length=packed_sequence_length, + ) + print( + json.dumps( + { + "event": "train_step", + "step": result.step, + "model": model.get_inference_name(), + **train_summary, + "backend_metrics": result.metrics, + } + ), + flush=True, + ) + + eval_summary = await evaluate( + client, + model, + eval_prompts, + max_tokens=max_tokens, + timeout=timeout, + ) + print( + json.dumps( + { + "event": "eval", + "step": current_step + 1, + "model": model.get_inference_name(), + **eval_summary, + } + ), + flush=True, + ) + finally: + await backend.close() if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index f9804de9..e93c460e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ megatron = [ "transformer-engine-torch==2.11.0", "megatron-core==0.16.0rc0", "pybind11>=2.13.6", - "megatron-bridge", + "megatron-bridge @ git+https://github.com/NVIDIA-NeMo/Megatron-Bridge.git@e049cc00c24d03e2ae45d2608c7a44e2d2364e3d", "deep_ep @ git+https://github.com/deepseek-ai/DeepEP.git@v1.2.1 ; sys_platform == 'linux'", "nvidia-ml-py==13.580.82", "ml-dtypes>=0.5.0 ; python_full_version < '3.13'", @@ -241,5 +241,4 @@ dev = [ [tool.uv.sources] panza = { git = "https://github.com/corbt/panza.git" } apex = { git = "https://github.com/NVIDIA/apex.git", branch = "25.09" } -megatron-bridge = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "75f2c5ad4afb702b57b4781a00f5291a66bcf183" } transformer-engine-torch = { git = "https://github.com/NVIDIA/TransformerEngine.git", tag = "v2.11", subdirectory = "transformer_engine/pytorch" } diff --git a/src/art/local/backend.py b/src/art/local/backend.py index cb85bed4..9f82c33f 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -19,7 +19,6 @@ import aiohttp import numpy as np -from openai import AsyncOpenAI import polars as pl import torch from tqdm import auto as tqdm @@ -497,10 +496,6 @@ async def _monitor_openai_server( self, model: AnyTrainableModel, base_url: str, api_key: str ) -> None: model_name = model.name - openai_client = AsyncOpenAI( - base_url=base_url, - api_key=api_key, - ) consecutive_failures = 0 max_consecutive_failures = 3 async with aiohttp.ClientSession() as session: @@ -526,18 +521,22 @@ async def _monitor_openai_server( running_requests = int(float(line.split()[1])) elif line.startswith("vllm:num_requests_waiting"): pending_requests = int(float(line.split()[1])) - # If there are no running or pending requests, send a health check + # If there are no running or pending requests, send a cheap liveness + # probe rather than a real generation request. Large models can take + # longer than a short completion-based probe while still being healthy. if running_requests == 0 and pending_requests == 0: try: - # Send a health check with a short timeout - await openai_client.completions.create( - model=self._model_inference_name(model), - prompt="Hi", - max_tokens=1, + async with session.get( + f"{base_url.split('/v1')[0]}/health", timeout=float( os.environ.get("ART_SERVER_MONITOR_TIMEOUT", 5.0) ), - ) + ) as health_response: + if health_response.status >= 400: + raise RuntimeError( + "OpenAI server health check failed with " + f"status {health_response.status}" + ) except Exception as e: # If the server is sleeping, a failed health check is okay if await self._services[ diff --git a/src/art/megatron/bridge_adapter_compat.py b/src/art/megatron/bridge_adapter_compat.py new file mode 100644 index 00000000..5e267224 --- /dev/null +++ b/src/art/megatron/bridge_adapter_compat.py @@ -0,0 +1,310 @@ +import math + +from megatron.bridge.models.conversion.model_bridge import MegatronWeightTuple +from megatron.bridge.models.conversion.peft_bridge import AdapterWeight +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_layer import TransformerLayer +import torch + +from art.megatron.lora import ( + GatedDeltaNetInProjLoRA, + LoRA, + MLPExpertsLinearFC1LoRA, + MLPExpertsLinearFC2LoRA, + SelfAttentionLinearProjLoRA, + SelfAttentionLinearQKVLoRA, + SharedExpertsLinearFC1LoRA, + SharedExpertsLinearFC2LoRA, +) + + +def _is_language_transformer_layer_name(module_name: str) -> bool: + while module_name.startswith("module."): + module_name = module_name.removeprefix("module.") + return module_name.startswith(("decoder.layers.", "language_model.decoder.layers.")) + + +def _adapter_alpha_dim(lora: LoRA) -> tuple[int, int]: + dim = int(lora.A_T.shape[-1]) + alpha = float(lora.scale) * dim + rounded_alpha = round(alpha) + assert math.isclose(alpha, rounded_alpha) + return rounded_alpha, dim + + +def _adapter_tensors( + lora: LoRA, expert_idx: int | None = None +) -> tuple[torch.Tensor, torch.Tensor]: + a_t = lora.A_T if expert_idx is None else lora.A_T[expert_idx] + b_t = lora.B_T if expert_idx is None else lora.B_T[expert_idx] + return a_t.transpose(-1, -2).contiguous(), b_t.transpose(-1, -2).contiguous() + + +def _adapter_param_prefix(base_prefix: str, adapter_key: str | None) -> str: + if adapter_key is None: + return f"{base_prefix}.adapter" + return f"{base_prefix}.adapter.{adapter_key}" + + +def _adapter_weight( + *, + base_prefix: str, + adapter_key: str | None, + alpha: int, + dim: int, + linear_in: torch.Tensor, + linear_out: torch.Tensor, +) -> AdapterWeight: + param_prefix = _adapter_param_prefix(base_prefix, adapter_key) + return AdapterWeight( + global_base_prefix=base_prefix, + adapter_key=adapter_key, + alpha=alpha, + dim=dim, + linear_in_weight=MegatronWeightTuple( + param_name=f"{param_prefix}.linear_in.weight", + weight=linear_in, + vp_stage=0, + ), + linear_out_weight=MegatronWeightTuple( + param_name=f"{param_prefix}.linear_out.weight", + weight=linear_out, + vp_stage=0, + ), + ) + + +def _simple_adapter_weight( + base_prefix: str, + lora: LoRA, + *, + adapter_key: str | None = None, + expert_idx: int | None = None, +) -> AdapterWeight: + alpha, dim = _adapter_alpha_dim(lora) + linear_in, linear_out = _adapter_tensors(lora, expert_idx) + return _adapter_weight( + base_prefix=base_prefix, + adapter_key=adapter_key, + alpha=alpha, + dim=dim, + linear_in=linear_in, + linear_out=linear_out, + ) + + +def _fused_gdn_adapter_weight( + base_prefix: str, + handler: GatedDeltaNetInProjLoRA, +) -> AdapterWeight: + qkv_linear_in, qkv_linear_out = _adapter_tensors(handler.qkv_lora) + z_linear_in, z_linear_out = _adapter_tensors(handler.z_lora) + assert math.isclose(float(handler.qkv_lora.scale), float(handler.z_lora.scale)) + total_dim = int(qkv_linear_in.shape[0] + z_linear_in.shape[0]) + alpha = round(float(handler.qkv_lora.scale) * total_dim) + + qkv_rank = int(qkv_linear_in.shape[0]) + z_rank = int(z_linear_in.shape[0]) + qkv_out = int(qkv_linear_out.shape[0]) + z_out = int(z_linear_out.shape[0]) + beta_alpha_out = int(handler.num_value_heads_per_partition) + + qkv_padding = qkv_linear_out.new_zeros((qkv_out, z_rank)) + z_padding = z_linear_out.new_zeros((z_out, qkv_rank)) + zeros = qkv_linear_out.new_zeros((beta_alpha_out, total_dim)) + + return _adapter_weight( + base_prefix=base_prefix, + adapter_key=None, + alpha=alpha, + dim=total_dim, + linear_in=torch.cat([qkv_linear_in, z_linear_in], dim=0), + linear_out=torch.cat( + [ + torch.cat([qkv_linear_out, qkv_padding], dim=1), + torch.cat([z_padding, z_linear_out], dim=1), + zeros, + zeros.clone(), + ], + dim=0, + ), + ) + + +def _fused_pair_adapter_weight( + base_prefix: str, + first_lora: LoRA, + second_lora: LoRA, + *, + first_expert_idx: int | None = None, + second_expert_idx: int | None = None, +) -> AdapterWeight: + first_linear_in, first_linear_out = _adapter_tensors(first_lora, first_expert_idx) + second_linear_in, second_linear_out = _adapter_tensors( + second_lora, second_expert_idx + ) + assert math.isclose(float(first_lora.scale), float(second_lora.scale)) + total_dim = int(first_linear_in.shape[0] + second_linear_in.shape[0]) + alpha = round(float(first_lora.scale) * total_dim) + + first_rank = int(first_linear_in.shape[0]) + second_rank = int(second_linear_in.shape[0]) + first_out = int(first_linear_out.shape[0]) + second_out = int(second_linear_out.shape[0]) + + first_padding = first_linear_out.new_zeros((first_out, second_rank)) + second_padding = second_linear_out.new_zeros((second_out, first_rank)) + + return _adapter_weight( + base_prefix=base_prefix, + adapter_key=None, + alpha=alpha, + dim=total_dim, + linear_in=torch.cat([first_linear_in, second_linear_in], dim=0), + linear_out=torch.cat( + [ + torch.cat([first_linear_out, first_padding], dim=1), + torch.cat([second_padding, second_linear_out], dim=1), + ], + dim=0, + ), + ) + + +def build_adapter_weights_by_base( + model_chunks: list[MegatronModule], +) -> dict[str, list[AdapterWeight]]: + adapter_weights_by_base: dict[str, list[AdapterWeight]] = {} + for chunk in model_chunks: + for module_name, module in chunk.named_modules(): + if not isinstance(module, TransformerLayer): + continue + if not _is_language_transformer_layer_name(module_name): + continue + + layer_prefix = f"language_model.decoder.layers.{module.layer_number - 1}" + self_attention = module.self_attention + + linear_proj = getattr(self_attention, "linear_proj", None) + if isinstance(linear_proj, SelfAttentionLinearProjLoRA): + base_prefix = f"{layer_prefix}.self_attention.linear_proj" + adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _simple_adapter_weight(base_prefix, linear_proj.lora) + ] + + linear_qkv = getattr(self_attention, "linear_qkv", None) + if isinstance(linear_qkv, SelfAttentionLinearQKVLoRA): + base_prefix = f"{layer_prefix}.self_attention.linear_qkv" + adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _simple_adapter_weight( + base_prefix, linear_qkv.q_proj_lora, adapter_key="adapter_q" + ), + _simple_adapter_weight( + base_prefix, linear_qkv.k_proj_lora, adapter_key="adapter_k" + ), + _simple_adapter_weight( + base_prefix, linear_qkv.v_proj_lora, adapter_key="adapter_v" + ), + ] + + out_proj = getattr(self_attention, "out_proj", None) + if isinstance(out_proj, SelfAttentionLinearProjLoRA): + base_prefix = f"{layer_prefix}.self_attention.out_proj" + adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _simple_adapter_weight(base_prefix, out_proj.lora) + ] + + in_proj = getattr(self_attention, "in_proj", None) + if isinstance(in_proj, GatedDeltaNetInProjLoRA): + base_prefix = f"{layer_prefix}.self_attention.in_proj" + adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _fused_gdn_adapter_weight(base_prefix, in_proj) + ] + + experts = getattr(module.mlp, "experts", None) + if experts is not None: + if isinstance(experts.linear_fc1, MLPExpertsLinearFC1LoRA): + base_prefix = f"{layer_prefix}.mlp.experts.linear_fc1" + for local_expert_idx in range( + experts.linear_fc1.gate_lora.num_local_experts + ): + global_expert_idx = ( + local_expert_idx + + experts.linear_fc1.gate_lora._expert_offset + ) + adapter_weights_by_base[ + f"{base_prefix}.weight{global_expert_idx}" + ] = [ + _fused_pair_adapter_weight( + base_prefix, + experts.linear_fc1.gate_lora, + experts.linear_fc1.up_lora, + first_expert_idx=local_expert_idx, + second_expert_idx=local_expert_idx, + ) + ] + if isinstance(experts.linear_fc2, MLPExpertsLinearFC2LoRA): + base_prefix = f"{layer_prefix}.mlp.experts.linear_fc2" + for local_expert_idx in range( + experts.linear_fc2.lora.num_local_experts + ): + global_expert_idx = ( + local_expert_idx + experts.linear_fc2.lora._expert_offset + ) + adapter_weights_by_base[ + f"{base_prefix}.weight{global_expert_idx}" + ] = [ + _simple_adapter_weight( + base_prefix, + experts.linear_fc2.lora, + expert_idx=local_expert_idx, + ) + ] + else: + linear_fc1 = getattr(module.mlp, "linear_fc1", None) + if isinstance(linear_fc1, SharedExpertsLinearFC1LoRA): + base_prefix = f"{layer_prefix}.mlp.linear_fc1" + adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _simple_adapter_weight( + base_prefix, + linear_fc1.gate_lora, + adapter_key="adapter_gate", + ), + _simple_adapter_weight( + base_prefix, linear_fc1.up_lora, adapter_key="adapter_up" + ), + ] + linear_fc2 = getattr(module.mlp, "linear_fc2", None) + if isinstance(linear_fc2, SharedExpertsLinearFC2LoRA): + base_prefix = f"{layer_prefix}.mlp.linear_fc2" + adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _simple_adapter_weight( + base_prefix, linear_fc2.row_parallel_lora.lora + ) + ] + + shared_experts = getattr(module.mlp, "shared_experts", None) + if shared_experts is not None: + if isinstance(shared_experts.linear_fc1, SharedExpertsLinearFC1LoRA): + base_prefix = f"{layer_prefix}.mlp.shared_experts.linear_fc1" + adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _simple_adapter_weight( + base_prefix, + shared_experts.linear_fc1.gate_lora, + adapter_key="adapter_gate", + ), + _simple_adapter_weight( + base_prefix, + shared_experts.linear_fc1.up_lora, + adapter_key="adapter_up", + ), + ] + if isinstance(shared_experts.linear_fc2, SharedExpertsLinearFC2LoRA): + base_prefix = f"{layer_prefix}.mlp.shared_experts.linear_fc2" + adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _simple_adapter_weight( + base_prefix, + shared_experts.linear_fc2.row_parallel_lora.lora, + ) + ] + return adapter_weights_by_base diff --git a/src/art/megatron/lora.py b/src/art/megatron/lora.py index 5c4d1242..9fba022f 100644 --- a/src/art/megatron/lora.py +++ b/src/art/megatron/lora.py @@ -1,21 +1,24 @@ from collections.abc import Sequence import math -from typing import Any, Literal +from typing import Any, Literal, cast from megatron.bridge.models.gpt_provider import GPTModelProvider from megatron.core import parallel_state as ps from megatron.core.extensions.transformer_engine import ( TEColumnParallelGroupedLinear, + TEColumnParallelLinear, TELayerNormColumnParallelLinear, TERowParallelGroupedLinear, TERowParallelLinear, ) +from megatron.core.ssm.gated_delta_net import GatedDeltaNet from megatron.core.tensor_parallel.mappings import ( reduce_from_tensor_model_parallel_region, reduce_scatter_to_sequence_parallel_region, ) from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.moe.experts import TEGroupedMLP +from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.transformer_layer import TransformerLayer from pydantic import BaseModel, ConfigDict import torch @@ -95,6 +98,14 @@ def _normalize_axis(axis: int, ndim: int) -> int: return axis +def _linear_disables_tensor_parallel_comm(linear: Any) -> bool: + # Shared experts can keep TP-sharded weights while deferring TP comm to the + # overlap path by setting parallel_mode=None / explicit_expert_comm=True. + return getattr(linear, "parallel_mode", "") is None or getattr( + linear, "explicit_expert_comm", False + ) + + def _set_lora_parallel_metadata( param: torch.nn.Parameter, *, @@ -385,10 +396,12 @@ def __init__( rank: int, alpha: float, provider: GPTModelProvider, + reduce_output: bool = True, ) -> None: super().__init__() self.provider = provider self.linear_proj = linear_proj + self.reduce_output = reduce_output assert isinstance(linear_proj.weight, torch.Tensor) a_parallel_spec = LoRAParallelSpec( shard_domain="tp", @@ -424,7 +437,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: assert isinstance(bias_output, (torch.Tensor, type(None))) lora_output = self.lora(x) - if self.provider.tensor_model_parallel_size > 1: + if self.reduce_output and self.provider.tensor_model_parallel_size > 1: if self.provider.sequence_parallel: lora_output = reduce_scatter_to_sequence_parallel_region(lora_output) else: @@ -453,17 +466,31 @@ def __init__( raise ValueError( "num_attention_heads must be divisible by num_query_groups for QKV LoRA" ) - q_out_features = self.provider.kv_channels * self.provider.num_attention_heads + linear_qkv_weight = cast(torch.Tensor, linear_qkv.weight) + total_out_features_per_rank = linear_qkv_weight.shape[0] kv_out_features = self.provider.kv_channels * self.provider.num_query_groups tp_world_size = ps.get_tensor_model_parallel_world_size() assert kv_out_features % tp_world_size == 0, ( "kv_out_features must be divisible by tensor parallel size" ) + q_out_features = self.provider.kv_channels * self.provider.num_attention_heads assert q_out_features % tp_world_size == 0, ( "q_out_features must be divisible by tensor parallel size" ) q_out_features_per_rank = q_out_features // tp_world_size kv_out_features_per_rank = kv_out_features // tp_world_size + self.attention_output_gate = bool( + getattr(self.provider, "attention_output_gate", False) + ) + q_and_gate_out_features_per_rank = total_out_features_per_rank - ( + 2 * kv_out_features_per_rank + ) + expected_q_out_features_per_rank = q_out_features_per_rank * ( + 2 if self.attention_output_gate else 1 + ) + assert q_and_gate_out_features_per_rank == expected_q_out_features_per_rank, ( + "Unexpected per-rank QKV packing for this attention layout" + ) self.num_query_groups_per_partition = ( self.provider.num_query_groups // tp_world_size ) @@ -477,7 +504,7 @@ def __init__( linear_qkv=linear_qkv, rank=rank, alpha=alpha, - out_features=q_out_features_per_rank, + out_features=q_and_gate_out_features_per_rank, ) self.k_proj_lora = self._build_qkv_lora( adapter_model_prefix=f"{adapter_model_prefix}.k_proj", @@ -542,17 +569,17 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: assert isinstance(layernorm_output, torch.Tensor) assert isinstance(bias, (torch.Tensor, type(None))) - query = self.q_proj_lora(layernorm_output) + query_and_gate = self.q_proj_lora(layernorm_output) key = self.k_proj_lora(layernorm_output) value = self.v_proj_lora(layernorm_output) - # Match Megatron mixed_qkv layout: - # [S, B, nqg, (nah/nqg + 2), hn] where each query-group packs - # [all query heads for that group, key, value]. - query_5d = query.reshape( - query.shape[0], - query.shape[1], + # Qwen3 packs [query, key, value] per group, while Qwen3.5 packs + # [query, gate, key, value]. + query_and_gate_5d = query_and_gate.reshape( + query_and_gate.shape[0], + query_and_gate.shape[1], self.num_query_groups_per_partition, - self.num_attention_heads_per_group, + self.num_attention_heads_per_group + * (2 if self.attention_output_gate else 1), self.hidden_size_per_attention_head, ) key_5d = key.reshape( @@ -569,12 +596,109 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: 1, self.hidden_size_per_attention_head, ) - qkv_5d = torch.cat([query_5d, key_5d, value_5d], dim=3) + qkv_5d = torch.cat([query_and_gate_5d, key_5d, value_5d], dim=3) adapter_output = qkv_5d.reshape(qkv_5d.shape[0], qkv_5d.shape[1], -1) return linear_output + adapter_output, bias +class GatedDeltaNetInProjLoRA(torch.nn.Module): + def __init__( + self, + adapter_model_prefix: str, + in_proj: TELayerNormColumnParallelLinear, + gated_delta_net: GatedDeltaNet, + rank: int, + alpha: float, + ) -> None: + super().__init__() + in_proj.return_layernorm_output = True + in_proj.return_layernorm_output_gathered = True + self.in_proj = in_proj + self.num_value_heads_per_partition = ( + gated_delta_net.num_value_heads // ps.get_tensor_model_parallel_world_size() + ) + qkv_out_features_per_partition = ( + gated_delta_net.qk_dim * 2 + gated_delta_net.v_dim + ) // ps.get_tensor_model_parallel_world_size() + z_out_features_per_partition = ( + gated_delta_net.v_dim // ps.get_tensor_model_parallel_world_size() + ) + assert isinstance(in_proj.weight, torch.Tensor) + self.qkv_lora = self._build_in_proj_lora( + adapter_model_prefix=f"{adapter_model_prefix}.in_proj_qkv", + in_proj=in_proj, + rank=rank, + alpha=alpha, + out_features=qkv_out_features_per_partition, + ) + self.z_lora = self._build_in_proj_lora( + adapter_model_prefix=f"{adapter_model_prefix}.in_proj_z", + in_proj=in_proj, + rank=rank, + alpha=alpha, + out_features=z_out_features_per_partition, + ) + + @staticmethod + def _build_in_proj_lora( + *, + adapter_model_prefix: str, + in_proj: TELayerNormColumnParallelLinear, + rank: int, + alpha: float, + out_features: int, + ) -> LoRA: + assert isinstance(in_proj.weight, torch.Tensor) + a_parallel_spec = LoRAParallelSpec( + shard_domain="tp", + sharded=False, + shard_dim=None, + grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, + grad_sync_op=GRAD_SYNC_OP_SUM, + ) + b_parallel_spec = a_parallel_spec.model_copy( + update={ + "sharded": True, + "shard_dim": -1, + "grad_sync_op": GRAD_SYNC_OP_NONE, + } + ) + return LoRA( + adapter_model_prefix=adapter_model_prefix, + in_features=in_proj.in_features, + out_features=out_features, + rank=rank, + alpha=alpha, + dtype=in_proj.weight.dtype, + device=in_proj.weight.device, + a_parallel_spec=a_parallel_spec, + b_parallel_spec=b_parallel_spec, + allreduce=True, + ) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: + ( + linear_output_and_layernorm_output, + bias, + ) = self.in_proj(x) + linear_output, layernorm_output = linear_output_and_layernorm_output + assert isinstance(linear_output, torch.Tensor) + assert isinstance(layernorm_output, torch.Tensor) + assert isinstance(bias, (torch.Tensor, type(None))) + + qkv = self.qkv_lora(layernorm_output) + z = self.z_lora(layernorm_output) + beta = qkv.new_zeros( + qkv.shape[0], + qkv.shape[1], + self.num_value_heads_per_partition, + ) + alpha = beta.clone() + adapter_output = torch.cat([qkv, z, beta, alpha], dim=-1) + return linear_output + adapter_output, bias + + class MLPExpertsLinearFC1LoRA(torch.nn.Module): def __init__( self, @@ -667,6 +791,74 @@ def forward( return base_out + adapter_out, bias_out +class SharedExpertsLinearFC1LoRA(torch.nn.Module): + def __init__( + self, + adapter_model_prefix: str, + linear_fc1: TEColumnParallelLinear | TELayerNormColumnParallelLinear, + rank: int, + alpha: float, + ) -> None: + super().__init__() + self.linear_fc1 = linear_fc1 + self.gate_lora = self._build_fc1_lora( + adapter_model_prefix=f"{adapter_model_prefix}.gate_proj", + linear_fc1=linear_fc1, + rank=rank, + alpha=alpha, + ) + self.up_lora = self._build_fc1_lora( + adapter_model_prefix=f"{adapter_model_prefix}.up_proj", + linear_fc1=linear_fc1, + rank=rank, + alpha=alpha, + ) + + @staticmethod + def _build_fc1_lora( + *, + adapter_model_prefix: str, + linear_fc1: TEColumnParallelLinear | TELayerNormColumnParallelLinear, + rank: int, + alpha: float, + ) -> LoRA: + assert isinstance(linear_fc1.weight, torch.Tensor) + a_parallel_spec = LoRAParallelSpec( + shard_domain="tp", + sharded=False, + shard_dim=None, + grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, + grad_sync_op=GRAD_SYNC_OP_SUM, + ) + b_parallel_spec = a_parallel_spec.model_copy( + update={ + "sharded": True, + "shard_dim": -1, + "grad_sync_op": GRAD_SYNC_OP_NONE, + } + ) + return LoRA( + adapter_model_prefix=adapter_model_prefix, + in_features=linear_fc1.in_features, + out_features=linear_fc1.out_features // 2, + rank=rank, + alpha=alpha, + dtype=linear_fc1.weight.dtype, + device=linear_fc1.weight.device, + a_parallel_spec=a_parallel_spec, + b_parallel_spec=b_parallel_spec, + allreduce=True, + ) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: + base_out, bias_out = self.linear_fc1(x) + adapter_out = torch.cat( + [self.gate_lora(x), self.up_lora(x)], + dim=-1, + ) + return base_out + adapter_out, bias_out + + class MLPExpertsLinearFC2LoRA(torch.nn.Module): def __init__( self, @@ -720,11 +912,45 @@ def forward( return base_out + adapter_out, bias_out +class SharedExpertsLinearFC2LoRA(torch.nn.Module): + def __init__( + self, + adapter_model_prefix: str, + linear_fc2: TERowParallelLinear, + rank: int, + alpha: float, + provider: GPTModelProvider, + ) -> None: + super().__init__() + self.row_parallel_lora = SelfAttentionLinearProjLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.down_proj", + linear_proj=linear_fc2, + rank=rank, + alpha=alpha, + provider=provider, + reduce_output=not _linear_disables_tensor_parallel_comm(linear_fc2), + ) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: + return self.row_parallel_lora(x) + + def apply_lora_adapters( model: Sequence[torch.nn.Module], provider: GPTModelProvider, ) -> list[torch.nn.Module]: - def _unwrap_attr(value: Any, attr_name: str, expected_type: type[Any]) -> Any: + def _is_language_transformer_layer(module_name: str) -> bool: + while module_name.startswith("module."): + module_name = module_name.removeprefix("module.") + return module_name.startswith( + ("decoder.layers.", "language_model.decoder.layers.") + ) + + def _unwrap_attr( + value: Any, + attr_name: str, + expected_type: type[Any] | tuple[type[Any], ...], + ) -> Any: if isinstance(value, expected_type): return value unwrapped = getattr(value, attr_name) @@ -732,59 +958,143 @@ def _unwrap_attr(value: Any, attr_name: str, expected_type: type[Any]) -> Any: return unwrapped for chunk in model: - for module in chunk.modules(): + for module_name, module in chunk.named_modules(): if isinstance(module, TransformerLayer): + if not _is_language_transformer_layer(module_name): + continue adapter_model_prefix = ( f"base_model.model.model.layers.{module.layer_number - 1}" ) - assert isinstance(module.self_attention, SelfAttention) - self_attention_linear_proj = _unwrap_attr( - module.self_attention.linear_proj, - "linear_proj", - TERowParallelLinear, - ) - module.self_attention.linear_proj = SelfAttentionLinearProjLoRA( - adapter_model_prefix=f"{adapter_model_prefix}.self_attn.o_proj", - linear_proj=self_attention_linear_proj, - rank=LORA_RANK, - alpha=LORA_ALPHA, - provider=provider, - ) - self_attention_linear_qkv = _unwrap_attr( - module.self_attention.linear_qkv, - "linear_qkv", - TELayerNormColumnParallelLinear, - ) - module.self_attention.linear_qkv = SelfAttentionLinearQKVLoRA( - adapter_model_prefix=f"{adapter_model_prefix}.self_attn", - linear_qkv=self_attention_linear_qkv, - rank=LORA_RANK, - alpha=LORA_ALPHA, - provider=provider, - ) - assert isinstance(module.mlp.experts, TEGroupedMLP) - mlp_experts_linear_fc1 = _unwrap_attr( - module.mlp.experts.linear_fc1, - "linear_fc1", - TEColumnParallelGroupedLinear, # type: ignore[arg-type] - ) - module.mlp.experts.linear_fc1 = MLPExpertsLinearFC1LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", - linear_fc1=mlp_experts_linear_fc1, - rank=LORA_RANK, - alpha=LORA_ALPHA, - num_local_experts=module.mlp.experts.num_local_experts, - ) - mlp_experts_linear_fc2 = _unwrap_attr( - module.mlp.experts.linear_fc2, - "linear_fc2", - TERowParallelGroupedLinear, # type: ignore[arg-type] - ) - module.mlp.experts.linear_fc2 = MLPExpertsLinearFC2LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", - linear_fc2=mlp_experts_linear_fc2, - rank=LORA_RANK, - alpha=LORA_ALPHA, - num_local_experts=module.mlp.experts.num_local_experts, - ) + if isinstance(module.self_attention, SelfAttention): + self_attention_linear_proj = _unwrap_attr( + module.self_attention.linear_proj, + "linear_proj", + TERowParallelLinear, + ) + module.self_attention.linear_proj = SelfAttentionLinearProjLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.self_attn.o_proj", + linear_proj=self_attention_linear_proj, + rank=LORA_RANK, + alpha=LORA_ALPHA, + provider=provider, + ) + self_attention_linear_qkv = _unwrap_attr( + module.self_attention.linear_qkv, + "linear_qkv", + TELayerNormColumnParallelLinear, + ) + module.self_attention.linear_qkv = SelfAttentionLinearQKVLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.self_attn", + linear_qkv=self_attention_linear_qkv, + rank=LORA_RANK, + alpha=LORA_ALPHA, + provider=provider, + ) + elif isinstance(module.self_attention, GatedDeltaNet): + gated_delta_net_out_proj = _unwrap_attr( + module.self_attention.out_proj, + "out_proj", + TERowParallelLinear, + ) + module.self_attention.out_proj = SelfAttentionLinearProjLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.linear_attn.out_proj", + linear_proj=gated_delta_net_out_proj, + rank=LORA_RANK, + alpha=LORA_ALPHA, + provider=provider, + ) + gated_delta_net_in_proj = _unwrap_attr( + module.self_attention.in_proj, + "in_proj", + TELayerNormColumnParallelLinear, + ) + module.self_attention.in_proj = GatedDeltaNetInProjLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.linear_attn", + in_proj=gated_delta_net_in_proj, + gated_delta_net=module.self_attention, + rank=LORA_RANK, + alpha=LORA_ALPHA, + ) + else: + raise TypeError( + "Unsupported self_attention module type for Megatron LoRA: " + f"{type(module.self_attention)}" + ) + experts = getattr(module.mlp, "experts", None) + if experts is not None: + assert isinstance(experts, TEGroupedMLP) + mlp_experts_linear_fc1 = _unwrap_attr( + experts.linear_fc1, + "linear_fc1", + TEColumnParallelGroupedLinear, # type: ignore[arg-type] + ) + experts.linear_fc1 = MLPExpertsLinearFC1LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", + linear_fc1=mlp_experts_linear_fc1, + rank=LORA_RANK, + alpha=LORA_ALPHA, + num_local_experts=experts.num_local_experts, + ) + mlp_experts_linear_fc2 = _unwrap_attr( + experts.linear_fc2, + "linear_fc2", + TERowParallelGroupedLinear, # type: ignore[arg-type] + ) + experts.linear_fc2 = MLPExpertsLinearFC2LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", + linear_fc2=mlp_experts_linear_fc2, + rank=LORA_RANK, + alpha=LORA_ALPHA, + num_local_experts=experts.num_local_experts, + ) + else: + mlp_linear_fc1 = _unwrap_attr( + module.mlp.linear_fc1, + "linear_fc1", + (TEColumnParallelLinear, TELayerNormColumnParallelLinear), + ) + module.mlp.linear_fc1 = SharedExpertsLinearFC1LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp", + linear_fc1=mlp_linear_fc1, + rank=LORA_RANK, + alpha=LORA_ALPHA, + ) + mlp_linear_fc2 = _unwrap_attr( + module.mlp.linear_fc2, + "linear_fc2", + TERowParallelLinear, + ) + module.mlp.linear_fc2 = SharedExpertsLinearFC2LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp", + linear_fc2=mlp_linear_fc2, + rank=LORA_RANK, + alpha=LORA_ALPHA, + provider=provider, + ) + shared_experts = getattr(module.mlp, "shared_experts", None) + if shared_experts is not None: + assert isinstance(shared_experts, SharedExpertMLP) + shared_experts_linear_fc1 = _unwrap_attr( + shared_experts.linear_fc1, + "linear_fc1", + (TEColumnParallelLinear, TELayerNormColumnParallelLinear), + ) + shared_experts.linear_fc1 = SharedExpertsLinearFC1LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp.shared_expert", + linear_fc1=shared_experts_linear_fc1, + rank=LORA_RANK, + alpha=LORA_ALPHA, + ) + shared_experts_linear_fc2 = _unwrap_attr( + shared_experts.linear_fc2, + "linear_fc2", + TERowParallelLinear, + ) + shared_experts.linear_fc2 = SharedExpertsLinearFC2LoRA( + adapter_model_prefix=f"{adapter_model_prefix}.mlp.shared_expert", + linear_fc2=shared_experts_linear_fc2, + rank=LORA_RANK, + alpha=LORA_ALPHA, + provider=provider, + ) return list(model) diff --git a/src/art/megatron/provider.py b/src/art/megatron/provider.py index eef8679a..4b23f02c 100644 --- a/src/art/megatron/provider.py +++ b/src/art/megatron/provider.py @@ -1,9 +1,11 @@ import copy +from dataclasses import dataclass import inspect import json import os from pathlib import Path -from typing import Callable, Literal, cast +from types import MethodType +from typing import Any, Callable, Literal, cast import warnings from megatron.bridge import AutoBridge @@ -14,11 +16,24 @@ StateSource, ) from megatron.bridge.models.qwen.qwen3_moe_bridge import Qwen3MoEBridge +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.attention import ( + Qwen3VLSelfAttention, +) +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model import Qwen3VLModel +from megatron.bridge.models.qwen_vl.qwen35_vl_bridge import Qwen35VLMoEBridge +from megatron.bridge.models.qwen_vl.qwen35_vl_provider import ( + Qwen35VLMoEModelProvider, + _patch_standard_attention_specs, +) from megatron.bridge.training.flex_dispatcher_backend import ( apply_flex_dispatcher_backend, ) +from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( + get_transformer_block_with_experimental_attention_variant_spec, +) from megatron.core.transformer.enums import AttnBackend from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_block import TransformerBlockSubmodules import torch from art.megatron.flex_attention import FlexDotProductAttention @@ -26,6 +41,12 @@ _finalized_env_settings_printed = False +@dataclass(frozen=True) +class ProviderBundle: + provider: GPTModelProvider + bridge: Any + + def _resolve_layer_spec( base_layer_spec: ModuleSpec | Callable[[GPTModelProvider], ModuleSpec], config: GPTModelProvider, @@ -41,6 +62,17 @@ def _resolve_layer_spec( return base_layer_spec(config, **kwargs) +def _patch_core_attention(layer_spec: object) -> None: + submodules = getattr(layer_spec, "submodules", None) + self_attention = getattr(submodules, "self_attention", None) + attention_submodules = getattr(self_attention, "submodules", None) + if attention_submodules is None or not hasattr( + attention_submodules, "core_attention" + ): + return + attention_submodules.core_attention = FlexDotProductAttention + + class _CastingStateSource(StateSource): def __init__(self, source: StateSource, *, dtype: torch.dtype): self._source = source @@ -305,18 +337,18 @@ def _apply_runtime_env_overrides(provider: GPTModelProvider) -> None: provider.recompute_granularity = None -def get_provider( +def get_provider_bundle( model: str, *, torch_dtype: torch.dtype = torch.bfloat16, -) -> GPTModelProvider: +) -> ProviderBundle: bridge = AutoBridge.from_hf_pretrained( model, dtype=torch_dtype, trust_remote_code=True, ) - assert isinstance(bridge._model_bridge, Qwen3MoEBridge), ( - "Only Qwen3 MoE models are supported" + assert isinstance(bridge._model_bridge, (Qwen3MoEBridge, Qwen35VLMoEBridge)), ( + "Only Qwen3 and Qwen3.5 MoE models are supported" ) if torch_dtype != torch.bfloat16: model_name_or_path = bridge.hf_pretrained.model_name_or_path @@ -328,17 +360,85 @@ def get_provider( ) ) provider = bridge.to_megatron_provider() - setattr(provider, "art_bridge", bridge) + if isinstance(provider, Qwen35VLMoEModelProvider): + from megatron.bridge.models.gpt_provider import mtp_block_spec + + def _patch_qwen35_block_spec(block_spec: TransformerBlockSubmodules) -> None: + _patch_standard_attention_specs(block_spec, Qwen3VLSelfAttention) + layer_specs = block_spec.layer_specs + if layer_specs is None: + return + for layer_spec in layer_specs: + _patch_core_attention(layer_spec) + + def _qwen35_layer_spec( + config: GPTModelProvider, vp_stage: int | None = None + ) -> ModuleSpec: + block_spec = get_transformer_block_with_experimental_attention_variant_spec( + config, + vp_stage=vp_stage, + ) + _patch_qwen35_block_spec(block_spec) + return cast(ModuleSpec, block_spec) + + provider.transformer_layer_spec = _qwen35_layer_spec + + def _provide_qwen35_with_flex_attention( + self: Qwen35VLMoEModelProvider, + pre_process: bool | None = None, + post_process: bool | None = None, + vp_stage: int | None = None, + ) -> Qwen3VLModel: + language_transformer_config = cast(Any, self) + hf_vision_config = self.vision_config + hf_vision_config.torch_dtype = self.params_dtype + block_spec = cast( + ModuleSpec, + get_transformer_block_with_experimental_attention_variant_spec( + language_transformer_config, + vp_stage=vp_stage, + ), + ) + _patch_qwen35_block_spec(cast(TransformerBlockSubmodules, block_spec)) + pre_process_value = True if pre_process is None else pre_process + post_process_value = True if post_process is None else post_process + model = Qwen3VLModel( + language_transformer_config=language_transformer_config, + language_transformer_layer_spec=block_spec, + vision_transformer_config=hf_vision_config, + pre_process=pre_process_value, + post_process=post_process_value, + pg_collection=cast(Any, self._pg_collection), + mtp_block_spec=mtp_block_spec(self, vp_stage=vp_stage), + vp_stage=vp_stage, + ) + if ( + self.freeze_language_model + or self.freeze_vision_model + or self.freeze_vision_projection + ): + model.freeze( + freeze_language_model=self.freeze_language_model, + freeze_vision_model=self.freeze_vision_model, + freeze_vision_projection=self.freeze_vision_projection, + ) + return model + + provider.provide = cast( + Any, MethodType(_provide_qwen35_with_flex_attention, provider) + ) base_layer_spec = provider.transformer_layer_spec def _flex_attention_layer_spec( config: GPTModelProvider, vp_stage: int | None = None ) -> ModuleSpec: layer_spec = _resolve_layer_spec(base_layer_spec, config, vp_stage) - # Keep Megatron's standard layer stack and replace only core attention. - layer_spec.submodules.self_attention.submodules.core_attention = ( # ty: ignore[unresolved-attribute] - FlexDotProductAttention - ) + layer_specs = getattr(layer_spec, "layer_specs", None) + if layer_specs is None: + _patch_core_attention(layer_spec) + else: + for block_layer_spec in layer_specs: + _patch_core_attention(block_layer_spec) return layer_spec provider.transformer_layer_spec = _flex_attention_layer_spec @@ -361,6 +461,17 @@ def _flex_attention_layer_spec( # effectively just a flag modifying finalize_model_grads behavior for DPxCP provider.calculate_per_token_loss = True provider.sequence_parallel = provider.tensor_model_parallel_size > 1 + # ART computes its own RL loss, so MTP only adds incompatible postprocess work. + provider.mtp_enabled = False + provider.mtp_num_layers = 0 _maybe_print_finalized_env_settings(provider) provider.finalize() - return provider + return ProviderBundle(provider=provider, bridge=bridge) + + +def get_provider( + model: str, + *, + torch_dtype: torch.dtype = torch.bfloat16, +) -> GPTModelProvider: + return get_provider_bundle(model, torch_dtype=torch_dtype).provider diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index 058fcabd..a1ce338d 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -76,9 +76,15 @@ def create_identity_lora( if random_state is not None: torch.manual_seed(random_state) base_config = AutoConfig.from_pretrained(base_model, trust_remote_code=True) + model_config = base_config + nested_text_config = getattr(base_config, "text_config", None) + if not hasattr(base_config, "vocab_size") and hasattr( + nested_text_config, "vocab_size" + ): + model_config = nested_text_config with init_empty_weights(): model = AutoModelForCausalLM.from_config( - base_config, torch_dtype=torch.bfloat16, trust_remote_code=True + model_config, torch_dtype=torch.bfloat16, trust_remote_code=True ) model.name_or_path = base_model @@ -96,8 +102,14 @@ def create_identity_lora( "k_proj.weight", "v_proj.weight", "o_proj.weight", + "linear_attn.in_proj_qkv.weight", + "linear_attn.in_proj_z.weight", + "linear_attn.out_proj.weight", "mlp.experts.gate_up_proj", "mlp.experts.down_proj", + "mlp.shared_expert.gate_proj.weight", + "mlp.shared_expert.up_proj.weight", + "mlp.shared_expert.down_proj.weight", ) ) ], diff --git a/src/art/megatron/setup.sh b/src/art/megatron/setup.sh index dcd6ce09..def33024 100755 --- a/src/art/megatron/setup.sh +++ b/src/art/megatron/setup.sh @@ -12,4 +12,9 @@ apt-get install -y libcudnn9-headers-cuda-12 libibverbs-dev ninja-build script_dir="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" repo_root="$(cd -- "${script_dir}/../../.." && pwd)" cd "${repo_root}" -uv sync --extra backend --extra megatron --frozen --active +uv_bin="${HOME}/.local/bin/uv" +if [[ -x "${uv_bin}" ]]; then + "${uv_bin}" sync --extra backend --extra megatron --frozen --active +else + uv sync --extra backend --extra megatron --frozen --active +fi diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index f5ab2d8a..64a55b19 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -13,6 +13,7 @@ """ from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass import gc import importlib import json @@ -33,9 +34,12 @@ from pydantic import BaseModel, ConfigDict, field_validator import torch from torch._inductor.runtime.cache_dir_utils import cache_dir as inductor_cache_dir +from torch.distributed import all_reduce from art import dev, types +from art.dev.validate import QWEN3_5_MOE_MODELS from art.loss import loss_fn, shift_tensor +from art.megatron.bridge_adapter_compat import build_adapter_weights_by_base from art.megatron.compile_workarounds import install_torch_compile_workarounds from art.megatron.finalize_grads import finalize_model_grads_extended from art.megatron.flex_attention import create_shared_prefix_attention_state @@ -51,11 +55,6 @@ MergedWeightTransferSpec, ) from art.megatron.lora import ( - LoRA, - MLPExpertsLinearFC1LoRA, - MLPExpertsLinearFC2LoRA, - SelfAttentionLinearProjLoRA, - SelfAttentionLinearQKVLoRA, apply_lora_adapters, ) from art.megatron.merge import load_lora_adapter_state_dict, merge_lora_adapter @@ -70,7 +69,7 @@ offload_to_cpu, reload_to_gpu, ) -from art.megatron.provider import _env_flag, get_provider +from art.megatron.provider import _env_flag, get_provider_bundle from art.megatron.routing_replay import ( MoeRoutingReplayBundle, MoeRoutingReplayController, @@ -106,6 +105,7 @@ class TrainingRuntime(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) provider: Any + bridge: Any model: ModelChunks optimizer: Any | None optimizer_config: OptimizerConfig @@ -133,6 +133,15 @@ class TrainStepResult(BaseModel): num_zeros_in_grad: int | None +@dataclass +class MergedWeightExport: + bridge: Any + model: list[MegatronModule] + model_config: Any + conversion_tasks: list[Any] + adapter_weights_by_base: dict[str, list[Any]] + + def print0(rank: int, *values: Any) -> None: if rank == 0: print(*values) @@ -169,10 +178,7 @@ def _fast_backward( (weight,) = ctx.saved_tensors grad_input = _frozen_linear_grad_input(grad_output, weight) if ctx.allreduce_dgrad: - torch.distributed.all_reduce( # ty: ignore[possibly-missing-attribute] - grad_input, - group=ctx.tp_group, - ) + all_reduce(grad_input, group=ctx.tp_group) return grad_input, None, None, None, None setattr(_fast_backward, "__art_fast_output_backward__", True) @@ -191,9 +197,11 @@ def _eager_initialize_optimizer_state(optimizer: Any) -> None: init_state_fn(inner_optimizer, getattr(optimizer, "config", None)) -def _compile_enabled() -> bool: +def _compile_enabled(model_identifier: str) -> bool: disabled = _env_flag("ART_DISABLE_MEGATRON_COMPILE") - return disabled is not True + if disabled is not None: + return disabled is not True + return model_identifier not in QWEN3_5_MOE_MODELS def _install_gpt_preprocess_hook(model_chunks: ModelChunks) -> None: @@ -201,6 +209,8 @@ def _install_gpt_preprocess_hook(model_chunks: ModelChunks) -> None: module: Any = unwrap_megatron_chunk(chunk) while not isinstance(module, GPTModel) and hasattr(module, "module"): module = module.module + if not isinstance(module, GPTModel): + module = getattr(module, "language_model", None) if not isinstance(module, GPTModel): continue preprocess = module._preprocess @@ -212,6 +222,8 @@ def preprocess_hook(*args, _preprocess=preprocess, **kwargs): embedding_dim = table.size(-1) table_flat = table.view(table.size(0), embedding_dim) position_ids = kwargs["position_ids"] # [B, S] + if position_ids.ndim != 2: + return tuple(preproc_output) batch_size, sequence_length = position_ids.shape gathered = table_flat.index_select(0, position_ids.reshape(-1)) gathered = ( @@ -318,6 +330,9 @@ def build_training_runtime( print_env: bool = True, build_optimizer: bool = True, ) -> TrainingRuntime: + resolved_model_identifier = model_identifier or os.environ.get( + "MODEL_IDENTIFIER", DEFAULT_MODEL_IDENTIFIER + ) if random_state := os.environ.get("ART_MEGATRON_RANDOM_STATE"): seed = int(random_state) random.seed(seed) @@ -325,11 +340,11 @@ def build_training_runtime( if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) _install_fast_frozen_output_backward() - provider = get_provider( - model_identifier - or os.environ.get("MODEL_IDENTIFIER", DEFAULT_MODEL_IDENTIFIER), + provider_bundle = get_provider_bundle( + resolved_model_identifier, torch_dtype=provider_torch_dtype, ) + provider = provider_bundle.provider if provider_configure is not None: provider_configure(provider) provider.register_pre_wrap_hook(freeze_model) @@ -362,10 +377,22 @@ def build_training_runtime( print("TRITON_CACHE_DIR:", os.environ["TRITON_CACHE_DIR"]) _install_gpt_preprocess_hook(model) - if _compile_enabled(): + if _compile_enabled(resolved_model_identifier): + if rank == 0: + print("Enabling torch.compile for", resolved_model_identifier) install_torch_compile_workarounds() for chunk in model: _compile_transformer_layers(chunk) + elif ( + rank == 0 + and _env_flag("ART_DISABLE_MEGATRON_COMPILE") is None + and resolved_model_identifier in QWEN3_5_MOE_MODELS + ): + print( + "Disabling torch.compile for", + resolved_model_identifier, + "because Qwen3.5 MoE currently fails in PyTorch compiled backward stream ops.", + ) optimizer_config = optimizer_config or _default_optimizer_config() optimizer = ( @@ -379,6 +406,7 @@ def build_training_runtime( runtime = TrainingRuntime( provider=provider, + bridge=provider_bundle.bridge, model=model, optimizer=optimizer, optimizer_config=optimizer_config, @@ -1075,6 +1103,18 @@ def _move_inputs_to_device(inputs: PackedTensors, device: torch.device) -> None: inputs[key] = value.to(device) # type: ignore[index] +def _attention_block_kwargs( + model_chunk: torch.nn.Module, + attention_state: Any, +) -> dict[str, Any]: + model = model_chunk + while hasattr(model, "module"): + model = model.module # type: ignore[assignment] + if type(model).__name__ == "Qwen3VLModel": + return {"extra_block_kwargs": {"attention_bias": attention_state}} + return {"attention_bias": attention_state} + + def _optimizer_step( optimizer: Any, learning_rate: float, @@ -1304,13 +1344,17 @@ def run_training_step( ) attention_mask = torch.zeros((1, 1, 1, 1), dtype=torch.bool, device=device) - new_logprobs = -model_chunks[0]( + model_output = model_chunks[0]( input_ids=micro["tokens"], position_ids=micro["input_pos"], attention_mask=attention_mask, labels=shift_tensor(micro["tokens"], 0), - extra_block_kwargs={"attention_bias": attention_state}, + extra_block_kwargs=_attention_block_kwargs( + model_chunks[0], + attention_state, + ), ) + new_logprobs = -model_output loss_info = loss_fn( micro, # ty: ignore[invalid-argument-type] @@ -1377,13 +1421,15 @@ def _is_art_adapter_param_name(name: str) -> bool: ".q_proj_lora.", ".k_proj_lora.", ".v_proj_lora.", + ".qkv_lora.", + ".z_lora.", ".gate_lora.", ".up_lora.", ) ) -def _unwrap_art_wrapper_name(name: str) -> str: +def _canonical_art_param_name(name: str) -> str: while name.startswith("module."): name = name[len("module.") :] while name.startswith("_orig_mod."): @@ -1392,14 +1438,41 @@ def _unwrap_art_wrapper_name(name: str) -> str: name = name.replace("._orig_mod.", ".") if name.endswith("._orig_mod"): name = name[: -len("._orig_mod")] - for wrapped, unwrapped in ( - (".linear_proj.linear_proj.", ".linear_proj."), - (".linear_qkv.linear_qkv.", ".linear_qkv."), - (".linear_fc1.linear_fc1.", ".linear_fc1."), - (".linear_fc2.linear_fc2.", ".linear_fc2."), - ): - name = name.replace(wrapped, unwrapped) - return name + segments = name.split(".") + canonical: list[str] = [] + i = 0 + while i < len(segments): + if i + 1 < len(segments): + current = segments[i] + nxt = segments[i + 1] + if ( + current + in { + "linear_proj", + "linear_qkv", + "in_proj", + "linear_fc1", + "linear_fc2", + } + and nxt == current + ): + canonical.append(current) + i += 2 + continue + if current == "out_proj" and nxt == "linear_proj": + canonical.append(current) + i += 2 + continue + if current == "row_parallel_lora" and nxt == "linear_proj": + i += 2 + continue + canonical.append(segments[i]) + i += 1 + return ".".join(canonical) + + +def _unwrap_art_wrapper_name(name: str) -> str: + return _canonical_art_param_name(name) def _mapping_hf_weights_exist(mapping: Any, hf_keys: set[str]) -> bool: @@ -1408,183 +1481,9 @@ def _mapping_hf_weights_exist(mapping: Any, hf_keys: set[str]) -> bool: hf_param = mapping.hf_param if isinstance(hf_param, str): return hf_param in hf_keys - assert isinstance(hf_param, dict) - return all(param in hf_keys for param in hf_param.values()) - - -def _lora_delta(lora: LoRA, expert_idx: int | None = None) -> torch.Tensor: - if lora.A_T.ndim == 3: - assert expert_idx is not None - a_t = lora.A_T[expert_idx] - b_t = lora.B_T[expert_idx] - else: - a_t = lora.A_T - b_t = lora.B_T - return (b_t.T @ a_t.T) * lora.scale - - -def _expert_index_from_hf_name(hf_name: str) -> int: - match = re.search(r"\.experts\.(\d+)\.", hf_name) - assert match is not None - return int(match.group(1)) - - -def _hf_name_has_indexed_expert(hf_name: str) -> bool: - return re.search(r"\.experts\.(\d+)\.", hf_name) is not None - - -def _stack_moe_fc1_deltas(handler: MLPExpertsLinearFC1LoRA) -> torch.Tensor: - return torch.stack( - [ - torch.cat( - [ - _lora_delta(handler.gate_lora, expert_idx), - _lora_delta(handler.up_lora, expert_idx), - ], - dim=0, - ) - for expert_idx in range(handler.gate_lora.num_local_experts) - ], - dim=0, - ) - - -def _stack_moe_fc2_deltas(handler: MLPExpertsLinearFC2LoRA) -> torch.Tensor: - return torch.stack( - [ - _lora_delta(handler.lora, expert_idx) - for expert_idx in range(handler.lora.num_local_experts) - ], - dim=0, - ) - - -def _merge_delta_into_weight( - hf_name: str, - base_weight: torch.Tensor, - delta: torch.Tensor, -) -> torch.Tensor: - delta = delta.to(device=base_weight.device, dtype=base_weight.dtype) - if tuple(base_weight.shape) == tuple(delta.shape): - return base_weight + delta - transposed = delta.transpose(-1, -2) - assert tuple(base_weight.shape) == tuple(transposed.shape), ( - f"{hf_name}: cannot merge delta {tuple(delta.shape)} into {tuple(base_weight.shape)}" - ) - return base_weight + transposed - - -def _build_art_merge_handlers( - model_chunks: list[MegatronModule], -) -> tuple[dict[str, Any], dict[str, Any]]: - exact_handlers: dict[str, Any] = {} - prefix_handlers: dict[str, Any] = {} - for module_name, module in iter_named_modules(model_chunks): - if not isinstance(module, TransformerLayer): - continue - if not _is_language_transformer_layer_name(module_name): - continue - prefixes = ( - f"decoder.layers.{module.layer_number - 1}", - f"language_model.decoder.layers.{module.layer_number - 1}", - ) - linear_proj = getattr(module.self_attention, "linear_proj", None) - if isinstance(linear_proj, SelfAttentionLinearProjLoRA): - for prefix in prefixes: - exact_handlers[f"{prefix}.self_attention.linear_proj.weight"] = ( - linear_proj - ) - linear_qkv = getattr(module.self_attention, "linear_qkv", None) - if isinstance(linear_qkv, SelfAttentionLinearQKVLoRA): - for prefix in prefixes: - exact_handlers[f"{prefix}.self_attention.linear_qkv.weight"] = ( - linear_qkv - ) - experts = getattr(module.mlp, "experts", None) - if experts is None: - continue - if isinstance(experts.linear_fc1, MLPExpertsLinearFC1LoRA): - for prefix in prefixes: - prefix_handlers[f"{prefix}.mlp.experts.linear_fc1.weight"] = ( - experts.linear_fc1 - ) - if isinstance(experts.linear_fc2, MLPExpertsLinearFC2LoRA): - for prefix in prefixes: - prefix_handlers[f"{prefix}.mlp.experts.linear_fc2.weight"] = ( - experts.linear_fc2 - ) - return exact_handlers, prefix_handlers - - -def _merge_art_lora_into_hf_weights( - global_param_name: str, - converted_weights_dict: dict[str, torch.Tensor], - *, - exact_handlers: dict[str, Any], - prefix_handlers: dict[str, Any], -) -> dict[str, torch.Tensor]: - handler = exact_handlers.get(global_param_name) - if handler is None: - for prefix, prefix_handler in prefix_handlers.items(): - if global_param_name.startswith(prefix): - handler = prefix_handler - break - if handler is None: - return converted_weights_dict - if isinstance(handler, SelfAttentionLinearProjLoRA): - hf_name, base_weight = next(iter(converted_weights_dict.items())) - converted_weights_dict[hf_name] = _merge_delta_into_weight( - hf_name, - base_weight, - _lora_delta(handler.lora), - ) - return converted_weights_dict - if isinstance(handler, SelfAttentionLinearQKVLoRA): - deltas = { - "q_proj": _lora_delta(handler.q_proj_lora), - "k_proj": _lora_delta(handler.k_proj_lora), - "v_proj": _lora_delta(handler.v_proj_lora), - } - for hf_name, base_weight in list(converted_weights_dict.items()): - for projection, delta in deltas.items(): - if projection in hf_name: - converted_weights_dict[hf_name] = _merge_delta_into_weight( - hf_name, - base_weight, - delta, - ) - break - return converted_weights_dict - if isinstance(handler, MLPExpertsLinearFC1LoRA): - for hf_name, base_weight in list(converted_weights_dict.items()): - if _hf_name_has_indexed_expert(hf_name): - expert_idx = _expert_index_from_hf_name(hf_name) - if ".gate_proj." in hf_name: - delta = _lora_delta(handler.gate_lora, expert_idx) - else: - assert ".up_proj." in hf_name, hf_name - delta = _lora_delta(handler.up_lora, expert_idx) - else: - delta = _stack_moe_fc1_deltas(handler) - converted_weights_dict[hf_name] = _merge_delta_into_weight( - hf_name, - base_weight, - delta, - ) - return converted_weights_dict - assert isinstance(handler, MLPExpertsLinearFC2LoRA) - for hf_name, base_weight in list(converted_weights_dict.items()): - delta = ( - _lora_delta(handler.lora, _expert_index_from_hf_name(hf_name)) - if _hf_name_has_indexed_expert(hf_name) - else _stack_moe_fc2_deltas(handler) - ) - converted_weights_dict[hf_name] = _merge_delta_into_weight( - hf_name, - base_weight, - delta, - ) - return converted_weights_dict + if isinstance(hf_param, dict): + return all(param in hf_keys for param in hf_param.values()) + return False def _build_art_conversion_tasks(runtime: TrainingRuntime) -> list[Any]: @@ -1599,8 +1498,7 @@ def _build_art_conversion_tasks(runtime: TrainingRuntime) -> list[Any]: persistent_buffers, ) - bridge = getattr(runtime.provider, "art_bridge", None) - assert bridge is not None + bridge = runtime.bridge mapping_registry = bridge._model_bridge.mapping_registry() hf_source = bridge.hf_pretrained.state.source hf_keys = set(hf_source.get_all_keys()) @@ -1614,7 +1512,7 @@ def _build_art_conversion_tasks(runtime: TrainingRuntime) -> list[Any]: global_name = _megatron_local_name_to_global( megatron_chunks, model_config, - _unwrap_art_wrapper_name(local_name), + _canonical_art_param_name(local_name), vp_stage, ) mapping = mapping_registry.megatron_to_hf_lookup(global_name) @@ -1643,31 +1541,55 @@ def _build_art_conversion_tasks(runtime: TrainingRuntime) -> list[Any]: return tasks -def _iter_merged_vllm_weights(runtime: TrainingRuntime) -> Any: +def _build_merged_weight_export(runtime: TrainingRuntime) -> MergedWeightExport: + megatron_chunks = as_megatron_api_chunks(runtime.model) + return MergedWeightExport( + bridge=runtime.bridge, + model=megatron_chunks, + model_config=megatron_chunks[0].config, + conversion_tasks=_build_art_conversion_tasks(runtime), + adapter_weights_by_base=build_adapter_weights_by_base(megatron_chunks), + ) + + +def _iter_merged_vllm_weights(weight_export: MergedWeightExport) -> Any: # vLLM expects HF checkpoint names, but Megatron only has live trainer weights. # Convert through Bridge here, then merge ART's LoRA deltas into those tensors. - bridge = getattr(runtime.provider, "art_bridge", None) - assert bridge is not None + bridge = weight_export.bridge model_bridge = bridge._model_bridge hf_state_dict = bridge.hf_pretrained.state - megatron_chunks = as_megatron_api_chunks(runtime.model) - exact_handlers, prefix_handlers = _build_art_merge_handlers(megatron_chunks) - for task in _build_art_conversion_tasks(runtime): + grouped_buffers: dict[str, dict[int, torch.Tensor]] = {} + for task in weight_export.conversion_tasks: converted_weights_dict = task.mapping.megatron_to_hf( task.param_weight, task.megatron_module, ) - converted_weights_dict = model_bridge.maybe_modify_converted_hf_weight( - task, - converted_weights_dict, - hf_state_dict, - ) - converted_weights_dict = _merge_art_lora_into_hf_weights( - task.global_param_name, - converted_weights_dict, - exact_handlers=exact_handlers, - prefix_handlers=prefix_handlers, + adapter_weights = weight_export.adapter_weights_by_base.get( + task.global_param_name ) + if adapter_weights is not None: + converted_weights_dict = model_bridge._merge_lora_adapter_weights( + weight_export.model, + converted_weights_dict, + adapter_weights, + ) + if getattr(task.mapping, "is_grouped_export", False): + merged_result = model_bridge._accumulate_grouped_export( + task, + converted_weights_dict, + weight_export.model_config, + grouped_buffers, + hf_state_dict, + ) + if merged_result is None: + continue + converted_weights_dict = merged_result + else: + converted_weights_dict = model_bridge.maybe_modify_converted_hf_weight( + task, + converted_weights_dict, + hf_state_dict, + ) for hf_name, tensor in converted_weights_dict.items(): yield hf_name, tensor @@ -1719,10 +1641,11 @@ def _sync_merged_weights_to_vllm( from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferEngine _ensure_merged_weight_transfer_group(runtime, spec) + weight_export = _build_merged_weight_export(runtime) def _send_weights() -> None: NCCLWeightTransferEngine.trainer_send_weights( - _iter_merged_vllm_weights(runtime), + _iter_merged_vllm_weights(weight_export), {"group": runtime.merged_weight_transfer_group}, ) @@ -1739,7 +1662,7 @@ def _send_weights() -> None: names: list[str] = [] dtype_names: list[str] = [] shapes: list[list[int]] = [] - for name, tensor in _iter_merged_vllm_weights(runtime): + for name, tensor in _iter_merged_vllm_weights(weight_export): names.append(name) dtype_names.append(str(tensor.dtype).removeprefix("torch.")) shapes.append(list(tensor.shape)) diff --git a/tests/integration/test_megatron_provider_support.py b/tests/integration/test_megatron_provider_support.py new file mode 100644 index 00000000..e6030c90 --- /dev/null +++ b/tests/integration/test_megatron_provider_support.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast + +import pytest + +pytest.importorskip("megatron.bridge") +pytest.importorskip("megatron.bridge.models.qwen.qwen3_moe_bridge") +pytest.importorskip("megatron.bridge.models.qwen_vl.qwen35_vl_bridge") + +from megatron.bridge.models.qwen.qwen3_moe_bridge import Qwen3MoEBridge +from megatron.bridge.models.qwen_vl.qwen35_vl_bridge import Qwen35VLMoEBridge +from megatron.core.transformer.enums import AttnBackend + +from art.megatron.flex_attention import FlexDotProductAttention +import art.megatron.provider as provider_module + + +class _FakeProvider: + def __init__(self) -> None: + self.transformer_layer_spec = self._base_layer_spec + self.finalized = False + + def _base_layer_spec( + self, config: object, vp_stage: int | None = None + ) -> SimpleNamespace: + return SimpleNamespace( + submodules=SimpleNamespace( + self_attention=SimpleNamespace( + submodules=SimpleNamespace(core_attention=object()) + ) + ), + ) + + def finalize(self) -> None: + self.finalized = True + + +class _FakeHybridProvider(_FakeProvider): + def _base_layer_spec( + self, config: object, vp_stage: int | None = None + ) -> SimpleNamespace: + del config, vp_stage + gdn_layer = SimpleNamespace( + submodules=SimpleNamespace( + self_attention=SimpleNamespace(submodules=SimpleNamespace()) + ) + ) + attention_layer = SimpleNamespace( + submodules=SimpleNamespace( + self_attention=SimpleNamespace( + submodules=SimpleNamespace(core_attention=object()) + ) + ) + ) + return SimpleNamespace(layer_specs=[gdn_layer, attention_layer]) + + +class _FakeBridge: + def __init__(self, *, model_bridge: object, provider: _FakeProvider) -> None: + self._model_bridge = model_bridge + self._provider = provider + self.hf_pretrained = SimpleNamespace(model_name_or_path="unused") + + def to_megatron_provider(self) -> _FakeProvider: + return self._provider + + +@pytest.mark.parametrize("bridge_type", [Qwen3MoEBridge, Qwen35VLMoEBridge]) +def test_get_provider_accepts_supported_qwen_moe_bridges( + monkeypatch: pytest.MonkeyPatch, + bridge_type: type[object], +) -> None: + provider = _FakeProvider() + fake_bridge = _FakeBridge( + model_bridge=object.__new__(bridge_type), + provider=provider, + ) + monkeypatch.setattr( + provider_module.AutoBridge, + "from_hf_pretrained", + lambda *args, **kwargs: fake_bridge, + ) + monkeypatch.setattr(provider_module.torch.cuda, "device_count", lambda: 2) + + resolved = provider_module.get_provider("unused-model") + + assert resolved is provider + assert provider.finalized is True + assert resolved.attention_backend is AttnBackend.auto + assert resolved.recompute_granularity == "full" + assert resolved.recompute_method == "uniform" + assert resolved.recompute_num_layers == 1 + assert resolved.tensor_model_parallel_size == 2 + assert resolved.context_parallel_size == 1 + assert resolved.pipeline_model_parallel_size == 1 + assert resolved.expert_model_parallel_size == 2 + assert resolved.expert_tensor_parallel_size == 1 + assert resolved.sequence_parallel is True + assert resolved.moe_shared_expert_overlap is True + assert resolved.moe_router_dtype == "fp32" + assert resolved.moe_aux_loss_coeff == 0.0 + assert resolved.calculate_per_token_loss is True + + layer_spec = provider_module._resolve_layer_spec( + resolved.transformer_layer_spec, + resolved, + vp_stage=7, + ) + layer_spec = cast(Any, layer_spec) + assert ( + layer_spec.submodules.self_attention.submodules.core_attention + is FlexDotProductAttention + ) + + +def test_get_provider_rejects_unsupported_bridge( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake_bridge = _FakeBridge(model_bridge=object(), provider=_FakeProvider()) + monkeypatch.setattr( + provider_module.AutoBridge, + "from_hf_pretrained", + lambda *args, **kwargs: fake_bridge, + ) + + with pytest.raises( + AssertionError, + match="Only Qwen3 and Qwen3.5 MoE models are supported", + ): + provider_module.get_provider("unsupported-model") + + +def test_get_provider_preserves_hybrid_qwen35_layer_specs( + monkeypatch: pytest.MonkeyPatch, +) -> None: + provider = _FakeHybridProvider() + fake_bridge = _FakeBridge( + model_bridge=object.__new__(Qwen35VLMoEBridge), + provider=provider, + ) + monkeypatch.setattr( + provider_module.AutoBridge, + "from_hf_pretrained", + lambda *args, **kwargs: fake_bridge, + ) + monkeypatch.setattr(provider_module.torch.cuda, "device_count", lambda: 1) + + resolved = provider_module.get_provider("unused-qwen35") + layer_spec = provider_module._resolve_layer_spec( + resolved.transformer_layer_spec, + resolved, + vp_stage=0, + ) + + layer_specs = getattr(layer_spec, "layer_specs", None) + assert layer_specs is not None + gdn_layer, attention_layer = layer_specs + assert not hasattr(gdn_layer.submodules.self_attention.submodules, "core_attention") + assert ( + attention_layer.submodules.self_attention.submodules.core_attention + is FlexDotProductAttention + ) diff --git a/tests/integration/test_megatron_qwen35_lora_wrapping.py b/tests/integration/test_megatron_qwen35_lora_wrapping.py new file mode 100644 index 00000000..a8571997 --- /dev/null +++ b/tests/integration/test_megatron_qwen35_lora_wrapping.py @@ -0,0 +1,447 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +import socket + +import pytest + +torch = pytest.importorskip("torch") +pytest.importorskip("megatron.bridge") +pytest.importorskip("megatron.bridge.models.qwen_vl.qwen35_vl_provider") + +from megatron.bridge.models.qwen_vl.qwen35_vl_provider import ( + Qwen3_5MoeVisionConfig, + Qwen35VLMoEModelProvider, +) +from megatron.core import parallel_state as ps +from megatron.core.extensions.transformer_engine import ( + TELayerNormColumnParallelLinear, + TERowParallelLinear, +) +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.attention import SelfAttention +from megatron.core.transformer.moe.shared_experts import SharedExpertMLP +from megatron.core.transformer.transformer_layer import TransformerLayer +from torch.distributed import destroy_process_group, init_process_group, is_initialized + +from art.megatron.bridge_adapter_compat import build_adapter_weights_by_base +from art.megatron.lora import ( + GatedDeltaNetInProjLoRA, + MLPExpertsLinearFC1LoRA, + MLPExpertsLinearFC2LoRA, + SelfAttentionLinearProjLoRA, + SelfAttentionLinearQKVLoRA, + SharedExpertsLinearFC1LoRA, + SharedExpertsLinearFC2LoRA, + apply_lora_adapters, +) + + +class _DenseMLP(torch.nn.Module): + def __init__( + self, + *, + linear_fc1: TELayerNormColumnParallelLinear, + linear_fc2: TERowParallelLinear, + ) -> None: + super().__init__() + self.linear_fc1 = linear_fc1 + self.linear_fc2 = linear_fc2 + + +def _make_qwen35_provider() -> Qwen35VLMoEModelProvider: + assert Qwen3_5MoeVisionConfig is not None + provider = Qwen35VLMoEModelProvider( + num_layers=4, + hidden_size=64, + ffn_hidden_size=128, + moe_ffn_hidden_size=32, + moe_shared_expert_intermediate_size=16, + num_attention_heads=4, + num_query_groups=1, + kv_channels=16, + linear_key_head_dim=8, + linear_value_head_dim=16, + linear_num_key_heads=2, + linear_num_value_heads=4, + num_moe_experts=4, + moe_router_topk=2, + normalization="RMSNorm", + gated_linear_unit=True, + add_bias_linear=False, + add_qkv_bias=False, + qk_layernorm=True, + hidden_dropout=0.0, + attention_dropout=0.0, + attention_output_gate=True, + experimental_attention_variant="gated_delta_net", + linear_attention_freq=4, + linear_conv_kernel_dim=2, + vocab_size=128, + seq_length=128, + position_embedding_type="mrope", + vision_config=Qwen3_5MoeVisionConfig(), + tensor_model_parallel_size=1, + expert_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + params_dtype=torch.bfloat16, + ) + provider.finalize() + return provider + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def _adapter_tensors( + lora: torch.nn.Module, + expert_idx: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + a_t = lora.A_T if expert_idx is None else lora.A_T[expert_idx] + b_t = lora.B_T if expert_idx is None else lora.B_T[expert_idx] + return a_t.transpose(-1, -2).contiguous(), b_t.transpose(-1, -2).contiguous() + + +@contextmanager +def _single_rank_model_parallel() -> Iterator[None]: + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for Megatron Qwen3.5 LoRA coverage.") + if is_initialized(): + pytest.skip("torch.distributed is already initialized in this process.") + + torch.cuda.set_device(0) + init_process_group( + backend="nccl", + init_method=f"tcp://127.0.0.1:{_find_free_port()}", + rank=0, + world_size=1, + ) + try: + ps.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + expert_model_parallel_size=1, + ) + model_parallel_cuda_manual_seed(1234) + yield + finally: + if getattr(ps, "model_parallel_is_initialized", lambda: False)(): + ps.destroy_model_parallel() + if is_initialized(): + destroy_process_group() + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="No CUDA available in this environment", +) +def test_apply_lora_adapters_wraps_qwen35_gdn_and_shared_experts() -> None: + with _single_rank_model_parallel(): + provider = _make_qwen35_provider() + model = provider.provide_language_model(pre_process=True, post_process=True) + apply_lora_adapters([model], provider) + + gdn_in_proj_qkv_prefixes: list[str] = [] + gdn_in_proj_z_prefixes: list[str] = [] + gdn_out_proj_prefixes: list[str] = [] + shared_fc1_gate_prefixes: list[str] = [] + shared_fc1_up_prefixes: list[str] = [] + shared_fc2_prefixes: list[str] = [] + + for module in model.modules(): + in_proj = getattr(module, "in_proj", None) + if isinstance(in_proj, GatedDeltaNetInProjLoRA): + gdn_in_proj_qkv_prefixes.append(in_proj.qkv_lora.adapter_model_prefix) + gdn_in_proj_z_prefixes.append(in_proj.z_lora.adapter_model_prefix) + + out_proj = getattr(module, "out_proj", None) + if isinstance(out_proj, SelfAttentionLinearProjLoRA): + prefix = out_proj.lora.adapter_model_prefix + if prefix.endswith(".linear_attn.out_proj"): + gdn_out_proj_prefixes.append(prefix) + + linear_fc1 = getattr(module, "linear_fc1", None) + if isinstance(linear_fc1, SharedExpertsLinearFC1LoRA): + shared_fc1_gate_prefixes.append( + linear_fc1.gate_lora.adapter_model_prefix + ) + shared_fc1_up_prefixes.append(linear_fc1.up_lora.adapter_model_prefix) + + linear_fc2 = getattr(module, "linear_fc2", None) + if isinstance(linear_fc2, SharedExpertsLinearFC2LoRA): + shared_fc2_prefixes.append( + linear_fc2.row_parallel_lora.lora.adapter_model_prefix + ) + + assert gdn_in_proj_qkv_prefixes + assert gdn_in_proj_z_prefixes + assert gdn_out_proj_prefixes + assert shared_fc1_gate_prefixes + assert shared_fc1_up_prefixes + assert shared_fc2_prefixes + assert len(gdn_in_proj_qkv_prefixes) == len(gdn_in_proj_z_prefixes) + assert len(gdn_in_proj_qkv_prefixes) == len(gdn_out_proj_prefixes) + assert len(shared_fc1_gate_prefixes) == len(shared_fc1_up_prefixes) + assert len(shared_fc1_gate_prefixes) == len(shared_fc2_prefixes) + assert all( + prefix.startswith("base_model.model.model.layers.") + and prefix.endswith(".linear_attn.in_proj_qkv") + for prefix in gdn_in_proj_qkv_prefixes + ) + assert all( + prefix.startswith("base_model.model.model.layers.") + and prefix.endswith(".linear_attn.in_proj_z") + for prefix in gdn_in_proj_z_prefixes + ) + assert all( + prefix.startswith("base_model.model.model.layers.") + and prefix.endswith(".linear_attn.out_proj") + for prefix in gdn_out_proj_prefixes + ) + assert all( + prefix.startswith("base_model.model.model.layers.") + and prefix.endswith(".mlp.shared_expert.gate_proj") + for prefix in shared_fc1_gate_prefixes + ) + assert all( + prefix.startswith("base_model.model.model.layers.") + and prefix.endswith(".mlp.shared_expert.up_proj") + for prefix in shared_fc1_up_prefixes + ) + assert all( + prefix.startswith("base_model.model.model.layers.") + and prefix.endswith(".mlp.shared_expert.down_proj") + for prefix in shared_fc2_prefixes + ) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="No CUDA available in this environment", +) +def test_apply_lora_adapters_accepts_layernorm_column_fc1_dense_path() -> None: + with _single_rank_model_parallel(): + provider = _make_qwen35_provider() + model = provider.provide_language_model(pre_process=True, post_process=True) + + target_layer = next( + module + for module in model.modules() + if isinstance(module, TransformerLayer) + and isinstance(module.self_attention, SelfAttention) + and isinstance(getattr(module.mlp, "shared_experts", None), SharedExpertMLP) + ) + dense_fc1 = target_layer.self_attention.linear_qkv + dense_fc2 = target_layer.self_attention.linear_proj + assert isinstance(dense_fc1, TELayerNormColumnParallelLinear) + assert isinstance(dense_fc2, TERowParallelLinear) + target_layer.mlp = _DenseMLP( + linear_fc1=dense_fc1, + linear_fc2=dense_fc2, + ) + + apply_lora_adapters([model], provider) + + assert isinstance(target_layer.mlp.linear_fc1, SharedExpertsLinearFC1LoRA) + assert isinstance(target_layer.mlp.linear_fc2, SharedExpertsLinearFC2LoRA) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="No CUDA available in this environment", +) +def test_build_adapter_weights_handles_grouped_qwen35_moe_hf_weights() -> None: + with _single_rank_model_parallel(): + provider = _make_qwen35_provider() + model = provider.provide_language_model(pre_process=True, post_process=True) + apply_lora_adapters([model], provider) + + target_layer = next( + module + for module in model.modules() + if isinstance(module, TransformerLayer) + and hasattr(module.mlp, "experts") + and isinstance(module.mlp.experts.linear_fc1, MLPExpertsLinearFC1LoRA) + and isinstance(module.mlp.experts.linear_fc2, MLPExpertsLinearFC2LoRA) + ) + fc1_handler = target_layer.mlp.experts.linear_fc1 + fc2_handler = target_layer.mlp.experts.linear_fc2 + + for lora in (fc1_handler.gate_lora, fc1_handler.up_lora, fc2_handler.lora): + lora.A_T.data.fill_(1) + lora.B_T.data.fill_(1) + + adapter_weights_by_base = build_adapter_weights_by_base([model]) + layer_prefix = ( + f"language_model.decoder.layers.{target_layer.layer_number - 1}.mlp.experts" + ) + + for expert_idx in range(fc1_handler.gate_lora.num_local_experts): + fc1_weights = adapter_weights_by_base[ + f"{layer_prefix}.linear_fc1.weight{expert_idx}" + ] + fc2_weights = adapter_weights_by_base[ + f"{layer_prefix}.linear_fc2.weight{expert_idx}" + ] + + assert len(fc1_weights) == 1 + assert len(fc2_weights) == 1 + + gate_linear_in, gate_linear_out = _adapter_tensors( + fc1_handler.gate_lora, expert_idx + ) + up_linear_in, up_linear_out = _adapter_tensors( + fc1_handler.up_lora, expert_idx + ) + fc2_linear_in, fc2_linear_out = _adapter_tensors( + fc2_handler.lora, expert_idx + ) + + torch.testing.assert_close( + fc1_weights[0].linear_in_weight.weight, + torch.cat([gate_linear_in, up_linear_in], dim=0), + ) + torch.testing.assert_close( + fc1_weights[0].linear_out_weight.weight, + torch.cat( + [ + torch.cat( + [ + gate_linear_out, + torch.zeros( + ( + gate_linear_out.shape[0], + up_linear_in.shape[0], + ), + device=gate_linear_out.device, + dtype=gate_linear_out.dtype, + ), + ], + dim=1, + ), + torch.cat( + [ + torch.zeros( + ( + up_linear_out.shape[0], + gate_linear_in.shape[0], + ), + device=up_linear_out.device, + dtype=up_linear_out.dtype, + ), + up_linear_out, + ], + dim=1, + ), + ], + dim=0, + ), + ) + torch.testing.assert_close( + fc2_weights[0].linear_in_weight.weight, + fc2_linear_in, + ) + torch.testing.assert_close( + fc2_weights[0].linear_out_weight.weight, + fc2_linear_out, + ) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="No CUDA available in this environment", +) +def test_build_adapter_weights_handles_grouped_qwen35_moe_hf_weights_with_expert_suffix() -> ( + None +): + with _single_rank_model_parallel(): + provider = _make_qwen35_provider() + model = provider.provide_language_model(pre_process=True, post_process=True) + apply_lora_adapters([model], provider) + + target_layer = next( + module + for module in model.modules() + if isinstance(module, TransformerLayer) + and hasattr(module.mlp, "experts") + and isinstance(module.mlp.experts.linear_fc1, MLPExpertsLinearFC1LoRA) + and isinstance(module.mlp.experts.linear_fc2, MLPExpertsLinearFC2LoRA) + ) + fc1_handler = target_layer.mlp.experts.linear_fc1 + fc2_handler = target_layer.mlp.experts.linear_fc2 + + for lora in (fc1_handler.gate_lora, fc1_handler.up_lora, fc2_handler.lora): + lora.A_T.data.fill_(1) + lora.B_T.data.fill_(1) + + adapter_weights_by_base = build_adapter_weights_by_base([model]) + layer_prefix = ( + f"language_model.decoder.layers.{target_layer.layer_number - 1}.mlp.experts" + ) + expert_idx = 0 + fc1_weights = adapter_weights_by_base[ + f"{layer_prefix}.linear_fc1.weight{expert_idx}" + ] + fc2_weights = adapter_weights_by_base[ + f"{layer_prefix}.linear_fc2.weight{expert_idx}" + ] + + assert len(fc1_weights) == 1 + assert len(fc2_weights) == 1 + assert fc1_weights[0].global_base_prefix == f"{layer_prefix}.linear_fc1" + assert fc2_weights[0].global_base_prefix == f"{layer_prefix}.linear_fc2" + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="No CUDA available in this environment", +) +def test_build_adapter_weights_exposes_qwen35_q_proj_adapter() -> None: + with _single_rank_model_parallel(): + provider = _make_qwen35_provider() + model = provider.provide_language_model(pre_process=True, post_process=True) + apply_lora_adapters([model], provider) + + target_layer = next( + module + for module in model.modules() + if isinstance(module, TransformerLayer) + and isinstance( + getattr(module.self_attention, "linear_qkv", None), + SelfAttentionLinearQKVLoRA, + ) + ) + qkv_handler = target_layer.self_attention.linear_qkv + + for lora in ( + qkv_handler.q_proj_lora, + qkv_handler.k_proj_lora, + qkv_handler.v_proj_lora, + ): + lora.A_T.data.fill_(1) + lora.B_T.data.fill_(1) + + adapter_weights = build_adapter_weights_by_base([model])[ + f"language_model.decoder.layers.{target_layer.layer_number - 1}.self_attention.linear_qkv.weight" + ] + adapter_weights_by_key = { + adapter_weight.adapter_key: adapter_weight + for adapter_weight in adapter_weights + } + + assert set(adapter_weights_by_key) == {"adapter_q", "adapter_k", "adapter_v"} + q_linear_in, q_linear_out = _adapter_tensors(qkv_handler.q_proj_lora) + torch.testing.assert_close( + adapter_weights_by_key["adapter_q"].linear_in_weight.weight, + q_linear_in, + ) + torch.testing.assert_close( + adapter_weights_by_key["adapter_q"].linear_out_weight.weight, + q_linear_out, + ) diff --git a/tests/unit/test_local_backend_monitor.py b/tests/unit/test_local_backend_monitor.py new file mode 100644 index 00000000..c9dd7aca --- /dev/null +++ b/tests/unit/test_local_backend_monitor.py @@ -0,0 +1,90 @@ +import asyncio +from pathlib import Path + +import pytest + +from art import TrainableModel +from art.local import LocalBackend + + +class _FakeResponse: + def __init__(self, body: str, status: int = 200) -> None: + self._body = body + self.status = status + + async def __aenter__(self) -> "_FakeResponse": + return self + + async def __aexit__(self, exc_type, exc, tb) -> bool: + return False + + async def text(self) -> str: + return self._body + + +class _FakeSession: + def __init__(self, urls: list[str]) -> None: + self._urls = urls + + async def __aenter__(self) -> "_FakeSession": + return self + + async def __aexit__(self, exc_type, exc, tb) -> bool: + return False + + def get(self, url: str, timeout) -> _FakeResponse: + del timeout + self._urls.append(url) + if url.endswith("/metrics"): + return _FakeResponse( + "vllm:num_requests_running 0\nvllm:num_requests_waiting 0\n" + ) + if url.endswith("/health"): + return _FakeResponse("ok") + raise AssertionError(f"Unexpected URL: {url}") + + +@pytest.mark.asyncio +async def test_monitor_openai_server_uses_health_probe_when_idle( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + backend = LocalBackend(path=str(tmp_path)) + model = TrainableModel( + name="qwen35-monitor", + project="unit-tests", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + base_path=str(tmp_path), + ) + + class _FakeService: + async def vllm_engine_is_sleeping(self) -> bool: + return False + + backend._services[model.name] = _FakeService() # type: ignore[index] + requested_urls: list[str] = [] + sleep_calls = 0 + + async def fake_sleep(_seconds: float) -> None: + nonlocal sleep_calls + sleep_calls += 1 + if sleep_calls > 1: + raise asyncio.CancelledError + + monkeypatch.setattr("art.local.backend.asyncio.sleep", fake_sleep) + monkeypatch.setattr( + "art.local.backend.aiohttp.ClientSession", + lambda: _FakeSession(requested_urls), + ) + + with pytest.raises(asyncio.CancelledError): + await backend._monitor_openai_server( + model, + "http://127.0.0.1:1234/v1", + "default", + ) + + assert requested_urls == [ + "http://127.0.0.1:1234/metrics", + "http://127.0.0.1:1234/health", + ] diff --git a/tests/unit/test_megatron_dedicated.py b/tests/unit/test_megatron_dedicated.py index cc0a808b..dcc56aa2 100644 --- a/tests/unit/test_megatron_dedicated.py +++ b/tests/unit/test_megatron_dedicated.py @@ -1,24 +1,28 @@ import asyncio +from contextlib import nullcontext import os from pathlib import Path import shlex import sys import types as pytypes +from types import SimpleNamespace from typing import Any import pytest +import torch pytest.importorskip("vllm") from art import TrainableModel, types from art.dev.model import InternalModelConfig +from art.dev.validate import QWEN3_5_MOE_MODELS from art.megatron.backend import MegatronBackend from art.megatron.jobs import ( MegatronMergedTrainJob, MergedWeightTransferInitInfo, ) -from art.megatron.service import MegatronService -from art.megatron.train import _unwrap_art_wrapper_name +from art.megatron.service import MegatronService, create_identity_lora +from art.megatron.train import _compile_enabled, _unwrap_art_wrapper_name @pytest.mark.asyncio @@ -148,6 +152,79 @@ def test_unwrap_art_wrapper_name_strips_compiled_wrapper_segments() -> None: ) +def test_compile_enabled_disables_qwen35_moe_by_default() -> None: + assert _compile_enabled("Qwen/Qwen3-30B-A3B-Instruct-2507") is True + assert _compile_enabled("Qwen/Qwen3.5-32B-Instruct") is True + for model_identifier in QWEN3_5_MOE_MODELS: + assert _compile_enabled(model_identifier) is False + + +def test_compile_enabled_honors_env_override( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("ART_DISABLE_MEGATRON_COMPILE", "0") + assert _compile_enabled("Qwen/Qwen3.5-35B-A3B") is True + monkeypatch.setenv("ART_DISABLE_MEGATRON_COMPILE", "1") + assert _compile_enabled("Qwen/Qwen3-30B-A3B-Instruct-2507") is False + + +def test_create_identity_lora_uses_nested_text_config_when_top_level_lacks_vocab_size( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + top_level_config = SimpleNamespace( + text_config=SimpleNamespace(vocab_size=128), + ) + seen: dict[str, Any] = {} + + class FakeModel: + name_or_path = "" + + def named_parameters(self) -> list[tuple[str, torch.Tensor]]: + return [ + ( + "model.layers.0.self_attn.q_proj.weight", + torch.empty(1, device="meta"), + ), + ( + "model.layers.0.linear_attn.in_proj_qkv.weight", + torch.empty(1, device="meta"), + ), + ] + + class FakePeftModel: + def save_pretrained(self, lora_path: str) -> None: + Path(lora_path).mkdir(parents=True, exist_ok=True) + + monkeypatch.setattr( + "transformers.AutoConfig.from_pretrained", + lambda *_args, **_kwargs: top_level_config, + ) + monkeypatch.setattr( + "transformers.AutoModelForCausalLM.from_config", + lambda config, **_kwargs: seen.setdefault("config", config) or FakeModel(), + ) + monkeypatch.setattr("accelerate.init_empty_weights", nullcontext) + monkeypatch.setattr( + "peft.get_peft_model", + lambda _model, lora_config, **_kwargs: ( + seen.setdefault("lora_config", lora_config) or FakePeftModel() + ), + ) + monkeypatch.setattr( + "art.megatron.service.convert_checkpoint_if_needed", + lambda _path: None, + ) + + create_identity_lora("Qwen/Qwen3.5-35B-A3B", str(tmp_path)) + + assert seen["config"] is top_level_config.text_config + assert ( + "model.layers.0.linear_attn.in_proj_qkv.weight" + in seen["lora_config"].target_parameters + ) + + @pytest.mark.asyncio async def test_megatron_service_start_openai_server_dedicated_starts_subprocess( tmp_path: Path, diff --git a/tests/unit/test_megatron_qwen_helpers.py b/tests/unit/test_megatron_qwen_helpers.py new file mode 100644 index 00000000..0c0b7782 --- /dev/null +++ b/tests/unit/test_megatron_qwen_helpers.py @@ -0,0 +1,59 @@ +from types import SimpleNamespace +from typing import Any, cast + +import pytest + +pytest.importorskip("megatron.bridge") + +import torch + +from art.megatron.lora import SelfAttentionLinearQKVLoRA +from art.megatron.train import _canonical_art_param_name + + +def test_canonical_art_param_name_strips_art_wrapper_segments() -> None: + assert ( + _canonical_art_param_name( + "module.language_model.decoder.layers.0.self_attention.out_proj.linear_proj.weight" + ) + == "language_model.decoder.layers.0.self_attention.out_proj.weight" + ) + assert ( + _canonical_art_param_name( + "module.language_model.decoder.layers.0.mlp.linear_fc2.row_parallel_lora.linear_proj.weight" + ) + == "language_model.decoder.layers.0.mlp.linear_fc2.weight" + ) + + +def test_self_attention_linear_qkv_lora_accepts_nongated_qwen3_layout( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr( + "art.megatron.lora.ps.get_tensor_model_parallel_world_size", lambda: 1 + ) + provider: Any = SimpleNamespace( + kv_channels=128, + num_query_groups=4, + num_attention_heads=32, + attention_output_gate=False, + ) + q_out_features = provider.kv_channels * provider.num_attention_heads + kv_out_features = provider.kv_channels * provider.num_query_groups + linear_qkv: Any = SimpleNamespace( + weight=torch.empty(q_out_features + 2 * kv_out_features, 16), + in_features=16, + return_layernorm_output=False, + return_layernorm_output_gathered=False, + ) + + wrapped = SelfAttentionLinearQKVLoRA( + adapter_model_prefix="base_model.model.model.layers.0.self_attn", + linear_qkv=cast(Any, linear_qkv), + rank=4, + alpha=8.0, + provider=cast(Any, provider), + ) + + assert wrapped.attention_output_gate is False + assert wrapped.q_proj_lora.B_T.shape[-1] == q_out_features diff --git a/uv.lock b/uv.lock index aa54bd8b..efc42a52 100644 --- a/uv.lock +++ b/uv.lock @@ -4327,7 +4327,7 @@ wheels = [ [[package]] name = "megatron-bridge" version = "0.4.0rc0" -source = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git?rev=75f2c5ad4afb702b57b4781a00f5291a66bcf183#75f2c5ad4afb702b57b4781a00f5291a66bcf183" } +source = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git?rev=e049cc00c24d03e2ae45d2608c7a44e2d2364e3d#e049cc00c24d03e2ae45d2608c7a44e2d2364e3d" } dependencies = [ { name = "accelerate" }, { name = "causal-conv1d" }, @@ -4358,7 +4358,7 @@ dependencies = [ [[package]] name = "megatron-core" version = "0.16.0rc0" -source = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git?subdirectory=3rdparty%2FMegatron-LM&rev=75f2c5ad4afb702b57b4781a00f5291a66bcf183#75f2c5ad4afb702b57b4781a00f5291a66bcf183" } +source = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git?subdirectory=3rdparty%2FMegatron-LM&rev=e049cc00c24d03e2ae45d2608c7a44e2d2364e3d#e049cc00c24d03e2ae45d2608c7a44e2d2364e3d" } dependencies = [ { name = "numpy" }, { name = "packaging" }, @@ -5591,7 +5591,7 @@ requires-dist = [ { name = "langgraph", marker = "extra == 'langgraph'", specifier = ">=0.6.2" }, { name = "litellm", specifier = ">=1.71.1,<=1.82.0" }, { name = "matplotlib", marker = "extra == 'plotting'", specifier = ">=3.10.1" }, - { name = "megatron-bridge", marker = "extra == 'megatron'", git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git?rev=75f2c5ad4afb702b57b4781a00f5291a66bcf183" }, + { name = "megatron-bridge", marker = "extra == 'megatron'", git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git?rev=e049cc00c24d03e2ae45d2608c7a44e2d2364e3d" }, { name = "megatron-core", marker = "extra == 'megatron'", specifier = "==0.16.0rc0" }, { name = "ml-dtypes", marker = "python_full_version < '3.13' and extra == 'megatron'", specifier = ">=0.5.0" }, { name = "nbclient", marker = "extra == 'backend'", specifier = ">=0.10.1" },