Conversation
There was a problem hiding this comment.
Code Review
This pull request modifies the slicing logic for mg_scale_inv in gpt_bridge.py, moving the calculation earlier in the weight retrieval process. Feedback indicates that removing the slicing after the all_gather operation may result in incorrect tensor shapes for RowParallel layers, as the concatenated scale tensor could contain more blocks than the global weight requires. A suggestion was provided to restore the slicing to ensure the exported scale tensor matches the expected format.
| @@ -397,7 +400,6 @@ def _get_weight( | |||
| mg_scale_inv = self._all_gather_tp(mg_scale_inv, tp_dim, is_expert) | |||
| mg_scale_inv = self._broadcast_ep_pp(mg_scale_inv, is_expert) | |||
| tensor = tensor.view(torch.float8_e4m3fn) | |||
There was a problem hiding this comment.
The removal of the slicing for mg_scale_inv at the end of _get_weight can lead to incorrect shapes for exported weights in RowParallel layers. When tp_dim is 1 (RowParallel), the column dimension is split across ranks. After all_gather, the concatenated scale tensor might have more blocks than required by the global weight if the local column count is not a multiple of fp8_block_size. For example, if total_cols=258, tp_size=2, and block_size=128, each rank has 129 columns, which requires 2 blocks locally. Gathering them results in 4 blocks, but the global weight only needs ceil(258/128) = 3 blocks. The slicing should be restored to ensure the exported scale tensor matches the expected format.
| tensor = tensor.view(torch.float8_e4m3fn) | |
| tensor = tensor.view(torch.float8_e4m3fn) | |
| if mg_scale_inv is not None: | |
| mg_scale_inv = mg_scale_inv[..., :math.ceil(tensor.shape[-1] / self.fp8_block_size)].contiguous() |
No description provided.