-
Notifications
You must be signed in to change notification settings - Fork 23
[feat] Resume from ckpt #135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5cd3c0f
91eeaeb
6eebda8
cdd9c1b
1542492
9883118
d41a634
21f9918
1e59531
9bb3f39
fdf1f71
6cf5160
144ffe6
e21f870
3359209
70ebe50
483778d
039789b
54de1a4
920ab86
ffd6304
582bd41
9cb6106
c0cf72e
505a75c
a222b5b
7499e00
cd0b094
abf2c2f
8bf7a6a
27e76c6
5d68910
9326e64
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -99,21 +99,29 @@ def train(): | |
| # model.set_lr_scheduler('LinearLR') | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这文件命名有一个typo
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 确实是 typo,但 self_congnition.py 在 main 上就已存在,是否在单独 PR 中修正更合适? |
||
|
|
||
| # Step 6: Optionally resume from a previous checkpoint | ||
| consumed_train_samples = 0 | ||
| global_step = 0 | ||
| if resume_path: | ||
| logger.info(f'Resuming training from {resume_path}') | ||
| model.load(resume_path, load_optimizer=True) | ||
| logger.info(f'Resuming model weights from {resume_path}') | ||
| model.load(resume_path) | ||
| trainer_state = model.load_training_state(resume_path) | ||
| dataloader.skip_consumed_samples(trainer_state['consumed_train_samples']) | ||
| consumed_train_samples = int(trainer_state['consumed_train_samples']) | ||
| global_step = int(trainer_state['cur_step']) | ||
|
|
||
| # Step 7: Run the training loop | ||
| logger.info(model.get_train_configs().model_dump()) | ||
|
|
||
| for epoch in range(3): | ||
| logger.info(f'Starting epoch {epoch}') | ||
| for step, batch in enumerate(dataloader): | ||
| for _, batch in enumerate(dataloader): | ||
| # Forward pass + backward pass (computes gradients) | ||
| model.forward_backward(inputs=batch) | ||
|
|
||
| # Step | ||
| model.clip_grad_and_step() | ||
| consumed_train_samples += len(batch) | ||
| global_step += 1 | ||
| # Equal to the following steps: | ||
| # # Clip gradients to prevent exploding gradients (max norm = 1.0) | ||
| # model.clip_grad_norm(1.0) | ||
|
|
@@ -125,13 +133,17 @@ def train(): | |
| # model.lr_step() | ||
|
|
||
| # Log the loss every 2 steps (aligned with gradient accumulation) | ||
| if step % 2 == 0: | ||
| if global_step % 2 == 0: | ||
| # Print metric | ||
| metric = model.calculate_metric(is_training=True) | ||
| logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric.result}') | ||
| logger.info(f'Current is step {global_step} of {len(dataloader)}, metric: {metric.result}') | ||
|
|
||
| # Step 8: Save the trained checkpoint | ||
| twinkle_path = model.save(name=f'twinkle-epoch-{epoch}', save_optimizer=True) | ||
| twinkle_path = model.save( | ||
| name=f'twinkle-epoch-{epoch}', | ||
| save_optimizer=True, | ||
| consumed_train_samples=consumed_train_samples, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dataloader.get_consumed_samples()?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 或者,dataloader.get_state(),更通用一些
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 另外,这里额外测试下torchrun/ray的兼容性,还有megatron和transformers双模型的兼容性 |
||
| ) | ||
| logger.info(f'Saved checkpoint: {twinkle_path}') | ||
|
|
||
| # Step 9: Upload the checkpoint to ModelScope Hub | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| from pathlib import Path | ||
| from typing import Any, Optional | ||
|
|
||
| from twinkle import get_logger | ||
|
|
||
|
|
||
| logger = get_logger() | ||
|
|
||
|
|
||
| def _build_model_kwargs(adapter_name: str) -> dict: | ||
| if not adapter_name: | ||
| return {} | ||
| return {'adapter_name': adapter_name} | ||
|
|
||
|
|
||
| def resume_from_checkpoint( | ||
| model: Any, | ||
| dataloader: Any, | ||
| checkpoint_path: Path, | ||
| *, | ||
| resume_only_model: bool, | ||
| ignore_data_skip: bool, | ||
| adapter_name: Optional[str] = None) -> int: | ||
| adapter_name = adapter_name or '' | ||
| checkpoint_dir = str(checkpoint_path) | ||
| model_kwargs = _build_model_kwargs(adapter_name) | ||
| if model_kwargs: | ||
| # Load adapter checkpoint. | ||
| model.load( | ||
| name=checkpoint_path.name, | ||
| output_dir=str(checkpoint_path.parent), | ||
| **model_kwargs, | ||
| ) | ||
|
|
||
| if resume_only_model: | ||
| # Only load model weights, optionally skip data. | ||
| if ignore_data_skip: | ||
| logger.info('Resumed weights only and restarted progress from step 0.') | ||
| return 0 | ||
| progress = model.read_training_progress(checkpoint_dir, **model_kwargs) | ||
| # Skip consumed samples in dataloader and move optimizer to the right step. | ||
| consumed_train_samples = int(progress['consumed_train_samples']) | ||
| dataloader.skip_consumed_samples(consumed_train_samples) | ||
| optimizer_group = model.optimizer_group[adapter_name] | ||
| optimizer_group.cur_step = progress['cur_step'] | ||
| optimizer_group.gradient_accumulation_steps = progress['gradient_accumulation_steps'] | ||
| logger.info(f'Skipped {consumed_train_samples} consumed samples.') | ||
| return consumed_train_samples | ||
|
|
||
| # Load full training state, including model weights, optimizer states, and training progress. | ||
| trainer_state = model.load_training_state(checkpoint_dir, **model_kwargs) | ||
| consumed_train_samples = int(trainer_state['consumed_train_samples']) | ||
| dataloader.skip_consumed_samples(consumed_train_samples) | ||
| logger.info(f'Restored full training state from step {trainer_state["cur_step"]}.') | ||
| return consumed_train_samples |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,3 +50,16 @@ for data in dataloader: | |
| model.forward_backward(...) | ||
| model.clip_grad_and_step(..., gradient_accumulation_steps=16) | ||
| ``` | ||
|
|
||
| ## Checkpoint and Resume | ||
|
|
||
| `TransformersModel.save()` can save either weights only or a resumable training checkpoint. | ||
|
|
||
| - `model.save(name, save_optimizer=True, consumed_train_samples=...)` saves weights together with optimizer, scheduler, scaler, RNG, and `trainer_state.json`. | ||
| - `model.load(name, output_dir=..., adapter_name=...)` restores LoRA / adapter model weights. | ||
| - `model.read_training_progress(checkpoint_dir, ...)` reads checkpoint metadata such as `cur_step`, `gradient_accumulation_steps`, and `consumed_train_samples`. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这两个比较相似,合成一个是否合适?比如 |
||
| - `model.load_training_state(checkpoint_dir, ...)` restores optimizer-related state and returns the training progress dictionary. | ||
|
|
||
| For full-parameter training, restore model weights by constructing `TransformersModel` with the checkpoint path as `model_id`, for example `TransformersModel(model_id='./output/fsdp2/last-checkpoint')`, and then call `load_training_state(...)` to restore optimizer state and training progress. | ||
|
|
||
| For end-to-end resume logic, including dataloader skipping, refer to `cookbook/transformers/fsdp2.py` and `cookbook/transformers/resume_utils.py`. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load_training_state和read_training_progress什么区别,能否合并为一个呢