Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions diffsynth/pipelines/wan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from PIL import Image
from einops import repeat
from typing import Optional, Union
from typing import Callable, Optional, Union
from einops import rearrange
import numpy as np
from PIL import Image
Expand Down Expand Up @@ -247,9 +247,22 @@ def __call__(
# progress_bar
progress_bar_cmd=tqdm,
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized",
# Prior-based step skip: optional callback after each denoising step
step_callback: Optional[Callable[[int, torch.Tensor, torch.Tensor], None]] = None,
# Prior-based step skip: resume from saved latent (requires prior_latents + prior_timesteps)
prior_latents: Optional[torch.Tensor] = None,
prior_timesteps: Optional[torch.Tensor] = None,
prior_sigmas: Optional[torch.Tensor] = None,
start_from_step: Optional[int] = None,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)

# Prior-based step skip: override latents, timesteps, and sigmas when resuming from prior
if prior_latents is not None and prior_timesteps is not None and start_from_step is not None:
self.scheduler.timesteps = prior_timesteps.to(self.scheduler.timesteps.device)
if prior_sigmas is not None:
self.scheduler.sigmas = prior_sigmas.to(self.scheduler.sigmas.device)

# Inputs
inputs_posi = {
Expand Down Expand Up @@ -284,10 +297,18 @@ def __call__(
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)

# Prior-based step skip: replace latents with loaded prior
if prior_latents is not None and start_from_step is not None:
inputs_shared["latents"] = prior_latents.to(dtype=self.torch_dtype, device=self.device)
Comment on lines +301 to +302
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The condition for replacing latents with the prior is missing a check for prior_timesteps. The check at line 262 correctly requires prior_latents, prior_timesteps, and start_from_step. If prior_timesteps is not provided here, the scheduler will use incorrect timesteps with the loaded prior latents, which can lead to incorrect generation results. To ensure consistency and prevent bugs, the condition should be the same as the one at line 262.

Suggested change
if prior_latents is not None and start_from_step is not None:
inputs_shared["latents"] = prior_latents.to(dtype=self.torch_dtype, device=self.device)
if prior_latents is not None and prior_timesteps is not None and start_from_step is not None:
inputs_shared["latents"] = prior_latents.to(dtype=self.torch_dtype, device=self.device)


# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timesteps = self.scheduler.timesteps
start_idx = (start_from_step + 1) if start_from_step is not None else 0
for progress_id, timestep in enumerate(progress_bar_cmd(timesteps)):
if progress_id < start_idx:
continue
# Switch DiT if necessary
if timestep.item() < switch_DiT_boundary * 1000 and self.dit2 is not None and not models["dit"] is self.dit2:
self.load_models_to_device(self.in_iteration_models_2)
Expand All @@ -312,6 +333,10 @@ def __call__(
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
if "first_frame_latents" in inputs_shared:
inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"]

# Prior-based step skip: call optional callback after each step
if step_callback is not None:
step_callback(progress_id, inputs_shared["latents"].clone(), timestep)

# VACE (TODO: remove it)
if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None):
Expand Down
6 changes: 6 additions & 0 deletions docs/en/Model_Details/Wan.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ Input parameters for `WanVideoPipeline` inference include:

If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.

### Prior-Based Step Skip

For fixed identity/scene with varying motion (e.g. lip-sync, different actions), early diffusion steps are largely redundant. You can run full inference once, save latents at each step, then resume from a saved latent to run only the remaining steps — ~70% fewer steps, same quality.

See [prior-based step skip example](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/prior_based_step_skip) for `generate_prior.py` and `infer_from_prior.py`.

## Model Training

Wan series models are uniformly trained through [`examples/wanvideo/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/train.py), and the script parameters include:
Expand Down
6 changes: 6 additions & 0 deletions docs/zh/Model_Details/Wan.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,12 @@ DeepSpeed ZeRO 3 训练:Wan 系列模型支持 DeepSpeed ZeRO 3 训练,将

如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。

### 基于先验的步长跳过

当身份/场景固定而仅运动变化(如口型同步、不同动作)时,早期扩散步长大多冗余。可先运行一次完整推理并保存每步的潜在表示,再从保存的潜在表示恢复,仅运行剩余步长 —— 步长减少约 70%,质量相当。

参见 [prior-based step skip 示例](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/prior_based_step_skip) 中的 `generate_prior.py` 与 `infer_from_prior.py`。

## 模型训练

Wan 系列模型统一通过 [`examples/wanvideo/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/train.py) 进行训练,脚本的参数包括:
Expand Down
212 changes: 212 additions & 0 deletions examples/wanvideo/prior_based_step_skip/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# Prior-Based Diffusion Step Skip

**~70% fewer inference steps, same quality, zero retraining.**

When you have a **fixed identity or scene** and only **one aspect varies** (e.g. motion, lip-sync, lighting), early diffusion steps are largely redundant. This module lets you:

1. **Generate a prior** — Run full inference once, save latents at each step
2. **Infer from prior** — Load a saved latent (e.g. step 6) and run only the remaining 3–4 steps

---

## Concept: What Is a Prior?

In diffusion models, generation is a **multi-step denoising process**. Each step refines the latent representation:

```
Step 0 (noisy) → Step 1 → Step 2 → … → Step T (clean)
```

The key insight: **early steps fix identity and structure** (who/what is in the scene), while **later steps refine the varying aspect** (motion, expression, lighting). When identity is fixed and only one aspect changes, the early trajectory is nearly identical across runs. We can reuse it.

A **prior** is a saved latent at some intermediate step. Instead of starting from pure noise every time, we inject a prior and run only the remaining steps. The prior encodes “what we already know” — the identity — so we only compute “what changes” — the motion or other variable.

### Pseudocode: Standard Diffusion (No Prior)

```
function GENERATE(prompt, image, num_steps):
latents ← sample_noise(shape) # Start from random noise
timesteps ← scheduler.get_timesteps(num_steps)

for step in 0 .. num_steps - 1:
t ← timesteps[step]
noise_pred ← model(latents, t, prompt, image)
latents ← scheduler.step(latents, noise_pred, t)
# ... (optionally save latents here for prior generation)

return decode(latents)
```

### Pseudocode: Prior-Based Step Skip

```
function GENERATE_PRIOR(prompt, image, num_steps, output_dir):
# One-time: run full inference and save latents at each step
latents ← sample_noise(shape)
timesteps ← scheduler.get_timesteps(num_steps)

for step in 0 .. num_steps - 1:
t ← timesteps[step]
noise_pred ← model(latents, t, prompt, image)
latents ← scheduler.step(latents, noise_pred, t)
save(latents, output_dir / f"step_{step}.pt") # ← Prior checkpoint

save_metadata(timesteps, scheduler_params, output_dir)
return decode(latents)


function INFER_FROM_PRIOR(prompt, image, prior_dir, start_step):
# Accelerated: load prior, run only remaining steps
prior_latents ← load(prior_dir / f"step_{start_step}.pt")
metadata ← load_metadata(prior_dir)
timesteps ← metadata.timesteps
num_steps ← len(timesteps)

latents ← prior_latents
# Skip steps 0 .. start_step; begin at start_step + 1
for step in (start_step + 1) .. num_steps - 1:
t ← timesteps[step]
noise_pred ← model(latents, t, prompt, image) # New prompt can differ!
latents ← scheduler.step(latents, noise_pred, t)

return decode(latents)
```

### Why It Works

| Phase | Steps (e.g. 10-step run) | What happens |
|--------------------|---------------------------|---------------------------------------|
| Identity formation | 0–5 | Geometry, lighting, scene layout |
| **Inflection** | **6** | Identity fixed; motion not committed |
| Refinement | 7–9 | Temporal details, sharpness |

By injecting the prior at step 6, we skip redundant identity formation. The remaining steps refine the **varying aspect** (e.g. motion) driven by the **new prompt**. Same identity, different motion — with ~70% fewer steps.

### Constraints

- **Same scheduler**: Prior and inference must use identical `num_inference_steps`, `denoising_strength`, `sigma_shift`.
- **Same conditioning (identity)**: Same input image (I2V) or seed-dependent structure.
- **Varying aspect**: Prompt (or other conditioning) can change for the refinement phase.

---

## Quick Start

Scripts work from **repo root** or from this directory. Run from repo root for consistent paths.

### Step 1: Generate the prior

**From repo root:**

```bash
# Download example image and run full inference
python examples/wanvideo/prior_based_step_skip/generate_prior.py \
--download_example \
--output_dir ./prior_output \
--num_inference_steps 10
```

**Or with your own image:**

```bash
python examples/wanvideo/prior_based_step_skip/generate_prior.py \
--image path/to/image.jpg \
--output_dir ./prior_output \
--num_inference_steps 10
```

**From this directory:**

```bash
cd examples/wanvideo/prior_based_step_skip

# With --download_example (downloads to repo root data/)
python generate_prior.py --download_example --output_dir ./prior_output --num_inference_steps 10

# Or with your own image
python generate_prior.py --image path/to/image.jpg --output_dir ./prior_output --num_inference_steps 10
```

Output: `./prior_output/run_<id>/` with `step_0000.pt` … `step_0009.pt`, `run_metadata.json`, and `output_full.mp4`.

### Step 2: Run accelerated inference

```bash
# From repo root (replace run_<id> with actual run ID from step 1)
python examples/wanvideo/prior_based_step_skip/infer_from_prior.py \
--prior_dir ./prior_output/run_<id> \
--start_step 6 \
--image data/examples/wan/input_image.jpg \
--prompt "Different motion: the boat turns sharply to the left."
```

Or from this directory:

```bash
python infer_from_prior.py \
--prior_dir ./prior_output/run_<id> \
--start_step 6 \
--image data/examples/wan/input_image.jpg \
--prompt "Different motion: the boat turns sharply to the left."
```

This runs only 3 steps (7, 8, 9) instead of 10 — ~70% fewer steps.

## How It Works

| Steps | Content |
|---------|-----------------------------------------------|
| 1–5 | Identity formation (geometry, lighting) |
| **6** | **Inflection point** — identity formed, motion not yet committed |
| 7–10 | Temporal refinement (details, sharpness) |

By injecting the latent at step 6, we skip redundant identity formation. The remaining steps refine the motion (or other varying aspect) driven by the new prompt.

## Scripts

| Script | Purpose |
|---------------------|--------------------------------------------------------|
| `generate_prior.py` | Full inference with latent saving at each step |
| `infer_from_prior.py` | Accelerated inference from a saved prior |
| `prior_utils.py` | Latent save/load, metadata, scheduler validation |

## Options

### generate_prior.py

- `--image` — Input image (required unless `--download_example`)
- `--download_example` — Download example image from ModelScope (saves to `data/examples/wan/`)
- `--output_dir` — Where to save latents (default: `./prior_output`)
- `--num_inference_steps` — Total steps (default: 10)
- `--start_step` — Not used here; for reference when calling infer_from_prior
- `--save_decoded_videos` — Decode and save video at each step (for finding formation point)

### infer_from_prior.py

- `--prior_dir` — Path to prior run (e.g. `./prior_output/run_123`)
- `--start_step` — Step to resume from (default: 6)
- `--image` — Same image used for prior generation
- `--prompt` — New prompt for the varying aspect

## Scheduler Identity

The scheduler used during prior generation **must match** inference. The scripts save and validate:

- `num_inference_steps`
- `denoising_strength`
- `sigma_shift`
- `scheduler_timesteps` and `scheduler_sigmas`

Do not change these between prior generation and inference.

## Requirements

- DiffSynth-Studio installed (`pip install -e .` from repo root)
- GPU with ≥8GB VRAM (low-VRAM config uses disk offload)
- Wan2.1-I2V-14B-480P model (downloaded automatically from ModelScope)

## See Also

- [Scheduler README](SCHEDULER_README.md) — What the scheduler is, its role in the prior, and how to use/modify it
- [Wan model documentation](../../../docs/en/Model_Details/Wan.md)
- [Model inference examples](../model_inference_low_vram/)
Loading