Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions api/routers/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import threading
import traceback
import uuid
from typing import Dict
from typing import Dict, List
from fastapi import APIRouter, File, Form, UploadFile, HTTPException, BackgroundTasks
from services.generators.base import smooth_progress

Expand All @@ -18,7 +18,8 @@
@router.post("/from-image")
async def generate_from_image(
background_tasks: BackgroundTasks,
image: UploadFile = File(...),
image: List[UploadFile] = File(...),
view_labels: str = Form(""),
model_id: str = Form("sf3d"),
collection: str = Form("Default"),
vertex_count: int = Form(10000),
Expand All @@ -30,8 +31,9 @@ async def generate_from_image(
seed: int = Form(-1),
num_inference_steps: int = Form(30),
):
if not image.content_type or not image.content_type.startswith("image/"):
raise HTTPException(400, "File must be an image")
for img in image:
if not img.content_type or not img.content_type.startswith("image/"):
raise HTTPException(400, "All files must be images")

if remesh not in ("quad", "triangle", "none"):
raise HTTPException(400, "remesh must be 'quad', 'triangle', or 'none'")
Expand All @@ -56,8 +58,12 @@ async def generate_from_image(

generator_registry.switch_model(model_id)

job_id = str(uuid.uuid4())
image_bytes = await image.read()
job_id = str(uuid.uuid4())
image_bytes_list = [await img.read() for img in image]
# Pass single bytes for backward compat, list for multi-view
image_data = image_bytes_list[0] if len(image_bytes_list) == 1 else image_bytes_list
# Parse view labels (e.g. "front,back" → ["front", "back"])
parsed_view_labels = [v.strip() for v in view_labels.split(",") if v.strip()] if view_labels else []
params = {
"vertex_count": vertex_count,
"remesh": remesh,
Expand All @@ -67,12 +73,13 @@ async def generate_from_image(
"guidance_scale": guidance_scale,
"seed": seed,
"num_inference_steps": num_inference_steps,
"view_labels": parsed_view_labels,
}

job = JobStatus(job_id=job_id, status="pending", progress=0)
_jobs[job_id] = job

background_tasks.add_task(_run_generation, job_id, image_bytes, params, collection)
background_tasks.add_task(_run_generation, job_id, image_data, params, collection)

return {"job_id": job_id}

Expand All @@ -86,7 +93,7 @@ async def job_status(job_id: str):
return job


async def _run_generation(job_id: str, image_bytes: bytes, params: dict, collection: str = "Default") -> None:
async def _run_generation(job_id: str, image_bytes, params: dict, collection: str = "Default") -> None:
job = _jobs[job_id]
job.status = "running"

Expand Down
7 changes: 4 additions & 3 deletions api/services/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
import threading
from pathlib import Path
from typing import Callable, Optional
from typing import Callable, List, Optional, Union


def smooth_progress(
Expand Down Expand Up @@ -84,12 +84,13 @@ def is_loaded(self) -> bool:
@abstractmethod
def generate(
self,
image_bytes: bytes,
image_bytes: Union[bytes, List[bytes]],
params: dict,
progress_cb: Optional[Callable[[int, str], None]] = None,
) -> Path:
"""
Starts 3D generation from an image.
Starts 3D generation from one or more images.
Pass a single bytes for single-view, or List[bytes] for multi-view.
Returns the path to the generated .glb file.
progress_cb(percent: int, step_label: str)
"""
Expand Down
24 changes: 19 additions & 5 deletions api/services/generators/hunyuan3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import uuid
import zipfile
from pathlib import Path
from typing import Callable, Optional
from typing import Callable, List, Optional, Union

from PIL import Image

Expand Down Expand Up @@ -84,7 +84,7 @@ def unload(self) -> None:

def generate(
self,
image_bytes: bytes,
image_bytes: Union[bytes, List[bytes]],
params: dict,
progress_cb: Optional[Callable[[int, str], None]] = None,
) -> Path:
Expand All @@ -93,9 +93,23 @@ def generate(
num_steps = int(params.get("num_inference_steps", 50))
vert_count = int(params.get("vertex_count", 0))

# Step 1 — background removal
self._report(progress_cb, 5, "Removing background…")
image = self._preprocess(image_bytes)
# Step 1 — background removal (single or multi-view)
view_labels = params.get("view_labels", [])
is_multiview = isinstance(image_bytes, list) and len(image_bytes) > 1
if is_multiview:
self._report(progress_cb, 5, f"Removing backgrounds ({len(image_bytes)} images)…")
processed_images = [self._preprocess(ib) for ib in image_bytes]
if view_labels and len(view_labels) == len(processed_images):
image = {label: img for label, img in zip(view_labels, processed_images)}
else:
fallback_keys = ["front", "left", "back", "right"]
image = {fallback_keys[i]: img for i, img in enumerate(processed_images[:4])}
elif isinstance(image_bytes, list):
self._report(progress_cb, 5, "Removing background…")
image = self._preprocess(image_bytes[0])
else:
self._report(progress_cb, 5, "Removing background…")
image = self._preprocess(image_bytes)

# Step 2 — shape generation (long, no internal callbacks)
self._report(progress_cb, 12, "Generating 3D shape…")
Expand Down
24 changes: 19 additions & 5 deletions api/services/generators/hunyuan3d_mini.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import uuid
import zipfile
from pathlib import Path
from typing import Callable, Optional
from typing import Callable, List, Optional, Union

from PIL import Image

Expand Down Expand Up @@ -83,7 +83,7 @@ def unload(self) -> None:

def generate(
self,
image_bytes: bytes,
image_bytes: Union[bytes, List[bytes]],
params: dict,
progress_cb: Optional[Callable[[int, str], None]] = None,
) -> Path:
Expand All @@ -96,9 +96,23 @@ def generate(
guidance_scale = float(params.get("guidance_scale", 5.5))
seed = int(params.get("seed", -1))

# Step 1 — background removal
self._report(progress_cb, 5, "Removing background…")
image = self._preprocess(image_bytes)
# Step 1 — background removal (single or multi-view)
view_labels = params.get("view_labels", [])
is_multiview = isinstance(image_bytes, list) and len(image_bytes) > 1
if is_multiview:
self._report(progress_cb, 5, f"Removing backgrounds ({len(image_bytes)} images)…")
processed_images = [self._preprocess(ib) for ib in image_bytes]
if view_labels and len(view_labels) == len(processed_images):
image = {label: img for label, img in zip(view_labels, processed_images)}
else:
fallback_keys = ["front", "left", "back", "right"]
image = {fallback_keys[i]: img for i, img in enumerate(processed_images[:4])}
elif isinstance(image_bytes, list):
self._report(progress_cb, 5, "Removing background…")
image = self._preprocess(image_bytes[0])
else:
self._report(progress_cb, 5, "Removing background…")
image = self._preprocess(image_bytes)

# Step 2 — shape generation
# If texture is enabled, reserve 5-70% for shape and 70-95% for texture
Expand Down
8 changes: 6 additions & 2 deletions api/services/generators/sf3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import uuid
import zipfile
from pathlib import Path
from typing import Callable, Optional
from typing import Callable, List, Optional, Union

from PIL import Image

Expand Down Expand Up @@ -69,12 +69,16 @@ def load(self) -> None:

def generate(
self,
image_bytes: bytes,
image_bytes: Union[bytes, List[bytes]],
params: dict,
progress_cb: Optional[Callable[[int, str], None]] = None,
) -> Path:
import torch

# SF3D only supports single-image input; use first image if list provided
if isinstance(image_bytes, list):
image_bytes = image_bytes[0]

vertex_count = int(params.get("vertex_count", 10000))
remesh = str(params.get("remesh", "quad"))

Expand Down
7 changes: 4 additions & 3 deletions src/areas/generate/GeneratePage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ import WorkspacePanel from './components/WorkspacePanel'
import Viewer3D from './components/Viewer3D'

export default function GeneratePage(): JSX.Element {
const selectedImagePath = useAppStore((s) => s.selectedImagePath)
const viewImages = useAppStore((s) => s.viewImages)
const { currentJob, startGeneration } = useGeneration()
const isGenerating = currentJob?.status === 'uploading' || currentJob?.status === 'generating'
const hasFrontImage = !!viewImages.front

return (
<>
Expand All @@ -23,8 +24,8 @@ export default function GeneratePage(): JSX.Element {
{/* Sticky bottom: Generate button */}
<div className="p-4 border-t border-zinc-800">
<button
onClick={() => selectedImagePath && startGeneration(selectedImagePath)}
disabled={!selectedImagePath || isGenerating}
onClick={() => hasFrontImage && startGeneration()}
disabled={!hasFrontImage || isGenerating}
className="w-full py-2.5 rounded-lg text-sm font-semibold bg-accent hover:bg-accent-dark disabled:opacity-40 disabled:cursor-not-allowed text-white transition-colors"
>
{isGenerating ? 'Generating…' : 'Generate 3D Model'}
Expand Down
Loading