Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cookbook/transformers/sp_fsdp_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
device_type=Platform.get_platform().device_prefix(),
)]

# FSDP + SP validation over 4 GPUs: dp=2, fsdp=2 (SP only affects input slicing)
# FSDP + sequence-parallel validation over 4 GPUs: dp=2, fsdp=2.
# In Transformers route, ulysses_size is the total sequence-parallel degree.
device_mesh = DeviceMesh(
device_type='cuda',
device_type=Platform.get_platform().device_prefix(),
mesh=np.arange(4).reshape(2, 2),
mesh_dim_names=('dp', 'fsdp'),
ulysses_size=2,
Expand Down
3 changes: 2 additions & 1 deletion cookbook/transformers/sp_fsdp_dense.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/bin/bash
# To enabele sequence parallelism, please set ulysses_size > 1
# To enable Transformers sequence parallelism, please set ulysses_size > 1.
# ulysses_size is interpreted as the total sequence-parallel degree.
# device_mesh = DeviceMesh(
# device_type="cuda",
# mesh=np.arange(4).reshape(2, 2),
Expand Down
6 changes: 6 additions & 0 deletions src/twinkle/metric/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,19 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M
return
loss = outputs['loss']
loss_reduction = kwargs.get('loss_reduction', 'mean')
ulysses_size = getattr(self.device_mesh, 'ulysses_size', None) or 1
if loss_reduction == 'sum':
if not isinstance(inputs, list):
inputs = [inputs]
for input in inputs:
# `Transformers` models may use reduction=sum, to average grads before step
labels = input['labels']
self.num_tokens += (labels >= 0).sum().item()
# Sequence-parallel gathered loss is replicated on each ulysses rank, while
# local labels still count only the shard-local tokens. Normalize the loss
# contribution here so metric-side averaging matches the non-SP path.
if ulysses_size > 1:
loss = loss / float(ulysses_size)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么会放到这里呢,或者说,model进行backward的loss是否需要除以ulysses-size

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss_instance 的reduction 为sum时,这里loss 是 在每个 ulysses rank 上都复制了一份的 全序列 loss,但这里统计的 num_tokens 还是 每个 rank 本地 shard 的 token 数。两边口径不一致,所以要除一次 ulysses_size,这里除一下只是只为修 metric 打印口径;至于反向传播时loss是没有除以ulysses size的,在GatherLoss.apply中只保留了本地梯度

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还有我感觉这里当开启sp时这里这么判断会有点问题,当fsdp=2时,raw_dp_world_size=2,而data_world_size=1,此时就跳过gather了
df3eb2f42ce7411a076b820ffd3372f0

另外这里应该也要改成raw_dp_fsdp_world_size,因为后面gather 的维度是process_group,不是data_world_size
16b985153d0f9675f6fcbad778eba2cb

image

grad_norm = kwargs.get('grad_norm')
if grad_norm is not None:
self.grad_norm = grad_norm
Expand Down
Loading
Loading