Skip to content

Add Apple Silicon (MPS) support to the FP8 inference path#19

Open
sammcj wants to merge 5 commits into
ideogram-oss:mainfrom
sammcj:mps_support
Open

Add Apple Silicon (MPS) support to the FP8 inference path#19
sammcj wants to merge 5 commits into
ideogram-oss:mainfrom
sammcj:mps_support

Conversation

@sammcj

@sammcj sammcj commented Jun 6, 2026

Copy link
Copy Markdown

The weight-only FP8 path crashed on Apple Silicon (MPS) because PyTorch's MPS backend cannot store or cast float8_e4m3fn, and supports neither float64 nor the ndtri/expit special functions the sampler uses.

The README and run_inference.py --help already claim FP8 "runs on any device (no FP8 hardware needed)", but MPS hit two hard TypeErrors. This change makes that claim true while leaving the CUDA and CPU paths unchanged by construction.

Fixes #18

Problem

Two separate MPS dtype walls, in order of failure:

  1. Fp8Linear stores weights as float8_e4m3fn; on MPS, moving/casting that dtype raises TypeError: Trying to convert Float8_e4m3fn to the MPS backend....
  2. LogitNormalSchedule runs its logit-normal warp in float64; on MPS, t.to(torch.float64) raises Cannot convert a MPS Tensor to float64 dtype....

Changes

  • quantized_loading.py: add device_supports_fp8(). On devices that can't store float8 (MPS), Fp8Linear holds an already-dequantized bf16 weight (no scale buffer) and load_fp8_state_dict dequantizes each FP8 weight via its per-row scale at load time, dropping the now-unused .weight_scale keys. The half-size FP8 checkpoint is still downloaded; only the in-memory weight is expanded to bf16.
  • scheduler.py: on MPS only, run the float64 warp on CPU and return float32 on the caller's device (a fused .to(cpu, float64) still casts on the MPS side first, so the move and cast are split).
  • pipeline_ideogram4.py: thread the target device into swap_linears_to_fp8 at both call sites.

Backwards compatibility

  • CUDA and CPU take store_fp8=True (the default), so buffer registration, state-dict loading, and the forward pass are identical to before; the only added op on that path is an assert, stripped under python -O.
  • The scheduler change is gated on device.type == "mps"; CUDA/CPU keep the original on-device float64 path.
  • The nf4 (CUDA bitsandbytes) path is untouched.
  • swap_linears_to_fp8 / load_fp8_state_dict / Fp8Linear are internal (not exported); the only callers are the two updated sites plus the function's own recursion.

Testing

  • Hardware: Apple M5 Max (128GB), macOS 26, PyTorch 2.12, ideogram-ai/ideogram-4-fp8.
  • End-to-end generation on MPS now succeeds at 1024x1024 across V4_TURBO_12, V4_DEFAULT_20, and V4_QUALITY_48, producing correct images (verified visually with a full structured MagicPrompt caption).
  • Numerics checked clean (no NaN/Inf) in the text-encoder conditioning features and both transformer outputs during sampling.
  • Lint/type gates pass: ruff check, ruff format --check, and mypy (project config) are clean on the changed files.
  • CUDA was not run (no CUDA hardware available); the CUDA/CPU paths are unchanged by construction, so a single CUDA smoke run from a maintainer would confirm no regression.

Performance / resource notes (MPS)

  • PyTorch's MPS backend can't run the FP8 weights natively (it can't store float8), so they're dequantized to bf16 and executed as bf16 matmuls. The quantization saves download/disk but gives no compute speedup on MPS, unlike fp8/int8 tensor-core matmuls on CUDA. Measured ~7s/step is bf16-bound., so this is the practical floor on Apple Silicon.
  • Load/dequant is single-threaded on CPU and takes ~1-2 min; sampling then runs on the GPU at ~7s/step.
  • Measured sampling time (1024x1024, after the one-time load): V4_TURBO_12 ~85s, V4_DEFAULT_20 ~136s, V4_QUALITY_48 ~325s.

MPS can neither store nor cast float8_e4m3fn, and supports neither float64 nor ndtri/expit, so the weight-only FP8 path crashed twice on MPS: first dequantizing Linear weights, then in the logit-normal sampler. Both are now handled, with the CUDA/CPU paths unchanged by construction.

- quantized_loading.py: add device_supports_fp8(); on MPS, Fp8Linear holds an already-dequantized bf16 weight (no scale buffer) and load_fp8_state_dict dequantizes each FP8 weight via its per-row scale at load time, dropping the now-unused .weight_scale keys. CUDA/CPU keep the float8 storage path (store_fp8=True) byte-for-byte.
- scheduler.py: run the float64 warp (ndtri/expit) on CPU for MPS only, returning float32 on the caller's device; CUDA/CPU keep the original on-device path.
- pipeline_ideogram4.py: thread the target device into swap_linears_to_fp8 at both call sites.
Copilot AI review requested due to automatic review settings June 6, 2026 00:01

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR improves Apple Silicon (MPS) compatibility by routing unsupported float64/special-function operations and FP8 weight handling through CPU/dequantized paths while keeping the public pipeline behavior consistent.

Changes:

  • Run the LogitNormalSchedule “warp” on CPU for MPS and return results back to the original device.
  • Add MPS-aware FP8 loading: on MPS, dequantize FP8 weights at load time and drop .weight_scale entries.
  • Thread device through FP8 linear swapping so layer layouts match the chosen storage strategy.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
src/ideogram4/scheduler.py Adds an MPS CPU fallback for float64 special functions and restores output to the original device.
src/ideogram4/quantized_loading.py Introduces device capability gating for FP8 storage and MPS dequantization-at-load behavior.
src/ideogram4/pipeline_ideogram4.py Passes device into FP8 swapping so modules match the intended FP8 storage/dequantization path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/ideogram4/quantized_loading.py Outdated
Comment thread src/ideogram4/quantized_loading.py Outdated
Comment thread src/ideogram4/scheduler.py
Comment thread src/ideogram4/scheduler.py
sammcj added 2 commits June 6, 2026 10:09
- device_supports_fp8 now probes the backend at runtime (cached per device type) instead of hard-coding "not mps", so it adapts to other backends that may lack float8 storage/casting.
- load_fp8_state_dict dequant path moves the fp8 weight and scale to CPU explicitly before the float32 cast and multiply, matching the comment and guarding against non-CPU state-dict tensors.

Scheduler CPU-roundtrip review comments were left as-is: the warp runs on a single-element tensor that the caller immediately scalarises via .item(), so the transfer is negligible (measured ~6.8s/step, transformer-bound, at V4_QUALITY_48).
step_intervals is fixed for the whole sampling loop, so warp it through LogitNormalSchedule once before the loop and index the results, rather than calling the schedule twice per step. Output is byte-identical; this hoists the loop-invariant scalar work (and its per-step host syncs) out of the inner loop. Addresses the scheduler.py review comments on the MPS CPU roundtrip.
Copilot AI review requested due to automatic review settings June 6, 2026 00:47

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

Comment thread src/ideogram4/pipeline_ideogram4.py Outdated
Comment thread src/ideogram4/quantized_loading.py
Comment thread src/ideogram4/quantized_loading.py
…e cache, explicit missing-scale error

- pipeline_ideogram4: warp the whole step_intervals tensor through the schedule in one vectorized call (.tolist()) instead of per-element calls; the values are byte-identical so output is unchanged.
- quantized_loading: key the fp8 capability probe cache by concrete device (e.g. cuda:0) rather than device type, so a heterogeneous multi-device setup is probed individually.
- quantized_loading: raise a clear RuntimeError naming the weight and the missing scale key when an FP8 weight has no matching .weight_scale, instead of a bare KeyError.
@sammcj sammcj requested a review from Copilot June 6, 2026 00:57

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

Comment thread src/ideogram4/scheduler.py
Comment thread src/ideogram4/scheduler.py
Comment thread src/ideogram4/quantized_loading.py
Comment thread src/ideogram4/quantized_loading.py
Comment thread src/ideogram4/pipeline_ideogram4.py Outdated
…t scale-key guard

- pipeline_ideogram4: pass a CPU tensor to the schedule precompute so the warp result isn't bounced back to the device just for tolist() to pull it off again; values are byte-identical.
- scheduler: document __call__ as an inference-time, non-differentiable helper (the sampler runs under no_grad and reads the values out as Python scalars).
- quantized_loading: reword device_supports_fp8 docstring so the runtime probe is authoritative rather than asserting CUDA/CPU always pass.
- quantized_loading: guard that an FP8 weight key ends with '.weight' before deriving the scale key, raising a clear RuntimeError otherwise.
@rsl

rsl commented Jun 6, 2026

Copy link
Copy Markdown

here for this. thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

fp8 on mps FAILED: TypeError Trying to convert Float8_e4m3fn to the MPS backend

3 participants