Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,15 @@ dynamic = [
]
dependencies = [
"beartype>=0.22.9",
"brisque>=0.0.17",
"flask>=3.0.3",
"numpy>=1.26.4",
"pillow>=11.0.0",
"piq>=0.8.0",
"pydantic-settings>=2.6.1",
"requests>=2.32.3",
"responses>=0.25.3",
"torch>=2.5.1",
"torchmetrics>=1.5.1",
"torchvision>=0.20.1",
"scikit-image>=0.21.0",
"scipy>=1.12.0",
"tzdata; sys_platform=='win32'",
"vws-auth-tools>=2024.7.12",
"werkzeug>=3.1.2",
Expand Down Expand Up @@ -139,11 +138,6 @@ fallback_version = "0.0.0"
# Code to match this is in ``conf.py``.
version_scheme = "post-release"

[tool.uv]
sources.torch = { index = "pytorch-cpu" }
sources.torchvision = { index = "pytorch-cpu" }
index = [ { name = "pytorch-cpu", url = "https://download.pytorch.org/whl/cpu", explicit = true } ]

[tool.ruff]
line-length = 79
lint.select = [
Expand Down Expand Up @@ -315,12 +309,11 @@ pep621_dev_dependency_groups = [
"release",
]
per_rule_ignores.DEP002 = [
# scipy is a transitive dependency of brisque.
"scipy",
# tzdata is needed on Windows for zoneinfo to work.
# See https://docs.python.org/3/library/zoneinfo.html#data-sources.
"tzdata",
# torchvision is used transitively via piq, but must be a direct dependency
# so that tool.uv.sources can route it to the CPU-only PyTorch index.
"torchvision",
]

[tool.pyproject-fmt]
Expand Down
1 change: 1 addition & 0 deletions spelling_private_dict.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ beartype
binascii
bool
boolean
brisque
bytesio
changelog
chunked
Expand Down
23 changes: 23 additions & 0 deletions src/mock_vws/_brisque.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Wrapper around the brisque package for image quality scoring."""

from __future__ import annotations

import io

import numpy as np
from brisque import ( # type: ignore[attr-defined] # pyright: ignore[reportMissingTypeStubs]
BRISQUE,
)
from PIL import Image

_brisque_scorer = BRISQUE(url=False) # type: ignore[no-untyped-call]


def brisque_score(image_content: bytes) -> float:
"""Return a BRISQUE quality score for the given image bytes."""
image_file = io.BytesIO(initial_bytes=image_content)
image = Image.open(fp=image_file).convert(mode="RGB")
image_np: np.ndarray = np.array(object=image)
return float(
_brisque_scorer.score(img=image_np) # type: ignore[no-untyped-call] # pyright: ignore[reportUnknownMemberType]
)
2 changes: 2 additions & 0 deletions src/mock_vws/_flask_server/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
FROM ghcr.io/astral-sh/uv:0.10.4-python3.13-trixie-slim AS base
# Install build tools required by packages that compile C++ extensions (e.g. libsvm-official).
RUN apt-get update && apt-get install -y --no-install-recommends g++=3.0.3 && rm -rf /var/lib/apt/lists/*
# We set this pretend version as we do not have Git in our path, and we do
# not care enough about having the version correct inside the Docker container
# to install it.
Expand Down
51 changes: 15 additions & 36 deletions src/mock_vws/image_matchers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Matchers for query and duplicate requests."""

import io
from typing import Protocol, runtime_checkable
from typing import Protocol, cast, runtime_checkable

import numpy as np
import torch
from beartype import beartype
from PIL import Image
from torchmetrics.image import (
StructuralSimilarityIndexMeasure,
from skimage.metrics import ( # pylint: disable=no-name-in-module
structural_similarity, # pyright: ignore[reportUnknownVariableType]
)


Expand Down Expand Up @@ -78,42 +77,22 @@ def __call__(
first_image_resized = first_image.resize(size=target_size)
second_image_resized = second_image.resize(size=target_size)

first_image_np = np.array(object=first_image_resized, dtype=np.float32)
first_image_tensor = torch.tensor(data=first_image_np).float() / 255
first_image_tensor = first_image_tensor.view(
first_image_resized.size[1],
first_image_resized.size[0],
len(first_image_resized.getbands()),
first_image_np = (
np.array(object=first_image_resized, dtype=np.float32) / 255
)

second_image_np = np.array(
object=second_image_resized,
dtype=np.float32,
)
second_image_tensor = torch.tensor(data=second_image_np).float() / 255
second_image_tensor = second_image_tensor.view(
second_image_resized.size[1],
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SSIM fails on grayscale images missing channel dimension

High Severity

The structural_similarity call unconditionally passes channel_axis=2, but grayscale ("L" mode) images produce 2D numpy arrays with shape (H, W) — no axis 2 exists. The image validator explicitly allows grayscale images, so they can reach this code path. The old torch-based code handled this correctly via getbands() and .view(). The _brisque.py wrapper avoids this by calling .convert(mode="RGB"), but image_matchers.py does not convert images before creating numpy arrays.

Fix in Cursor Fix in Web

second_image_resized.size[0],
len(second_image_resized.getbands()),
second_image_np = (
np.array(object=second_image_resized, dtype=np.float32) / 255
)

first_image_tensor_batch_dimension = first_image_tensor.permute(
2,
0,
1,
).unsqueeze(dim=0)
second_image_tensor_batch_dimension = second_image_tensor.permute(
2,
0,
1,
).unsqueeze(dim=0)

ssim = StructuralSimilarityIndexMeasure(data_range=1.0)
ssim_value = ssim(
first_image_tensor_batch_dimension,
second_image_tensor_batch_dimension,
ssim_score: float = cast(
"float",
structural_similarity( # type: ignore[no-untyped-call]
im1=first_image_np,
im2=second_image_np,
data_range=1.0,
channel_axis=2,
),
)
ssim_score = ssim_value.item()

# Normalize SSIM score from -1 to 1 scale to 0 to 10 scale.
# This maps -1 to 0 and 1 to 10.
Expand Down
21 changes: 4 additions & 17 deletions src/mock_vws/target_raters.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
"""Raters for target quality."""

import functools
import io
import math
import secrets
from typing import Protocol, runtime_checkable

import numpy as np
import torch
from beartype import beartype
from PIL import Image
from piq.brisque import brisque # pyright: ignore[reportMissingTypeStubs]

from mock_vws._brisque import brisque_score


@functools.cache
Expand All @@ -25,21 +22,11 @@ def _get_brisque_target_tracking_rating(*, image_content: bytes) -> int:
Args:
image_content: A target's image's content.
"""
image_file = io.BytesIO(initial_bytes=image_content)
image = Image.open(fp=image_file)
image_np = np.array(object=image, dtype=np.float32)
image_tensor = torch.tensor(data=image_np).float() / 255
image_tensor = image_tensor.view(
image.size[1],
image.size[0],
len(image.getbands()),
)
image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(dim=0)
try:
brisque_score = brisque(x=image_tensor, data_range=255)
score = brisque_score(image_content=image_content)
except (AssertionError, IndexError):
return 0
return math.ceil(int(brisque_score.item()) / 20)
return math.ceil(int(score) / 20)


@runtime_checkable
Expand Down
Loading