From d221841b2c3d45ca612512f25bf24ba90dbcb7a5 Mon Sep 17 00:00:00 2001 From: Efrat Taig Date: Wed, 4 Mar 2026 15:02:06 +0000 Subject: [PATCH 1/2] feat(wan): Add prior-based diffusion step skip for ~70% fewer inference steps - Add step_callback, prior_latents, prior_timesteps, prior_sigmas, start_from_step to WanVideoPipeline for resuming from saved latents - Add examples/wanvideo/prior_based_step_skip/ with generate_prior.py, infer_from_prior.py, prior_utils.py, and README - Add --download_example to generate_prior.py for easy onboarding - Add Prior-based step skip section to Wan docs (en + zh) - Supports fixed identity/scene with varying motion (e.g. lip-sync, different actions) Made-with: Cursor --- diffsynth/pipelines/wan_video.py | 29 ++- docs/en/Model_Details/Wan.md | 6 + docs/zh/Model_Details/Wan.md | 6 + .../wanvideo/prior_based_step_skip/README.md | 128 ++++++++++ .../prior_based_step_skip/generate_prior.py | 225 ++++++++++++++++++ .../prior_based_step_skip/infer_from_prior.py | 160 +++++++++++++ .../prior_based_step_skip/prior_utils.py | 144 +++++++++++ 7 files changed, 696 insertions(+), 2 deletions(-) create mode 100644 examples/wanvideo/prior_based_step_skip/README.md create mode 100644 examples/wanvideo/prior_based_step_skip/generate_prior.py create mode 100644 examples/wanvideo/prior_based_step_skip/infer_from_prior.py create mode 100644 examples/wanvideo/prior_based_step_skip/prior_utils.py diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index bbc479e29..2780cb08e 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -2,7 +2,7 @@ import numpy as np from PIL import Image from einops import repeat -from typing import Optional, Union +from typing import Callable, Optional, Union from einops import rearrange import numpy as np from PIL import Image @@ -247,9 +247,22 @@ def __call__( # progress_bar progress_bar_cmd=tqdm, output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized", + # Prior-based step skip: optional callback after each denoising step + step_callback: Optional[Callable[[int, torch.Tensor, torch.Tensor], None]] = None, + # Prior-based step skip: resume from saved latent (requires prior_latents + prior_timesteps) + prior_latents: Optional[torch.Tensor] = None, + prior_timesteps: Optional[torch.Tensor] = None, + prior_sigmas: Optional[torch.Tensor] = None, + start_from_step: Optional[int] = None, ): # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) + + # Prior-based step skip: override latents, timesteps, and sigmas when resuming from prior + if prior_latents is not None and prior_timesteps is not None and start_from_step is not None: + self.scheduler.timesteps = prior_timesteps.to(self.scheduler.timesteps.device) + if prior_sigmas is not None: + self.scheduler.sigmas = prior_sigmas.to(self.scheduler.sigmas.device) # Inputs inputs_posi = { @@ -284,10 +297,18 @@ def __call__( for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + # Prior-based step skip: replace latents with loaded prior + if prior_latents is not None and start_from_step is not None: + inputs_shared["latents"] = prior_latents.to(dtype=self.torch_dtype, device=self.device) + # Denoise self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} - for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timesteps = self.scheduler.timesteps + start_idx = (start_from_step + 1) if start_from_step is not None else 0 + for progress_id, timestep in enumerate(progress_bar_cmd(timesteps)): + if progress_id < start_idx: + continue # Switch DiT if necessary if timestep.item() < switch_DiT_boundary * 1000 and self.dit2 is not None and not models["dit"] is self.dit2: self.load_models_to_device(self.in_iteration_models_2) @@ -312,6 +333,10 @@ def __call__( inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) if "first_frame_latents" in inputs_shared: inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"] + + # Prior-based step skip: call optional callback after each step + if step_callback is not None: + step_callback(progress_id, inputs_shared["latents"].clone(), timestep) # VACE (TODO: remove it) if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None): diff --git a/docs/en/Model_Details/Wan.md b/docs/en/Model_Details/Wan.md index 20e12822b..60907c888 100644 --- a/docs/en/Model_Details/Wan.md +++ b/docs/en/Model_Details/Wan.md @@ -201,6 +201,12 @@ Input parameters for `WanVideoPipeline` inference include: If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. +### Prior-Based Step Skip + +For fixed identity/scene with varying motion (e.g. lip-sync, different actions), early diffusion steps are largely redundant. You can run full inference once, save latents at each step, then resume from a saved latent to run only the remaining steps — ~70% fewer steps, same quality. + +See [prior-based step skip example](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/prior_based_step_skip) for `generate_prior.py` and `infer_from_prior.py`. + ## Model Training Wan series models are uniformly trained through [`examples/wanvideo/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/train.py), and the script parameters include: diff --git a/docs/zh/Model_Details/Wan.md b/docs/zh/Model_Details/Wan.md index 0144bd212..61cc176fe 100644 --- a/docs/zh/Model_Details/Wan.md +++ b/docs/zh/Model_Details/Wan.md @@ -202,6 +202,12 @@ DeepSpeed ZeRO 3 训练:Wan 系列模型支持 DeepSpeed ZeRO 3 训练,将 如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 +### 基于先验的步长跳过 + +当身份/场景固定而仅运动变化(如口型同步、不同动作)时,早期扩散步长大多冗余。可先运行一次完整推理并保存每步的潜在表示,再从保存的潜在表示恢复,仅运行剩余步长 —— 步长减少约 70%,质量相当。 + +参见 [prior-based step skip 示例](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/prior_based_step_skip) 中的 `generate_prior.py` 与 `infer_from_prior.py`。 + ## 模型训练 Wan 系列模型统一通过 [`examples/wanvideo/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/train.py) 进行训练,脚本的参数包括: diff --git a/examples/wanvideo/prior_based_step_skip/README.md b/examples/wanvideo/prior_based_step_skip/README.md new file mode 100644 index 000000000..cfaf290c8 --- /dev/null +++ b/examples/wanvideo/prior_based_step_skip/README.md @@ -0,0 +1,128 @@ +# Prior-Based Diffusion Step Skip + +**~70% fewer inference steps, same quality, zero retraining.** + +When you have a **fixed identity or scene** and only **one aspect varies** (e.g. motion, lip-sync, lighting), early diffusion steps are largely redundant. This module lets you: + +1. **Generate a prior** — Run full inference once, save latents at each step +2. **Infer from prior** — Load a saved latent (e.g. step 6) and run only the remaining 3–4 steps + +## Quick Start + +Scripts work from **repo root** or from this directory. Run from repo root for consistent paths. + +### Step 1: Generate the prior + +**From repo root:** + +```bash +# Download example image and run full inference +python examples/wanvideo/prior_based_step_skip/generate_prior.py \ + --download_example \ + --output_dir ./prior_output \ + --num_inference_steps 10 +``` + +**Or with your own image:** + +```bash +python examples/wanvideo/prior_based_step_skip/generate_prior.py \ + --image path/to/image.jpg \ + --output_dir ./prior_output \ + --num_inference_steps 10 +``` + +**From this directory:** + +```bash +cd examples/wanvideo/prior_based_step_skip + +# With --download_example (downloads to repo root data/) +python generate_prior.py --download_example --output_dir ./prior_output --num_inference_steps 10 + +# Or with your own image +python generate_prior.py --image path/to/image.jpg --output_dir ./prior_output --num_inference_steps 10 +``` + +Output: `./prior_output/run_/` with `step_0000.pt` … `step_0009.pt`, `run_metadata.json`, and `output_full.mp4`. + +### Step 2: Run accelerated inference + +```bash +# From repo root (replace run_ with actual run ID from step 1) +python examples/wanvideo/prior_based_step_skip/infer_from_prior.py \ + --prior_dir ./prior_output/run_ \ + --start_step 6 \ + --image data/examples/wan/input_image.jpg \ + --prompt "Different motion: the boat turns sharply to the left." +``` + +Or from this directory: + +```bash +python infer_from_prior.py \ + --prior_dir ./prior_output/run_ \ + --start_step 6 \ + --image data/examples/wan/input_image.jpg \ + --prompt "Different motion: the boat turns sharply to the left." +``` + +This runs only 3 steps (7, 8, 9) instead of 10 — ~70% fewer steps. + +## How It Works + +| Steps | Content | +|---------|-----------------------------------------------| +| 1–5 | Identity formation (geometry, lighting) | +| **6** | **Inflection point** — identity formed, motion not yet committed | +| 7–10 | Temporal refinement (details, sharpness) | + +By injecting the latent at step 6, we skip redundant identity formation. The remaining steps refine the motion (or other varying aspect) driven by the new prompt. + +## Scripts + +| Script | Purpose | +|---------------------|--------------------------------------------------------| +| `generate_prior.py` | Full inference with latent saving at each step | +| `infer_from_prior.py` | Accelerated inference from a saved prior | +| `prior_utils.py` | Latent save/load, metadata, scheduler validation | + +## Options + +### generate_prior.py + +- `--image` — Input image (required unless `--download_example`) +- `--download_example` — Download example image from ModelScope (saves to `data/examples/wan/`) +- `--output_dir` — Where to save latents (default: `./prior_output`) +- `--num_inference_steps` — Total steps (default: 10) +- `--start_step` — Not used here; for reference when calling infer_from_prior +- `--save_decoded_videos` — Decode and save video at each step (for finding formation point) + +### infer_from_prior.py + +- `--prior_dir` — Path to prior run (e.g. `./prior_output/run_123`) +- `--start_step` — Step to resume from (default: 6) +- `--image` — Same image used for prior generation +- `--prompt` — New prompt for the varying aspect + +## Scheduler Identity + +The scheduler used during prior generation **must match** inference. The scripts save and validate: + +- `num_inference_steps` +- `denoising_strength` +- `sigma_shift` +- `scheduler_timesteps` and `scheduler_sigmas` + +Do not change these between prior generation and inference. + +## Requirements + +- DiffSynth-Studio installed (`pip install -e .` from repo root) +- GPU with ≥8GB VRAM (low-VRAM config uses disk offload) +- Wan2.1-I2V-14B-480P model (downloaded automatically from ModelScope) + +## See Also + +- [Wan model documentation](../../../docs/en/Model_Details/Wan.md) +- [Model inference examples](../model_inference_low_vram/) diff --git a/examples/wanvideo/prior_based_step_skip/generate_prior.py b/examples/wanvideo/prior_based_step_skip/generate_prior.py new file mode 100644 index 000000000..331b79564 --- /dev/null +++ b/examples/wanvideo/prior_based_step_skip/generate_prior.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +""" +Generate a prior for prior-based diffusion step skip. + +Runs full inference and saves latent tensors at each denoising step. +Use infer_from_prior.py to run accelerated inference from the saved prior. + +Example: + # Image-to-video (Wan2.1-I2V-14B-480P) + python generate_prior.py \\ + --image path/to/image.jpg \\ + --output_dir ./prior_output \\ + --num_inference_steps 10 + + # With decoded videos at each step (for finding formation point) + python generate_prior.py \\ + --image path/to/image.jpg \\ + --output_dir ./prior_output \\ + --save_decoded_videos +""" + +import argparse +import os +import sys + +# Ensure prior_utils is importable when run from repo root or from this directory +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +if _SCRIPT_DIR not in sys.path: + sys.path.insert(0, _SCRIPT_DIR) + +import torch +from PIL import Image + +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from diffsynth.utils.data import save_video + +from prior_utils import build_step_callback, save_run_metadata + +# Default negative prompt (Wan-style) +DEFAULT_NEGATIVE_PROMPT = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指," + "画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合," + "静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +) + + +def parse_args(): + p = argparse.ArgumentParser(description="Generate prior latents for step-skip inference") + p.add_argument("--image", type=str, default=None, help="Path to input image (I2V); required unless --download_example") + p.add_argument("--prompt", type=str, default=None, help="Text prompt (default: example prompt)") + p.add_argument("--output_dir", type=str, default="./prior_output", help="Output directory") + p.add_argument("--run_id", type=str, default=None, help="Run ID (default: timestamp)") + p.add_argument("--num_inference_steps", type=int, default=10, help="Total denoising steps") + p.add_argument("--denoising_strength", type=float, default=1.0) + p.add_argument("--sigma_shift", type=float, default=5.0) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--height", type=int, default=480) + p.add_argument("--width", type=int, default=832) + p.add_argument("--num_frames", type=int, default=81) + p.add_argument("--cfg_scale", type=float, default=5.0) + p.add_argument("--save_decoded_videos", action="store_true", help="Decode and save video at each step") + p.add_argument("--model", type=str, default="I2V-480P", choices=["I2V-480P", "T2V-1.3B"]) + p.add_argument("--download_example", action="store_true", help="Download example image from ModelScope") + args = p.parse_args() + if not args.image and not args.download_example: + p.error("Either --image or --download_example is required") + return args + + +def main(): + args = parse_args() + + if args.download_example: + from modelscope import dataset_snapshot_download + + repo_root = os.path.abspath(os.path.join(_SCRIPT_DIR, "..", "..", "..")) + dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir=repo_root, + allow_file_pattern="data/examples/wan/input_image.jpg", + ) + args.image = os.path.join(repo_root, "data", "examples", "wan", "input_image.jpg") + + # Default prompt + prompt = args.prompt or ( + "A small boat bravely sails through the waves. The blue sea is turbulent, " + "white foam splashing against the hull. Sunlight reflects on the water. " + "The camera pulls in to show the flag on the boat waving in the wind." + ) + + # Load image + image = Image.open(args.image).convert("RGB").resize((args.width, args.height)) + + # VRAM config for low-memory GPUs + vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", + } + + if args.model == "I2V-480P": + pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig( + model_id="Wan-AI/Wan2.1-I2V-14B-480P", + origin_file_pattern="diffusion_pytorch_model*.safetensors", + **vram_config, + ), + ModelConfig( + model_id="Wan-AI/Wan2.1-I2V-14B-480P", + origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", + **vram_config, + ), + ModelConfig( + model_id="Wan-AI/Wan2.1-I2V-14B-480P", + origin_file_pattern="Wan2.1_VAE.pth", + **vram_config, + ), + ModelConfig( + model_id="Wan-AI/Wan2.1-I2V-14B-480P", + origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", + **vram_config, + ), + ], + tokenizer_config=ModelConfig( + model_id="Wan-AI/Wan2.1-T2V-1.3B", + origin_file_pattern="google/umt5-xxl/", + ), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024**3) - 2, + ) + else: + # T2V-1.3B (no image) + pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig( + model_id="Wan-AI/Wan2.1-T2V-1.3B", + origin_file_pattern="diffusion_pytorch_model*.safetensors", + **vram_config, + ), + ModelConfig( + model_id="Wan-AI/Wan2.1-T2V-1.3B", + origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", + **vram_config, + ), + ModelConfig( + model_id="Wan-AI/Wan2.1-T2V-1.3B", + origin_file_pattern="Wan2.1_VAE.pth", + **vram_config, + ), + ], + tokenizer_config=ModelConfig( + model_id="Wan-AI/Wan2.1-T2V-1.3B", + origin_file_pattern="google/umt5-xxl/", + ), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024**3) - 2, + ) + + # Build step callback + make_callback, run_id = build_step_callback( + output_dir=args.output_dir, + run_id=args.run_id, + save_decoded_videos=args.save_decoded_videos, + ) + step_callback = make_callback(pipe) + + print(f"Generating prior: {args.num_inference_steps} steps -> {args.output_dir}/{run_id}") + + # Run inference with step callback + pipe_kwargs = dict( + prompt=prompt, + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + num_inference_steps=args.num_inference_steps, + denoising_strength=args.denoising_strength, + sigma_shift=args.sigma_shift, + seed=args.seed, + height=args.height, + width=args.width, + num_frames=args.num_frames, + cfg_scale=args.cfg_scale, + tiled=True, + step_callback=step_callback, + ) + if args.model == "I2V-480P": + pipe_kwargs["input_image"] = image + + video = pipe(**pipe_kwargs) + + # Save metadata for infer_from_prior + save_run_metadata( + output_dir=args.output_dir, + run_id=run_id, + pipe=pipe, + height=args.height, + width=args.width, + num_frames=args.num_frames, + denoising_strength=args.denoising_strength, + sigma_shift=args.sigma_shift, + ) + + # Save final video + out_video_path = os.path.join(args.output_dir, run_id, "output_full.mp4") + save_video(video, out_video_path, fps=16, quality=5) + + print(f"Done. Prior saved to {args.output_dir}/{run_id}") + print(f" Latents: step_0000.pt ... step_{args.num_inference_steps - 1:04d}.pt") + print(f" Metadata: run_metadata.json") + print(f" Full video: output_full.mp4") + print(f"\nTo run accelerated inference from step 6:") + print(f" python examples/wanvideo/prior_based_step_skip/infer_from_prior.py \\") + print(f" --prior_dir {os.path.abspath(os.path.join(args.output_dir, run_id))} \\") + print(f" --start_step 6 --image {args.image}") + + +if __name__ == "__main__": + main() diff --git a/examples/wanvideo/prior_based_step_skip/infer_from_prior.py b/examples/wanvideo/prior_based_step_skip/infer_from_prior.py new file mode 100644 index 000000000..5f680cb94 --- /dev/null +++ b/examples/wanvideo/prior_based_step_skip/infer_from_prior.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +""" +Run accelerated inference from a saved prior. + +Loads a latent from a prior run and performs only the remaining denoising steps. +~70% fewer steps, same quality, zero retraining. + +Example: + python infer_from_prior.py \\ + --prior_dir ./prior_output/run_1234567890 \\ + --start_step 6 \\ + --image path/to/image.jpg \\ + --prompt "Different motion description" +""" + +import argparse +import os +import sys + +# Ensure prior_utils is importable when run from repo root or from this directory +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +if _SCRIPT_DIR not in sys.path: + sys.path.insert(0, _SCRIPT_DIR) +sys.path.insert(0, os.path.join(_SCRIPT_DIR, "..", "..", "..")) + +import torch +from PIL import Image + +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from diffsynth.utils.data import save_video + +from prior_utils import load_prior_metadata, validate_scheduler_match + +DEFAULT_NEGATIVE_PROMPT = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止," + "整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指," + "画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合," + "静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +) + + +def parse_args(): + p = argparse.ArgumentParser(description="Infer from prior (accelerated)") + p.add_argument("--prior_dir", type=str, required=True, help="Path to prior run directory") + p.add_argument("--start_step", type=int, default=6, help="Step to resume from (0-indexed)") + p.add_argument("--image", type=str, required=True, help="Input image (must match prior)") + p.add_argument("--prompt", type=str, default=None, help="New prompt for refinement") + p.add_argument("--output", type=str, default=None, help="Output video path") + p.add_argument("--model", type=str, default="I2V-480P", choices=["I2V-480P"]) + return p.parse_args() + + +def main(): + args = parse_args() + + # Load prior metadata and validate + meta = load_prior_metadata(args.prior_dir) + validate_scheduler_match( + { + "num_inference_steps": meta["num_inference_steps"], + "denoising_strength": meta["denoising_strength"], + "sigma_shift": meta["sigma_shift"], + }, + meta, + ) + + # Load prior latent + latent_path = os.path.join(args.prior_dir, f"step_{args.start_step:04d}.pt") + if not os.path.exists(latent_path): + raise FileNotFoundError(f"Prior latent not found: {latent_path}") + prior_latents = torch.load(latent_path, map_location="cpu", weights_only=True) + prior_timesteps = torch.tensor(meta["scheduler_timesteps"], dtype=torch.float32) + # Sigmas required for scheduler.step(); fallback: timesteps/1000 for Wan + if "scheduler_sigmas" in meta: + prior_sigmas = torch.tensor(meta["scheduler_sigmas"], dtype=torch.float32) + else: + prior_sigmas = prior_timesteps / 1000.0 + + height = meta["height"] + width = meta["width"] + num_frames = meta["num_frames"] + + prompt = args.prompt or ( + "A small boat bravely sails through the waves. The blue sea is turbulent, " + "white foam splashing against the hull. Sunlight reflects on the water." + ) + + image = Image.open(args.image).convert("RGB").resize((width, height)) + + vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", + } + + pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig( + model_id="Wan-AI/Wan2.1-I2V-14B-480P", + origin_file_pattern="diffusion_pytorch_model*.safetensors", + **vram_config, + ), + ModelConfig( + model_id="Wan-AI/Wan2.1-I2V-14B-480P", + origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", + **vram_config, + ), + ModelConfig( + model_id="Wan-AI/Wan2.1-I2V-14B-480P", + origin_file_pattern="Wan2.1_VAE.pth", + **vram_config, + ), + ModelConfig( + model_id="Wan-AI/Wan2.1-I2V-14B-480P", + origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", + **vram_config, + ), + ], + tokenizer_config=ModelConfig( + model_id="Wan-AI/Wan2.1-T2V-1.3B", + origin_file_pattern="google/umt5-xxl/", + ), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024**3) - 2, + ) + + remaining_steps = meta["num_inference_steps"] - args.start_step - 1 + print(f"Running {remaining_steps} steps (from step {args.start_step + 1} to {meta['num_inference_steps'] - 1})") + + video = pipe( + prompt=prompt, + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + input_image=image, + num_inference_steps=meta["num_inference_steps"], + denoising_strength=meta["denoising_strength"], + sigma_shift=meta["sigma_shift"], + height=height, + width=width, + num_frames=num_frames, + cfg_scale=5.0, + tiled=True, + prior_latents=prior_latents, + prior_timesteps=prior_timesteps, + prior_sigmas=prior_sigmas, + start_from_step=args.start_step, + ) + + out_path = args.output or os.path.join(args.prior_dir, f"output_from_step_{args.start_step}.mp4") + save_video(video, out_path, fps=16, quality=5) + print(f"Saved: {out_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/wanvideo/prior_based_step_skip/prior_utils.py b/examples/wanvideo/prior_based_step_skip/prior_utils.py new file mode 100644 index 000000000..8cf2923ad --- /dev/null +++ b/examples/wanvideo/prior_based_step_skip/prior_utils.py @@ -0,0 +1,144 @@ +""" +Utilities for prior-based diffusion step skip. + +Saves latent tensors at each denoising step and metadata required for resuming inference. +""" + +import json +import time +from pathlib import Path +from typing import Any, Callable, Optional, Tuple + +import torch + + +def build_step_callback( + output_dir: str, + run_id: Optional[str] = None, + save_decoded_videos: bool = False, +) -> Tuple[Callable, str]: + """ + Build a step_callback for WanVideoPipeline that saves latents at each step. + + Args: + output_dir: Directory to save latents (e.g. ./prior_output) + run_id: Optional run identifier; defaults to timestamp-based + save_decoded_videos: If True, decode latents to video at each step (for inspection). + Requires pipe to be passed via closure; we return a factory. + + Returns: + (callback_factory, run_id): A function that takes (pipe) and returns the actual + step_callback. The caller must pass the pipe so we can decode if requested. + """ + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + if run_id is None: + run_id = f"run_{int(time.time())}" + + run_dir = output_path / run_id + run_dir.mkdir(parents=True, exist_ok=True) + + def make_callback(pipe) -> Callable: + def step_callback(step_index: int, latents: torch.Tensor, timestep: torch.Tensor) -> None: + # Save latent + latent_path = run_dir / f"step_{step_index:04d}.pt" + torch.save(latents.cpu(), latent_path) + + # Optionally decode and save video for inspection + if save_decoded_videos and pipe is not None: + pipe.load_models_to_device(["vae"]) + video = pipe.vae.decode( + latents, + device=pipe.device, + tiled=getattr(pipe, "_prior_tiled", True), + tile_size=getattr(pipe, "_prior_tile_size", (30, 52)), + tile_stride=getattr(pipe, "_prior_tile_stride", (15, 26)), + ) + video_frames = pipe.vae_output_to_video(video) + video_path = run_dir / f"step_{step_index:04d}.mp4" + _save_video_frames(video_frames, str(video_path), fps=16) + pipe.load_models_to_device([]) + + return step_callback + + return make_callback, run_id + + +def save_run_metadata( + output_dir: str, + run_id: str, + pipe, + height: int = 480, + width: int = 832, + num_frames: int = 81, + denoising_strength: Optional[float] = None, + sigma_shift: Optional[float] = None, + **extra: Any, +) -> None: + """Save run_metadata.json for prior/inference compatibility checks.""" + run_dir = Path(output_dir) / run_id + run_dir.mkdir(parents=True, exist_ok=True) + + timesteps = ( + pipe.scheduler.timesteps.cpu().tolist() + if hasattr(pipe.scheduler.timesteps, "cpu") + else list(pipe.scheduler.timesteps) + ) + sigmas = ( + pipe.scheduler.sigmas.cpu().tolist() + if hasattr(pipe.scheduler.sigmas, "cpu") + else list(pipe.scheduler.sigmas) + ) + + metadata = { + "run_id": run_id, + "num_inference_steps": len(timesteps), + "scheduler_timesteps": timesteps, + "scheduler_sigmas": sigmas, + "denoising_strength": denoising_strength if denoising_strength is not None else 1.0, + "sigma_shift": sigma_shift if sigma_shift is not None else 5.0, + "height": height, + "width": width, + "num_frames": num_frames, + **extra, + } + + with open(run_dir / "run_metadata.json", "w", encoding="utf-8") as f: + json.dump(metadata, f, indent=2) + + +def load_prior_metadata(prior_dir: str) -> dict: + """Load run_metadata.json from a prior run directory.""" + path = Path(prior_dir) / "run_metadata.json" + if not path.exists(): + raise FileNotFoundError(f"Metadata not found: {path}") + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + +def validate_scheduler_match(current_config: dict, prior_metadata: dict) -> None: + """ + Validate that current inference config matches the prior's scheduler. + Raises ValueError if mismatch. + """ + for key in ("num_inference_steps", "denoising_strength", "sigma_shift"): + curr = current_config.get(key) + prior = prior_metadata.get(key) + if curr is not None and prior is not None and curr != prior: + raise ValueError( + f"Scheduler mismatch: {key} is {curr} but prior used {prior}. " + "Prior and inference must use identical scheduler parameters." + ) + + +def _save_video_frames(frames: list, path: str, fps: int = 16) -> None: + """Save list of PIL Images to MP4.""" + import imageio + import numpy as np + + writer = imageio.get_writer(path, fps=fps, quality=8) + for frame in frames: + arr = np.array(frame) if hasattr(frame, "size") else frame + writer.append_data(arr) + writer.close() From f823cd82e9d79b8e80e4fd623f7ef014999320fc Mon Sep 17 00:00:00 2001 From: Efrat Taig Date: Thu, 5 Mar 2026 09:56:13 +0000 Subject: [PATCH 2/2] docs(wan): Add scheduler documentation and update prior-based step skip README Made-with: Cursor --- .../wanvideo/prior_based_step_skip/README.md | 84 +++++ .../prior_based_step_skip/SCHEDULER_README.md | 307 ++++++++++++++++++ 2 files changed, 391 insertions(+) create mode 100644 examples/wanvideo/prior_based_step_skip/SCHEDULER_README.md diff --git a/examples/wanvideo/prior_based_step_skip/README.md b/examples/wanvideo/prior_based_step_skip/README.md index cfaf290c8..43a2a796d 100644 --- a/examples/wanvideo/prior_based_step_skip/README.md +++ b/examples/wanvideo/prior_based_step_skip/README.md @@ -7,6 +7,89 @@ When you have a **fixed identity or scene** and only **one aspect varies** (e.g. 1. **Generate a prior** — Run full inference once, save latents at each step 2. **Infer from prior** — Load a saved latent (e.g. step 6) and run only the remaining 3–4 steps +--- + +## Concept: What Is a Prior? + +In diffusion models, generation is a **multi-step denoising process**. Each step refines the latent representation: + +``` +Step 0 (noisy) → Step 1 → Step 2 → … → Step T (clean) +``` + +The key insight: **early steps fix identity and structure** (who/what is in the scene), while **later steps refine the varying aspect** (motion, expression, lighting). When identity is fixed and only one aspect changes, the early trajectory is nearly identical across runs. We can reuse it. + +A **prior** is a saved latent at some intermediate step. Instead of starting from pure noise every time, we inject a prior and run only the remaining steps. The prior encodes “what we already know” — the identity — so we only compute “what changes” — the motion or other variable. + +### Pseudocode: Standard Diffusion (No Prior) + +``` +function GENERATE(prompt, image, num_steps): + latents ← sample_noise(shape) # Start from random noise + timesteps ← scheduler.get_timesteps(num_steps) + + for step in 0 .. num_steps - 1: + t ← timesteps[step] + noise_pred ← model(latents, t, prompt, image) + latents ← scheduler.step(latents, noise_pred, t) + # ... (optionally save latents here for prior generation) + + return decode(latents) +``` + +### Pseudocode: Prior-Based Step Skip + +``` +function GENERATE_PRIOR(prompt, image, num_steps, output_dir): + # One-time: run full inference and save latents at each step + latents ← sample_noise(shape) + timesteps ← scheduler.get_timesteps(num_steps) + + for step in 0 .. num_steps - 1: + t ← timesteps[step] + noise_pred ← model(latents, t, prompt, image) + latents ← scheduler.step(latents, noise_pred, t) + save(latents, output_dir / f"step_{step}.pt") # ← Prior checkpoint + + save_metadata(timesteps, scheduler_params, output_dir) + return decode(latents) + + +function INFER_FROM_PRIOR(prompt, image, prior_dir, start_step): + # Accelerated: load prior, run only remaining steps + prior_latents ← load(prior_dir / f"step_{start_step}.pt") + metadata ← load_metadata(prior_dir) + timesteps ← metadata.timesteps + num_steps ← len(timesteps) + + latents ← prior_latents + # Skip steps 0 .. start_step; begin at start_step + 1 + for step in (start_step + 1) .. num_steps - 1: + t ← timesteps[step] + noise_pred ← model(latents, t, prompt, image) # New prompt can differ! + latents ← scheduler.step(latents, noise_pred, t) + + return decode(latents) +``` + +### Why It Works + +| Phase | Steps (e.g. 10-step run) | What happens | +|--------------------|---------------------------|---------------------------------------| +| Identity formation | 0–5 | Geometry, lighting, scene layout | +| **Inflection** | **6** | Identity fixed; motion not committed | +| Refinement | 7–9 | Temporal details, sharpness | + +By injecting the prior at step 6, we skip redundant identity formation. The remaining steps refine the **varying aspect** (e.g. motion) driven by the **new prompt**. Same identity, different motion — with ~70% fewer steps. + +### Constraints + +- **Same scheduler**: Prior and inference must use identical `num_inference_steps`, `denoising_strength`, `sigma_shift`. +- **Same conditioning (identity)**: Same input image (I2V) or seed-dependent structure. +- **Varying aspect**: Prompt (or other conditioning) can change for the refinement phase. + +--- + ## Quick Start Scripts work from **repo root** or from this directory. Run from repo root for consistent paths. @@ -124,5 +207,6 @@ Do not change these between prior generation and inference. ## See Also +- [Scheduler README](SCHEDULER_README.md) — What the scheduler is, its role in the prior, and how to use/modify it - [Wan model documentation](../../../docs/en/Model_Details/Wan.md) - [Model inference examples](../model_inference_low_vram/) diff --git a/examples/wanvideo/prior_based_step_skip/SCHEDULER_README.md b/examples/wanvideo/prior_based_step_skip/SCHEDULER_README.md new file mode 100644 index 000000000..e6a7ff9f2 --- /dev/null +++ b/examples/wanvideo/prior_based_step_skip/SCHEDULER_README.md @@ -0,0 +1,307 @@ +# The Scheduler in Prior-Based Step Skip + +This document explains **what the scheduler is**, **why it matters for the prior**, and **how to use and modify it** correctly. + +--- + +## What Is the Scheduler? + +In diffusion models, generation is an iterative denoising process. The **scheduler** defines: + +1. **The trajectory** — which points in “noise space” we visit (timesteps and sigmas) +2. **The step rule** — how to update the latent given the model’s prediction + +The model predicts a direction (velocity); the scheduler decides how far to move along that direction at each step. + +### Timesteps and Sigmas + +- **Timestep** `t`: A scalar (often 0–1000) that tells the model “how noisy” the current latent is. The model is conditioned on `t`. +- **Sigma** `σ`: A noise level used in the flow-matching update. For flow matching, `σ ≈ t / 1000` (normalized timestep). + +The scheduler produces two arrays of length `num_inference_steps`: + +``` +timesteps = [t₀, t₁, t₂, …, t_{T-1}] # High → low (noisy → clean) +sigmas = [σ₀, σ₁, σ₂, …, σ_{T-1}] # High → low +``` + +These are **deterministic** given the scheduler parameters. Different parameters → different trajectory → different results. + +### The Step Formula (Flow Matching) + +At each step `i`, the scheduler computes the next latent: + +``` +latent_{i+1} = latent_i + model_output × (σ_{i+1} − σ_i) +``` + +So the model’s output is scaled by the **sigma difference** between the current and next step. The scheduler’s `sigmas` array is what makes this math correct. + +--- + +## Pseudocode: What the Scheduler Does + +``` +# ═══════════════════════════════════════════════════════════════════════════ +# SCHEDULER: Definition and Role +# ═══════════════════════════════════════════════════════════════════════════ + +# 1. SET_TIMESTEPS: Build the denoising trajectory from parameters +# Inputs: num_inference_steps, denoising_strength, sigma_shift (or shift) +# Outputs: timesteps[], sigmas[] — the exact sequence for this run +# +function SCHEDULER_SET_TIMESTEPS(num_steps, denoising_strength, sigma_shift): + sigma_start ← sigma_min + (sigma_max - sigma_min) × denoising_strength + sigmas ← linspace(sigma_start, sigma_min, num_steps) + sigmas ← sigma_shift × sigmas / (1 + (sigma_shift - 1) × sigmas) # Rescale + timesteps ← sigmas × 1000 # Map to 0–1000 range for model conditioning + return (sigmas, timesteps) + + +# 2. STEP: Update latent using model output and sigma difference +# The model predicts a "velocity"; we move the sample by (σ_next − σ_curr) +# +function SCHEDULER_STEP(model_output, timestep, sample, sigmas, timesteps): + step_id ← index of timestep in timesteps + σ ← sigmas[step_id] + σ_next ← sigmas[step_id + 1] # or 0 if last step + sample_next ← sample + model_output × (σ_next − σ) + return sample_next + + +# 3. Standard diffusion loop (no prior) +# +function DENOISE_STANDARD(prompt, image, num_steps): + (sigmas, timesteps) ← SCHEDULER_SET_TIMESTEPS(num_steps, 1.0, 5.0) + latents ← sample_noise() + + for i in 0 .. num_steps - 1: + t ← timesteps[i] + noise_pred ← model(latents, t, prompt, image) + latents ← SCHEDULER_STEP(noise_pred, t, latents, sigmas, timesteps) + + return decode(latents) +``` + +--- + +## Why the Scheduler Matters for the Prior + +The prior latent was produced at a **specific point** on a **specific trajectory**. That trajectory is fully defined by `(timesteps, sigmas)`. + +If inference uses a **different** trajectory (e.g. different `num_inference_steps`, `denoising_strength`, or `sigma_shift`): + +- The sigma differences `(σ_{i+1} − σ_i)` change +- The step formula produces wrong updates +- The denoising path no longer matches what the model expects + +So: **prior and inference must use the same scheduler trajectory**. We achieve this by saving and restoring `timesteps` and `sigmas` from the prior run. + +--- + +## What We Did With the Scheduler (Prior-Based Step Skip) + +### 1. Save the trajectory when generating the prior + +When we run full inference to build the prior, we save not only the latents but also the scheduler’s `timesteps` and `sigmas`: + +```python +# prior_utils.py — save_run_metadata() +timesteps = pipe.scheduler.timesteps.cpu().tolist() +sigmas = pipe.scheduler.sigmas.cpu().tolist() + +metadata = { + "scheduler_timesteps": timesteps, + "scheduler_sigmas": sigmas, + "num_inference_steps": len(timesteps), + "denoising_strength": denoising_strength, + "sigma_shift": sigma_shift, + # ... +} +``` + +### 2. Override the scheduler when inferring from the prior + +When resuming from a prior, we **replace** the scheduler’s arrays with the saved ones instead of recomputing them: + +```python +# wan_video.py — pipeline __call__ +self.scheduler.set_timesteps(num_inference_steps, denoising_strength=..., shift=sigma_shift) + +# Prior-based step skip: use the exact trajectory from the prior run +if prior_latents is not None and prior_timesteps is not None and start_from_step is not None: + self.scheduler.timesteps = prior_timesteps.to(...) + if prior_sigmas is not None: + self.scheduler.sigmas = prior_sigmas.to(...) +``` + +### 3. Skip early steps in the loop + +We still iterate over the full `timesteps` array, but we skip the steps we’ve already “done” via the prior: + +```python +start_idx = start_from_step + 1 +for progress_id, timestep in enumerate(timesteps): + if progress_id < start_idx: + continue # Skip — we loaded the latent after step start_from_step + # ... run model, scheduler.step(), etc. +``` + +--- + +## Pseudocode: Prior Flow With Scheduler Handling + +``` +# ═══════════════════════════════════════════════════════════════════════════ +# PRIOR GENERATION: Save latents AND scheduler state +# ═══════════════════════════════════════════════════════════════════════════ + +function GENERATE_PRIOR(prompt, image, num_steps, output_dir): + # Build trajectory (deterministic from params) + (sigmas, timesteps) ← SCHEDULER_SET_TIMESTEPS(num_steps, 1.0, 5.0) + latents ← sample_noise() + + for i in 0 .. num_steps - 1: + t ← timesteps[i] + noise_pred ← model(latents, t, prompt, image) + latents ← SCHEDULER_STEP(noise_pred, t, latents, sigmas, timesteps) + save(latents, output_dir / f"step_{i}.pt") + + # CRITICAL: Save the trajectory so inference can reuse it exactly + save_metadata({ + "scheduler_timesteps": timesteps, + "scheduler_sigmas": sigmas, + "num_inference_steps": num_steps, + "denoising_strength": 1.0, + "sigma_shift": 5.0, + }, output_dir) + return decode(latents) + + +# ═══════════════════════════════════════════════════════════════════════════ +# PRIOR INFERENCE: Load prior, override scheduler, run remaining steps +# ═══════════════════════════════════════════════════════════════════════════ + +function INFER_FROM_PRIOR(prompt, image, prior_dir, start_step): + # Load prior latent (output of step start_step) + prior_latents ← load(prior_dir / f"step_{start_step}.pt") + metadata ← load_metadata(prior_dir) + + # Use the EXACT trajectory from the prior run — do NOT recompute + timesteps ← metadata.scheduler_timesteps + sigmas ← metadata.scheduler_sigmas + + latents ← prior_latents + start_idx ← start_step + 1 + + for i in 0 .. len(timesteps) - 1: + if i < start_idx: + continue # Skip steps 0..start_step (already in prior) + t ← timesteps[i] + noise_pred ← model(latents, t, prompt, image) + latents ← SCHEDULER_STEP(noise_pred, t, latents, sigmas, timesteps) + + return decode(latents) +``` + +--- + +## How to Use and Modify the Scheduler + +### Using the prior correctly + +| Requirement | Reason | +|-------------|--------| +| Same `num_inference_steps` | Same trajectory length | +| Same `denoising_strength` | Same starting sigma | +| Same `sigma_shift` | Same sigma rescaling | +| Use saved `timesteps` and `sigmas` | Exact trajectory match; avoids float drift | + +The scripts validate these via `validate_scheduler_match()` before inference. + +### Modifying scheduler parameters + +- **When generating the prior**: Choose `num_inference_steps`, `denoising_strength`, `sigma_shift` as needed. These are saved in metadata. +- **When inferring from the prior**: Pass the **same** values. The pipeline loads `prior_timesteps` and `prior_sigmas` from metadata and overrides the scheduler; the parameters are mainly for validation. + +### Example: Changing the number of steps + +If you want 20 steps instead of 10: + +1. Generate a new prior with `--num_inference_steps 20`. +2. Use that prior with `infer_from_prior.py`; it will read `num_inference_steps: 20` from metadata. +3. You can use e.g. `--start_step 12` to skip the first 13 steps and run 7 steps. + +You cannot mix a prior generated with 10 steps with inference configured for 20 steps — the trajectories differ. + +### Code: Saving and loading scheduler state + +**Saving** (in `prior_utils.save_run_metadata`): + +```python +timesteps = pipe.scheduler.timesteps.cpu().tolist() +sigmas = pipe.scheduler.sigmas.cpu().tolist() +metadata = { + "scheduler_timesteps": timesteps, + "scheduler_sigmas": sigmas, + "num_inference_steps": len(timesteps), + "denoising_strength": denoising_strength, + "sigma_shift": sigma_shift, + # ... +} +``` + +**Loading** (in `infer_from_prior.py`): + +```python +meta = load_prior_metadata(prior_dir) +prior_timesteps = torch.tensor(meta["scheduler_timesteps"], dtype=torch.float32) +prior_sigmas = torch.tensor(meta["scheduler_sigmas"], dtype=torch.float32) + +# Passed to pipeline; pipeline overrides scheduler.timesteps and scheduler.sigmas +video = pipe( + ..., + prior_latents=prior_latents, + prior_timesteps=prior_timesteps, + prior_sigmas=prior_sigmas, + start_from_step=args.start_step, +) +``` + +**Override in pipeline** (in `wan_video.py`): + +```python +self.scheduler.set_timesteps(num_inference_steps, denoising_strength=..., shift=sigma_shift) + +if prior_latents is not None and prior_timesteps is not None and start_from_step is not None: + self.scheduler.timesteps = prior_timesteps.to(self.scheduler.timesteps.device) + if prior_sigmas is not None: + self.scheduler.sigmas = prior_sigmas.to(self.scheduler.sigmas.device) +``` + +### Example: Wan scheduler formula + +The Wan scheduler (flow matching) uses: + +```python +# diffsynth/diffusion/flow_match.py — set_timesteps_wan +sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength +sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] +sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) +timesteps = sigmas * num_train_timesteps +``` + +- `denoising_strength`: 0–1; 1 = full denoising from max noise +- `shift` (sigma_shift): Rescales sigmas; default 5 for Wan + +--- + +## Summary + +| Concept | Role | +|---------|------| +| **Scheduler** | Defines the denoising trajectory (timesteps, sigmas) and the step update rule | +| **Timesteps** | Conditioning for the model; index into the trajectory | +| **Sigmas** | Used in the step formula; must match between prior and inference | +| **Prior + scheduler** | Prior latent lies on a specific trajectory; inference must use that same trajectory | +| **Override, don’t recompute** | Load saved `timesteps` and `sigmas` to guarantee consistency |