Add Apple Silicon (MPS) support to the FP8 inference path#19
Open
sammcj wants to merge 5 commits into
Open
Conversation
MPS can neither store nor cast float8_e4m3fn, and supports neither float64 nor ndtri/expit, so the weight-only FP8 path crashed twice on MPS: first dequantizing Linear weights, then in the logit-normal sampler. Both are now handled, with the CUDA/CPU paths unchanged by construction. - quantized_loading.py: add device_supports_fp8(); on MPS, Fp8Linear holds an already-dequantized bf16 weight (no scale buffer) and load_fp8_state_dict dequantizes each FP8 weight via its per-row scale at load time, dropping the now-unused .weight_scale keys. CUDA/CPU keep the float8 storage path (store_fp8=True) byte-for-byte. - scheduler.py: run the float64 warp (ndtri/expit) on CPU for MPS only, returning float32 on the caller's device; CUDA/CPU keep the original on-device path. - pipeline_ideogram4.py: thread the target device into swap_linears_to_fp8 at both call sites.
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR improves Apple Silicon (MPS) compatibility by routing unsupported float64/special-function operations and FP8 weight handling through CPU/dequantized paths while keeping the public pipeline behavior consistent.
Changes:
- Run the LogitNormalSchedule “warp” on CPU for MPS and return results back to the original device.
- Add MPS-aware FP8 loading: on MPS, dequantize FP8 weights at load time and drop
.weight_scaleentries. - Thread
devicethrough FP8 linear swapping so layer layouts match the chosen storage strategy.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| src/ideogram4/scheduler.py | Adds an MPS CPU fallback for float64 special functions and restores output to the original device. |
| src/ideogram4/quantized_loading.py | Introduces device capability gating for FP8 storage and MPS dequantization-at-load behavior. |
| src/ideogram4/pipeline_ideogram4.py | Passes device into FP8 swapping so modules match the intended FP8 storage/dequantization path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- device_supports_fp8 now probes the backend at runtime (cached per device type) instead of hard-coding "not mps", so it adapts to other backends that may lack float8 storage/casting. - load_fp8_state_dict dequant path moves the fp8 weight and scale to CPU explicitly before the float32 cast and multiply, matching the comment and guarding against non-CPU state-dict tensors. Scheduler CPU-roundtrip review comments were left as-is: the warp runs on a single-element tensor that the caller immediately scalarises via .item(), so the transfer is negligible (measured ~6.8s/step, transformer-bound, at V4_QUALITY_48).
step_intervals is fixed for the whole sampling loop, so warp it through LogitNormalSchedule once before the loop and index the results, rather than calling the schedule twice per step. Output is byte-identical; this hoists the loop-invariant scalar work (and its per-step host syncs) out of the inner loop. Addresses the scheduler.py review comments on the MPS CPU roundtrip.
…e cache, explicit missing-scale error - pipeline_ideogram4: warp the whole step_intervals tensor through the schedule in one vectorized call (.tolist()) instead of per-element calls; the values are byte-identical so output is unchanged. - quantized_loading: key the fp8 capability probe cache by concrete device (e.g. cuda:0) rather than device type, so a heterogeneous multi-device setup is probed individually. - quantized_loading: raise a clear RuntimeError naming the weight and the missing scale key when an FP8 weight has no matching .weight_scale, instead of a bare KeyError.
…t scale-key guard - pipeline_ideogram4: pass a CPU tensor to the schedule precompute so the warp result isn't bounced back to the device just for tolist() to pull it off again; values are byte-identical. - scheduler: document __call__ as an inference-time, non-differentiable helper (the sampler runs under no_grad and reads the values out as Python scalars). - quantized_loading: reword device_supports_fp8 docstring so the runtime probe is authoritative rather than asserting CUDA/CPU always pass. - quantized_loading: guard that an FP8 weight key ends with '.weight' before deriving the scale key, raising a clear RuntimeError otherwise.
|
here for this. thanks! |
1 task
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The weight-only FP8 path crashed on Apple Silicon (MPS) because PyTorch's MPS backend cannot store or cast float8_e4m3fn, and supports neither float64 nor the ndtri/expit special functions the sampler uses.
The README and run_inference.py --help already claim FP8 "runs on any device (no FP8 hardware needed)", but MPS hit two hard TypeErrors. This change makes that claim true while leaving the CUDA and CPU paths unchanged by construction.
Fixes #18
Problem
Two separate MPS dtype walls, in order of failure:
Changes
Backwards compatibility
Testing
Performance / resource notes (MPS)