From 0ed50a1a8792d9b231effdfd81e484a196c0b7ae Mon Sep 17 00:00:00 2001 From: James Huang Date: Tue, 10 Mar 2026 20:08:27 +0000 Subject: [PATCH] Implement CFG cache for Wan 2.1 Signed-off-by: James Huang --- src/maxdiffusion/configs/base_wan_14b.yml | 5 + src/maxdiffusion/configs/base_wan_1_3b.yml | 3 + src/maxdiffusion/configs/base_wan_27b.yml | 3 + src/maxdiffusion/configs/base_wan_i2v_14b.yml | 3 + src/maxdiffusion/configs/base_wan_i2v_27b.yml | 3 + src/maxdiffusion/generate_wan.py | 1 + .../pipelines/wan/wan_pipeline.py | 103 +++++++ .../pipelines/wan/wan_pipeline_2_1.py | 147 ++++++++-- src/maxdiffusion/tests/wan_cfg_cache_test.py | 274 ++++++++++++++++++ 9 files changed, 524 insertions(+), 18 deletions(-) create mode 100644 src/maxdiffusion/tests/wan_cfg_cache_test.py diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 91a3e092a..fa6309610 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -324,6 +324,11 @@ num_frames: 81 guidance_scale: 5.0 flow_shift: 3.0 +# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only) +# Skips the unconditional forward pass on ~35% of steps via residual compensation. +# See: FasterCache (Lv et al. 2024), WAN 2.1 paper §4.4.2 +use_cfg_cache: False + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 30 diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index ffd2864a8..5f20d11dc 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -280,6 +280,9 @@ num_frames: 81 guidance_scale: 5.0 flow_shift: 3.0 +# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only) +use_cfg_cache: False + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 30 diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 022b18c91..6d06218cc 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -302,6 +302,9 @@ guidance_scale_high: 4.0 # timestep to switch between low noise and high noise transformer boundary_ratio: 0.875 +# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only) +use_cfg_cache: False + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 30 diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index 2a5b0338c..d0c1a0140 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -286,6 +286,9 @@ num_frames: 81 guidance_scale: 5.0 flow_shift: 5.0 +# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only) +use_cfg_cache: False + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 50 diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 0bd6a27f2..93ab8ce32 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -298,6 +298,9 @@ guidance_scale_high: 4.0 # timestep to switch between low noise and high noise transformer boundary_ratio: 0.875 +# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only) +use_cfg_cache: False + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 50 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index f53cc59b6..828bc1a2c 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -125,6 +125,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, + use_cfg_cache=config.use_cfg_cache, ) elif model_key == WAN2_2: return pipeline( diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 7c0314b40..86c9f9c2e 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -778,3 +778,106 @@ def transformer_forward_pass( latents = latents[:bsz] return noise_pred, latents + + +@partial(jax.jit, static_argnames=("guidance_scale",)) +def transformer_forward_pass_full_cfg( + graphdef, + sharded_state, + rest_of_state, + latents_doubled: jnp.array, + timestep: jnp.array, + prompt_embeds_combined: jnp.array, + guidance_scale: float, + encoder_hidden_states_image=None, +): + """Full CFG forward pass. + + Accepts pre-doubled latents and pre-concatenated [cond, uncond] prompt embeds. + Returns the merged noise_pred plus raw noise_cond and noise_uncond for + CFG cache storage. Keeping cond/uncond separate avoids a second forward + pass on cache steps. + """ + wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) + bsz = latents_doubled.shape[0] // 2 + noise_pred = wan_transformer( + hidden_states=latents_doubled, + timestep=timestep, + encoder_hidden_states=prompt_embeds_combined, + encoder_hidden_states_image=encoder_hidden_states_image, + ) + noise_cond = noise_pred[:bsz] + noise_uncond = noise_pred[bsz:] + noise_pred_merged = noise_uncond + guidance_scale * (noise_cond - noise_uncond) + return noise_pred_merged, noise_cond, noise_uncond + + +@partial(jax.jit, static_argnames=("guidance_scale",)) +def transformer_forward_pass_cfg_cache( + graphdef, + sharded_state, + rest_of_state, + latents_cond: jnp.array, + timestep_cond: jnp.array, + prompt_cond_embeds: jnp.array, + cached_noise_cond: jnp.array, + cached_noise_uncond: jnp.array, + guidance_scale: float, + w1: float = 1.0, + w2: float = 1.0, + encoder_hidden_states_image=None, +): + """CFG-Cache forward pass with FFT frequency-domain compensation. + + FasterCache (Lv et al., ICLR 2025) CFG-Cache: + 1. Compute frequency-domain bias: ΔF = FFT(uncond) - FFT(cond) + 2. Split into low-freq (ΔLF) and high-freq (ΔHF) via spectral mask + 3. Apply phase-dependent weights: + F_low = FFT(new_cond)_low + w1 * ΔLF + F_high = FFT(new_cond)_high + w2 * ΔHF + 4. Reconstruct: uncond_approx = IFFT(F_low + F_high) + + w1/w2 encode the denoising phase: + Early (high noise): w1=1+α, w2=1 → boost low-freq correction + Late (low noise): w1=1, w2=1+α → boost high-freq correction + where α=0.2 (FasterCache default). + + On TPU this compiles to a single static XLA graph with half the batch size + of a full CFG pass. + """ + wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) + noise_cond = wan_transformer( + hidden_states=latents_cond, + timestep=timestep_cond, + encoder_hidden_states=prompt_cond_embeds, + encoder_hidden_states_image=encoder_hidden_states_image, + ) + + # FFT over spatial dims (H, W) — last 2 dims of [B, C, F, H, W] + fft_cond_cached = jnp.fft.rfft2(cached_noise_cond.astype(jnp.float32)) + fft_uncond_cached = jnp.fft.rfft2(cached_noise_uncond.astype(jnp.float32)) + fft_bias = fft_uncond_cached - fft_cond_cached + + # Build low/high frequency mask (25% cutoff) + h = fft_bias.shape[-2] + w_rfft = fft_bias.shape[-1] + ch = jnp.maximum(1, h // 4) + cw = jnp.maximum(1, w_rfft // 4) + freq_h = jnp.arange(h) + freq_w = jnp.arange(w_rfft) + # Low-freq: indices near DC (0) in both dims; account for wrap-around in dim H + low_h = (freq_h < ch) | (freq_h >= h - ch + 1) + low_w = freq_w < cw + low_mask = (low_h[:, None] & low_w[None, :]).astype(jnp.float32) + high_mask = 1.0 - low_mask + + # Apply phase-dependent weights to frequency bias + fft_bias_weighted = fft_bias * (low_mask * w1 + high_mask * w2) + + # Reconstruct unconditional output + fft_cond_new = jnp.fft.rfft2(noise_cond.astype(jnp.float32)) + fft_uncond_approx = fft_cond_new + fft_bias_weighted + noise_uncond_approx = jnp.fft.irfft2(fft_uncond_approx, s=noise_cond.shape[-2:]).astype(noise_cond.dtype) + + noise_pred_merged = noise_uncond_approx + guidance_scale * (noise_cond - noise_uncond_approx) + return noise_pred_merged, noise_cond diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index c247facb5..976f0f042 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .wan_pipeline import WanPipeline, transformer_forward_pass +from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache from ...models.wan.transformers.transformer_wan import WanModel from typing import List, Union, Optional from ...pyconfig import HyperParameters @@ -90,7 +90,14 @@ def __call__( prompt_embeds: Optional[jax.Array] = None, negative_prompt_embeds: Optional[jax.Array] = None, vae_only: bool = False, + use_cfg_cache: bool = False, ): + if use_cfg_cache and guidance_scale <= 1.0: + raise ValueError( + f"use_cfg_cache=True requires guidance_scale > 1.0 (got {guidance_scale}). " + "CFG cache accelerates classifier-free guidance, which is disabled when guidance_scale <= 1.0." + ) + latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs( prompt, negative_prompt, @@ -114,6 +121,8 @@ def __call__( num_inference_steps=num_inference_steps, scheduler=self.scheduler, scheduler_state=scheduler_state, + use_cfg_cache=use_cfg_cache, + height=height, ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): @@ -140,26 +149,128 @@ def run_inference_2_1( num_inference_steps: int, scheduler: FlaxUniPCMultistepScheduler, scheduler_state, + use_cfg_cache: bool = False, + height: int = 480, ): - do_classifier_free_guidance = guidance_scale > 1.0 - if do_classifier_free_guidance: - prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + """Denoising loop for WAN 2.1 T2V with FasterCache CFG-Cache. + + CFG-Cache strategy (Lv et al., ICLR 2025, enabled via use_cfg_cache=True): + - Full CFG steps : run transformer on [cond, uncond] batch (batch×2). + Cache raw noise_cond and noise_uncond for FFT bias. + - Cache steps : run transformer on cond batch only (batch×1). + Estimate uncond via FFT frequency-domain compensation: + ΔF = FFT(cached_uncond) - FFT(cached_cond) + Split ΔF into low-freq (ΔLF) and high-freq (ΔHF). + uncond_approx = IFFT(FFT(new_cond) + w1*ΔLF + w2*ΔHF) + Phase-dependent weights (α=0.2): + Early (high noise): w1=1.2, w2=1.0 (boost low-freq) + Late (low noise): w1=1.0, w2=1.2 (boost high-freq) + - Schedule : full CFG for the first 1/3 of steps, then + full CFG every 5 steps, cache the rest. + + Two separately-compiled JAX-jitted functions handle full and cache steps so + XLA sees static shapes throughout — the key requirement for TPU efficiency. + """ + do_cfg = guidance_scale > 1.0 + bsz = latents.shape[0] + + # Resolution-dependent CFG cache config (FasterCache / MixCache guidance) + if height >= 720: + # 720p: conservative — protect last 40%, interval=5 + cfg_cache_interval = 5 + cfg_cache_start_step = int(num_inference_steps / 3) + cfg_cache_end_step = int(num_inference_steps * 0.9) + cfg_cache_alpha = 0.2 + else: + # 480p: moderate — protect last 2 steps, interval=5 + cfg_cache_interval = 5 + cfg_cache_start_step = int(num_inference_steps / 3) + cfg_cache_end_step = num_inference_steps - 2 + cfg_cache_alpha = 0.2 + + # Pre-split embeds once, outside the loop. + prompt_cond_embeds = prompt_embeds + prompt_embeds_combined = None + if do_cfg: + prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + + # Pre-compute cache schedule and phase-dependent weights. + # t₀ = midpoint step; before t₀ boost low-freq, after boost high-freq. + t0_step = num_inference_steps // 2 + first_full_step_seen = False + step_is_cache = [] + step_w1w2 = [] + for s in range(num_inference_steps): + is_cache = ( + use_cfg_cache + and do_cfg + and first_full_step_seen + and s >= cfg_cache_start_step + and s < cfg_cache_end_step + and (s - cfg_cache_start_step) % cfg_cache_interval != 0 + ) + step_is_cache.append(is_cache) + if not is_cache: + first_full_step_seen = True + # Phase-dependent weights: w = 1 + α·I(condition) + if s < t0_step: + step_w1w2.append((1.0 + cfg_cache_alpha, 1.0)) # early: boost low-freq + else: + step_w1w2.append((1.0, 1.0 + cfg_cache_alpha)) # late: boost high-freq + + # Cache tensors (on-device JAX arrays, initialised to None). + cached_noise_cond = None + cached_noise_uncond = None + for step in range(num_inference_steps): t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - if do_classifier_free_guidance: - latents = jnp.concatenate([latents] * 2) - timestep = jnp.broadcast_to(t, latents.shape[0]) - - noise_pred, latents = transformer_forward_pass( - graphdef, - sharded_state, - rest_of_state, - latents, - timestep, - prompt_embeds, - do_classifier_free_guidance=do_classifier_free_guidance, - guidance_scale=guidance_scale, - ) + is_cache_step = step_is_cache[step] + + if is_cache_step: + # ── Cache step: cond-only forward + FFT frequency compensation ── + w1, w2 = step_w1w2[step] + timestep = jnp.broadcast_to(t, bsz) + noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache( + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, + prompt_cond_embeds, + cached_noise_cond, + cached_noise_uncond, + guidance_scale=guidance_scale, + w1=jnp.float32(w1), + w2=jnp.float32(w2), + ) + + elif do_cfg: + # ── Full CFG step: doubled batch, store raw cond/uncond for cache ── + latents_doubled = jnp.concatenate([latents] * 2) + timestep = jnp.broadcast_to(t, bsz * 2) + noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg( + graphdef, + sharded_state, + rest_of_state, + latents_doubled, + timestep, + prompt_embeds_combined, + guidance_scale=guidance_scale, + ) + + else: + # ── No CFG (guidance_scale <= 1.0) ── + timestep = jnp.broadcast_to(t, bsz) + noise_pred, latents = transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, + prompt_cond_embeds, + do_classifier_free_guidance=False, + guidance_scale=guidance_scale, + ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents diff --git a/src/maxdiffusion/tests/wan_cfg_cache_test.py b/src/maxdiffusion/tests/wan_cfg_cache_test.py new file mode 100644 index 000000000..3543cf691 --- /dev/null +++ b/src/maxdiffusion/tests/wan_cfg_cache_test.py @@ -0,0 +1,274 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import time +import unittest + +import numpy as np +import pytest +from absl.testing import absltest + +from maxdiffusion.pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1 + +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +class WanCfgCacheValidationTest(unittest.TestCase): + """Tests that use_cfg_cache=True with guidance_scale <= 1.0 raises ValueError.""" + + def _make_pipeline(self): + """Create a WanPipeline2_1 instance with mocked internals.""" + pipeline = WanPipeline2_1.__new__(WanPipeline2_1) + return pipeline + + def test_cfg_cache_with_guidance_scale_1_raises(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + guidance_scale=1.0, + use_cfg_cache=True, + ) + self.assertIn("use_cfg_cache", str(ctx.exception)) + + def test_cfg_cache_with_guidance_scale_0_raises(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + guidance_scale=0.0, + use_cfg_cache=True, + ) + self.assertIn("use_cfg_cache", str(ctx.exception)) + + def test_cfg_cache_with_valid_guidance_scale_no_validation_error(self): + """guidance_scale > 1.0 should pass validation (may fail later without model).""" + pipeline = self._make_pipeline() + try: + pipeline( + prompt=["test"], + guidance_scale=5.0, + use_cfg_cache=True, + ) + except ValueError as e: + if "use_cfg_cache" in str(e): + self.fail(f"Unexpected validation error: {e}") + except Exception: + # Other errors expected (no model loaded). + pass + + def test_no_cfg_cache_with_low_guidance_no_error(self): + """use_cfg_cache=False should never raise our ValueError regardless of guidance_scale.""" + pipeline = self._make_pipeline() + try: + pipeline( + prompt=["test"], + guidance_scale=0.5, + use_cfg_cache=False, + ) + except ValueError as e: + if "use_cfg_cache" in str(e): + self.fail(f"Unexpected validation error: {e}") + except Exception: + pass + + +class WanCfgCacheScheduleTest(unittest.TestCase): + """Tests that CFG cache schedule produces the correct full/cache step pattern. + + Verifies the schedule logic in run_inference_2_1 without running any model. + """ + + def _get_cache_schedule(self, num_inference_steps, height=480): + """Extract the cache schedule from run_inference_2_1's logic. + + Mirrors the schedule computation in run_inference_2_1 to verify correctness. + """ + if height >= 720: + cfg_cache_interval = 5 + cfg_cache_start_step = int(num_inference_steps / 3) + cfg_cache_end_step = int(num_inference_steps * 0.9) + else: + cfg_cache_interval = 5 + cfg_cache_start_step = int(num_inference_steps / 3) + cfg_cache_end_step = num_inference_steps - 2 + + first_full_step_seen = False + schedule = [] + for s in range(num_inference_steps): + is_cache = ( + first_full_step_seen + and s >= cfg_cache_start_step + and s < cfg_cache_end_step + and (s - cfg_cache_start_step) % cfg_cache_interval != 0 + ) + schedule.append(is_cache) + if not is_cache: + first_full_step_seen = True + return schedule + + def test_480p_50_steps_schedule(self): + """480p, 50 steps: cache starts at step 16, ends at step 48.""" + schedule = self._get_cache_schedule(50, height=480) + self.assertEqual(len(schedule), 50) + # First 16 steps should all be full CFG + self.assertTrue(all(not s for s in schedule[:16])) + # Last 2 steps should be full CFG + self.assertTrue(all(not s for s in schedule[48:])) + # There should be some cache steps in the middle + cache_count = sum(schedule) + self.assertGreater(cache_count, 0, "Should have cache steps in 480p/50 steps") + + def test_720p_50_steps_schedule(self): + """720p, 50 steps: more conservative — cache ends at step 45.""" + schedule = self._get_cache_schedule(50, height=720) + self.assertEqual(len(schedule), 50) + # First 16 steps should all be full CFG + self.assertTrue(all(not s for s in schedule[:16])) + # Last 10% of steps (45-49) should be full CFG + self.assertTrue(all(not s for s in schedule[45:])) + cache_count = sum(schedule) + self.assertGreater(cache_count, 0, "Should have cache steps in 720p/50 steps") + + def test_720p_has_fewer_cache_steps_than_480p(self): + """720p should be more conservative (fewer cache steps) than 480p.""" + schedule_480 = self._get_cache_schedule(50, height=480) + schedule_720 = self._get_cache_schedule(50, height=720) + self.assertGreater(sum(schedule_480), sum(schedule_720)) + + def test_cache_interval_is_5(self): + """Every 5th step after start should be a full CFG step (not cached).""" + schedule = self._get_cache_schedule(50, height=480) + start = int(50 / 3) # 16 + end = 48 + for s in range(start, end): + if (s - start) % 5 == 0: + self.assertFalse(schedule[s], f"Step {s} should be full CFG (interval=5)") + + def test_short_run_no_cache(self): + """Very few steps should have no cache steps.""" + schedule = self._get_cache_schedule(3, height=480) + self.assertEqual(sum(schedule), 0, "3 steps is too short for cache") + + +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Requires TPU v7-8 and model weights") +class WanCfgCacheSmokeTest(unittest.TestCase): + """End-to-end smoke test: CFG cache should be faster with SSIM >= 0.95. + + Runs on TPU v7-8 (8 chips, context_parallelism=8) with WAN 2.1 14B, 720p. + Skipped in CI (GitHub Actions) — run locally with: + python -m pytest src/maxdiffusion/tests/wan_cfg_cache_test.py::WanCfgCacheSmokeTest -v + """ + + @classmethod + def setUpClass(cls): + from maxdiffusion import pyconfig + from maxdiffusion.checkpointing.wan_checkpointer_2_1 import WanCheckpointer2_1 + + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + "num_inference_steps=50", + "height=720", + "width=1280", + "num_frames=81", + "fps=24", + "guidance_scale=5.0", + "flow_shift=3.0", + "seed=11234567893", + "attention=flash", + "remat_policy=FULL", + "allow_split_physical_axes=True", + "skip_jax_distributed_system=True", + "weights_dtype=bfloat16", + "activations_dtype=bfloat16", + "per_device_batch_size=0.125", + "ici_data_parallelism=1", + "ici_fsdp_parallelism=1", + "ici_context_parallelism=8", + "ici_tensor_parallelism=1", + "flash_min_seq_length=0", + 'flash_block_sizes={"block_q": 2048, "block_kv_compute": 1024, "block_kv": 2048, "block_q_dkv": 2048, "block_kv_dkv": 2048, "block_kv_dkv_compute": 2048, "use_fused_bwd_kernel": true}', + ], + unittest=True, + ) + cls.config = pyconfig.config + checkpoint_loader = WanCheckpointer2_1(config=cls.config) + cls.pipeline, _, _ = checkpoint_loader.load_checkpoint() + + cls.prompt = [cls.config.prompt] * cls.config.global_batch_size_to_train_on + cls.negative_prompt = [cls.config.negative_prompt] * cls.config.global_batch_size_to_train_on + + # Warmup both XLA code paths + for use_cache in [False, True]: + cls.pipeline( + prompt=cls.prompt, + negative_prompt=cls.negative_prompt, + height=cls.config.height, + width=cls.config.width, + num_frames=cls.config.num_frames, + num_inference_steps=cls.config.num_inference_steps, + guidance_scale=cls.config.guidance_scale, + use_cfg_cache=use_cache, + ) + + def _run_pipeline(self, use_cfg_cache): + t0 = time.perf_counter() + videos = self.pipeline( + prompt=self.prompt, + negative_prompt=self.negative_prompt, + height=self.config.height, + width=self.config.width, + num_frames=self.config.num_frames, + num_inference_steps=self.config.num_inference_steps, + guidance_scale=self.config.guidance_scale, + use_cfg_cache=use_cfg_cache, + ) + return videos, time.perf_counter() - t0 + + def test_cfg_cache_speedup_and_fidelity(self): + """CFG cache must be faster than baseline with mean SSIM >= 0.95.""" + videos_baseline, t_baseline = self._run_pipeline(use_cfg_cache=False) + videos_cached, t_cached = self._run_pipeline(use_cfg_cache=True) + + # Speed check + speedup = t_baseline / t_cached + print(f"Baseline: {t_baseline:.2f}s, CFG cache: {t_cached:.2f}s, Speedup: {speedup:.3f}x") + self.assertGreater(speedup, 1.0, f"CFG cache should be faster. Speedup={speedup:.3f}x") + + # Fidelity check (per-frame SSIM) + v1 = np.array(videos_baseline[0], dtype=np.float64) + v2 = np.array(videos_cached[0], dtype=np.float64) + + C1, C2 = 0.01**2, 0.03**2 + ssim_scores = [] + for f in range(v1.shape[0]): + mu1, mu2 = np.mean(v1[f]), np.mean(v2[f]) + sigma1_sq, sigma2_sq = np.var(v1[f]), np.var(v2[f]) + sigma12 = np.mean((v1[f] - mu1) * (v2[f] - mu2)) + ssim = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2)) + ssim_scores.append(float(ssim)) + + mean_ssim = np.mean(ssim_scores) + print(f"SSIM: mean={mean_ssim:.4f}, min={np.min(ssim_scores):.4f}") + self.assertGreaterEqual(mean_ssim, 0.95, f"Mean SSIM={mean_ssim:.4f} < 0.95") + + +if __name__ == "__main__": + absltest.main()