Skip to content
32 changes: 31 additions & 1 deletion src/art/megatron/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,31 @@ class MegatronTrainingJob(BaseModel):
log_path: str = DEFAULT_TRAINING_LOG_PATH


class MergedWeightTransferInitInfo(BaseModel):
master_address: str
master_port: int
rank_offset: int
world_size: int


class MergedWeightTransferSpec(BaseModel):
init_info: MergedWeightTransferInitInfo
vllm_base_url: str
served_model_name: str


class MegatronMergedTrainJob(MegatronTrainingJob):
job_type: Literal["merged"] = "merged"
merged_weight_transfer: MergedWeightTransferSpec


class MegatronSyncJob(BaseModel):
job_type: Literal["sync"] = "sync"
lora_path: str
merged_weight_transfer: MergedWeightTransferSpec
log_path: str = DEFAULT_TRAINING_LOG_PATH


class MegatronSFTTrainingJob(BaseModel):
job_type: Literal["sft"] = "sft"
lora_path: str
Expand All @@ -35,4 +60,9 @@ class MegatronSFTTrainingJob(BaseModel):
log_path: str = DEFAULT_TRAINING_LOG_PATH


MegatronJob = MegatronTrainingJob | MegatronSFTTrainingJob
MegatronJob = (
MegatronTrainingJob
| MegatronMergedTrainJob
| MegatronSyncJob
| MegatronSFTTrainingJob
)
1 change: 1 addition & 0 deletions src/art/megatron/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def get_provider(
)
)
provider = bridge.to_megatron_provider()
setattr(provider, "art_bridge", bridge)
base_layer_spec = provider.transformer_layer_spec

def _flex_attention_layer_spec(
Expand Down
161 changes: 139 additions & 22 deletions src/art/megatron/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path
import shlex
import shutil
import signal
import socket
import subprocess
import sys
Expand All @@ -28,12 +29,17 @@
from ..unsloth.service import do_sleep, do_wake_up, gc_and_empty_cuda_cache
from ..utils.convert_moe_lora import convert_checkpoint_if_needed
from ..utils.get_model_step import get_step_from_dir
from ..utils.network import find_free_tcp_port
from ..utils.output_dirs import get_step_checkpoint_dir
from ..vllm import get_llm, openai_server_task, run_on_workers
from .client import create_megatron_job_paths, stream_megatron_job, write_megatron_job
from .jobs import (
MegatronMergedTrainJob,
MegatronSFTTrainingJob,
MegatronSyncJob,
MegatronTrainingJob,
MergedWeightTransferInitInfo,
MergedWeightTransferSpec,
)
from .lora import LORA_ALPHA, LORA_RANK
from .sft_batches import materialize_sft_batches
Expand Down Expand Up @@ -148,6 +154,10 @@ class MegatronService:
_vllm_log_file: Any = field(default=None, repr=False)
_vllm_host: str = "127.0.0.1"
_vllm_port: int = 0
_merged_weight_transfer_init_info: MergedWeightTransferInitInfo | None = field(
default=None,
repr=False,
)

@property
def is_dedicated(self) -> bool:
Expand Down Expand Up @@ -247,17 +257,59 @@ def _ensure_lora_adapter_config(
return
self._default_lora_adapter_config().save_pretrained(lora_path)

def _build_merged_weight_transfer_spec(self, step: int) -> MergedWeightTransferSpec:
init_info = self._merged_weight_transfer_init_info
assert init_info is not None
return MergedWeightTransferSpec(
init_info=init_info,
vllm_base_url=self._vllm_base_url,
served_model_name=f"{self.model_name}@{step}",
)

def _resolve_active_lora_path(self) -> str:
lora_path = get_last_checkpoint_dir(self.output_dir)
if lora_path is None:
lora_path = get_step_checkpoint_dir(self.output_dir, 0)
self._latest_step = 0
else:
self._latest_step = get_step_from_dir(self.output_dir)
self._ensure_identity_lora(lora_path)
if self.is_dedicated or self.rollout_weights_mode == "lora":
self._ensure_identity_lora(lora_path)
self._ensure_lora_adapter_config(lora_path)
return lora_path

async def _set_served_model_name(self, step: int) -> None:
import httpx

async with httpx.AsyncClient() as client:
response = await client.post(
f"{self._vllm_base_url}/art/set_served_model_name",
json={"name": f"{self.model_name}@{step}"},
timeout=30.0,
)
response.raise_for_status()
self._latest_step = step

async def _init_merged_weight_transfer(self) -> None:
import httpx

if self._merged_weight_transfer_init_info is not None:
return
assert len(self.config["trainer_gpu_ids"]) == 1
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self._vllm_base_url}/get_world_size",
timeout=30.0,
)
response.raise_for_status()
inference_world_size = int(response.json()["world_size"])
self._merged_weight_transfer_init_info = MergedWeightTransferInitInfo(
master_address="127.0.0.1",
master_port=find_free_tcp_port(),
rank_offset=1,
world_size=inference_world_size + 1,
)

async def _start_vllm_subprocess(
self,
lora_path: str,
Expand Down Expand Up @@ -285,8 +337,13 @@ async def _start_vllm_subprocess(
if config and "engine_args" in config:
engine_args.update(dict(config["engine_args"]))
engine_args.setdefault("generation_config", "vllm")
engine_args["enable_lora"] = True
engine_args.setdefault("max_loras", 2)
if self.rollout_weights_mode == "merged":
engine_args["weight_transfer_config"] = {"backend": "nccl"}
engine_args.pop("enable_lora", None)
engine_args.pop("max_loras", None)
else:
engine_args["enable_lora"] = True
engine_args.setdefault("max_loras", 2)
for key in ("model", "served_model_name", "enable_sleep_mode"):
engine_args.pop(key, None)

Expand Down Expand Up @@ -366,6 +423,25 @@ async def _reload_adapter(self, checkpoint_path: str, step: int) -> None:
response.raise_for_status()
self._latest_step = step

async def _sync_dedicated_merged_weights(
self,
*,
lora_path: str,
step: int,
) -> None:
await self._ensure_megatron_running()
await self._init_merged_weight_transfer()
job_path, log_path = self._create_megatron_job_paths()
job = MegatronSyncJob(
lora_path=lora_path,
merged_weight_transfer=self._build_merged_weight_transfer_spec(step),
log_path=log_path,
)
write_megatron_job(job, job_path=job_path)
async for _ in stream_megatron_job(job, job_path=job_path):
pass
self._latest_step = step

def _stop_vllm_subprocess(self) -> None:
if self._vllm_process is not None:
self._vllm_process.terminate()
Expand All @@ -378,12 +454,13 @@ def _stop_vllm_subprocess(self) -> None:
if self._vllm_log_file is not None:
self._vllm_log_file.close()
self._vllm_log_file = None
self._merged_weight_transfer_init_info = None

def _stop_megatron_process(self) -> None:
if self._megatron_process is None:
return
if self._megatron_process.returncode is None:
self._megatron_process.terminate()
os.killpg(os.getpgid(self._megatron_process.pid), signal.SIGTERM)
self._megatron_process = None

async def _add_lora_aliases(
Expand All @@ -402,8 +479,10 @@ async def _add_lora_aliases(

async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None:
if self.is_dedicated:
assert self.rollout_weights_mode == "lora"
await self._reload_adapter(checkpoint_dir, step)
if self.rollout_weights_mode == "merged":
await self._set_served_model_name(step)
else:
await self._reload_adapter(checkpoint_dir, step)
return
llm = await self.llm
await llm.pause_generation()
Expand Down Expand Up @@ -458,6 +537,7 @@ async def _ensure_megatron_running(self) -> None:
command,
cwd=str(project_root),
env=launch_env,
start_new_session=True,
)

def _clear_pending_jobs(self) -> None:
Expand Down Expand Up @@ -535,9 +615,15 @@ async def start_openai_server(
lora_path = self._resolve_active_lora_path()

if self.is_dedicated:
assert self.rollout_weights_mode == "lora"
port = (config or {}).get("server_args", {}).get("port", 8000)
return await self._start_vllm_subprocess(lora_path, port, config)
location = await self._start_vllm_subprocess(lora_path, port, config)
if self.rollout_weights_mode == "merged":
self._clear_pending_jobs()
await self._sync_dedicated_merged_weights(
lora_path=lora_path,
step=self._latest_step,
)
return location

lora_path_for_server = (
lora_path if self._adapter_has_weights(lora_path) else None
Expand Down Expand Up @@ -575,7 +661,6 @@ async def train(
verbose: bool = False,
) -> AsyncIterator[dict[str, float]]:
if self.is_dedicated:
assert self.rollout_weights_mode == "lora"
await self._ensure_megatron_running()

lora_path = self._resolve_active_lora_path()
Expand All @@ -586,24 +671,56 @@ async def train(
"MegatronService subprocess jobs must use moe_routing_replay_path."
)
job_path, log_path = self._create_megatron_job_paths()
job = MegatronTrainingJob(
lora_path=lora_path,
optimizer_state_path=self._get_optimizer_state_path("rl"),
disk_packed_tensors=disk_packed_tensors,
config=config,
experimental_config=cast(dict[str, Any], _config),
moe_routing_replay_path=_config.get("moe_routing_replay_path"),
moe_routing_replay_strict=_config.get(
"moe_routing_replay_strict", True
),
log_path=log_path,
)
next_step = self._latest_step + 1
if self.rollout_weights_mode == "merged":
await self._init_merged_weight_transfer()
job = MegatronMergedTrainJob(
lora_path=lora_path,
optimizer_state_path=self._get_optimizer_state_path("rl"),
disk_packed_tensors=disk_packed_tensors,
config=config,
experimental_config=cast(dict[str, Any], _config),
moe_routing_replay_path=_config.get("moe_routing_replay_path"),
moe_routing_replay_strict=_config.get(
"moe_routing_replay_strict", True
),
merged_weight_transfer=self._build_merged_weight_transfer_spec(
next_step
),
log_path=log_path,
)
else:
job = MegatronTrainingJob(
lora_path=lora_path,
optimizer_state_path=self._get_optimizer_state_path("rl"),
disk_packed_tensors=disk_packed_tensors,
config=config,
experimental_config=cast(dict[str, Any], _config),
moe_routing_replay_path=_config.get("moe_routing_replay_path"),
moe_routing_replay_strict=_config.get(
"moe_routing_replay_strict", True
),
log_path=log_path,
)
write_megatron_job(job, job_path=job_path)

async for result in stream_megatron_job(job, job_path=job_path):
yield {key: float(value) for key, value in result.items()}

await self._publish_dedicated_training_checkpoint(lora_path=lora_path)
if self.rollout_weights_mode == "merged":
new_checkpoint_dir = get_step_checkpoint_dir(self.output_dir, next_step)
os.makedirs(new_checkpoint_dir, exist_ok=True)
shutil.copy(
f"{lora_path}/adapter_model.safetensors",
f"{new_checkpoint_dir}/adapter_model.safetensors",
)
self._ensure_lora_adapter_config(
new_checkpoint_dir,
source_path=lora_path,
)
self._latest_step = next_step
else:
await self._publish_dedicated_training_checkpoint(lora_path=lora_path)
return
llm, lora_path = await self._prepare_for_training()
if _config.get("moe_routing_replay_bundle") is not None:
Expand Down
Loading
Loading