chore: optimize metal backend performance#1669
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a contiguous KV cache fast path for Metal (MLX) to optimize performance during low-concurrency dense serving. Key enhancements include native MLX sampling for greedy and random strategies, optimized logit projection to reduce computation, and one-token lookahead prefetching. The PR also implements caching for metadata arrays and block tables within the PagedAttentionContext. Feedback indicates that the native sampling path currently ignores active logits processors, which could bypass penalties or constraints. Additionally, a logic error was identified in the native random sampling implementation that incorrectly handles top_k masking in mixed batches.
| if not ( | ||
| batch.can_use_native_greedy_for_batch() | ||
| or batch.can_use_native_random_for_batch() | ||
| ): | ||
| return None |
There was a problem hiding this comment.
The native MLX sampling path (both greedy and random) currently bypasses any active logitsprocs. This means custom plugins, penalties, or constraints implemented as logits processors will be ignored when the fast path is taken. You should check if self._logitsprocs.all is empty before allowing the native path.
if (
self._logitsprocs.all
or not (
batch.can_use_native_greedy_for_batch()
or batch.can_use_native_random_for_batch()
)
):| if not self.all_random: | ||
| return False |
There was a problem hiding this comment.
| if max_top_k < batch.vocab_size: | ||
| topk_indices = mx.argpartition(-logits, max_top_k - 1, axis=-1)[ | ||
| :, :max_top_k | ||
| ] | ||
| logits = mx.take_along_axis(logits, topk_indices, axis=-1) | ||
| if len(set(top_ks)) != 1: | ||
| positions = mx.arange(max_top_k)[None, :] | ||
| row_top_ks = mx.array(top_ks, dtype=mx.int32)[:, None] | ||
| logits = mx.where(positions < row_top_ks, logits, -float("inf")) | ||
|
|
||
| if not batch.no_top_p: | ||
| sorted_positions = mx.argsort(-logits, axis=-1) | ||
| sorted_logits = mx.take_along_axis(logits, sorted_positions, axis=-1) | ||
| sorted_indices = mx.take_along_axis( | ||
| topk_indices, sorted_positions, axis=-1 | ||
| ) | ||
| sorted_probs = mx.softmax(sorted_logits, axis=-1) | ||
| top_ps = mx.array( | ||
| [ | ||
| sampling_params.top_p | ||
| for sampling_params in batch.sampling_params_list | ||
| ], | ||
| dtype=mx.float32, | ||
| )[:, None] | ||
| # Keep the first token that crosses top-p, matching nucleus | ||
| # sampling's usual "cumulative probability before this token" | ||
| # test. | ||
| remove = (mx.cumsum(sorted_probs, axis=-1) - sorted_probs) > top_ps | ||
| sorted_logits = mx.where(remove, -float("inf"), sorted_logits) | ||
| sampled_positions = mx.random.categorical(sorted_logits, axis=-1) | ||
| return mx.take_along_axis( | ||
| sorted_indices, sampled_positions[:, None], axis=-1 | ||
| )[:, 0] | ||
|
|
||
| sampled_positions = mx.random.categorical(logits, axis=-1) | ||
| return mx.take_along_axis( | ||
| topk_indices, sampled_positions[:, None], axis=-1 | ||
| )[:, 0] | ||
|
|
||
| topk_values = mx.topk(logits, max_top_k, axis=-1) | ||
| topk_thresholds = mx.min(topk_values, axis=-1, keepdims=True) | ||
| logits = mx.where(logits < topk_thresholds, -float("inf"), logits) |
There was a problem hiding this comment.
The else block (lines 641-643) is logically broken for mixed batches where some requests have top_k enabled and others have it disabled (set to vocab_size). In such cases, max_top_k becomes vocab_size, the if block is skipped, and the else block calculates a threshold based on the minimum logit of the entire row, effectively disabling top_k for all requests in the batch.
You should remove the if max_top_k < batch.vocab_size check and always use the argpartition logic if not batch.no_top_k. mx.argpartition handles k = vocab_size - 1 correctly, and the subsequent mx.where (line 610) will correctly apply the per-request top_k masking.
topk_indices = mx.argpartition(-logits, max_top_k - 1, axis=-1)[
:, :max_top_k
]
logits = mx.take_along_axis(logits, topk_indices, axis=-1)
if len(set(top_ks)) != 1 or max_top_k == batch.vocab_size:
positions = mx.arange(max_top_k)[None, :]
row_top_ks = mx.array(top_ks, dtype=mx.int32)[:, None]
logits = mx.where(positions < row_top_ks, logits, -float("inf"))
if not batch.no_top_p:
sorted_positions = mx.argsort(-logits, axis=-1)
sorted_logits = mx.take_along_axis(logits, sorted_positions, axis=-1)
sorted_indices = mx.take_along_axis(
topk_indices, sorted_positions, axis=-1
)
sorted_probs = mx.softmax(sorted_logits, axis=-1)
top_ps = mx.array(
[
sampling_params.top_p
for sampling_params in batch.sampling_params_list
],
dtype=mx.float32,
)[:, None]
remove = (mx.cumsum(sorted_probs, axis=-1) - sorted_probs) > top_ps
sorted_logits = mx.where(remove, -float("inf"), sorted_logits)
sampled_positions = mx.random.categorical(sorted_logits, axis=-1)
return mx.take_along_axis(
sorted_indices, sampled_positions[:, None], axis=-1
)[:, 0]
sampled_positions = mx.random.categorical(logits, axis=-1)
return mx.take_along_axis(
topk_indices, sampled_positions[:, None], axis=-1
)[:, 0]
Qwen3-0.6B, M4 Pro.
Before:
After: