diff --git a/tests/test_integrations_ultralytics.py b/tests/test_integrations_ultralytics.py new file mode 100644 index 0000000..e8b2b6a --- /dev/null +++ b/tests/test_integrations_ultralytics.py @@ -0,0 +1,448 @@ +"""Tests for the Ultralytics/YOLO integration extractor.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from wildedge.integrations.ultralytics import ( + YOLO_FAMILY_RE, + UltralyticsExtractor, + build_download_record, + classify_output_meta, + detect_accelerator, + detect_quantization, + detection_output_meta, + image_input_meta, + is_yolo, + weights_file_exists, +) +from wildedge.model import ModelHandle, ModelInfo + +# --------------------------------------------------------------------------- +# Fake objects — no ultralytics required +# --------------------------------------------------------------------------- + + +class FakeParam: + class Device: + type = "cpu" + + device = Device() + dtype = type("dtype", (), {"__str__": lambda self: "torch.float32"})() + + +class FakeInnerModel: + def parameters(self): + yield FakeParam() + + +class FakeBoxes: + class _Tensor: + def __init__(self, data): + self._data = data + + def tolist(self): + return self._data + + def __init__(self, preds): + # preds: list of (x1, y1, x2, y2, conf, cls) + self.xyxy = self._Tensor([[p[0], p[1], p[2], p[3]] for p in preds]) + self.conf = self._Tensor([p[4] for p in preds]) + self.cls = self._Tensor([p[5] for p in preds]) + + +class FakeResult: + def __init__(self, preds=None): + self.boxes = FakeBoxes(preds or []) + self.probs = None + + +class FakeClassifyProbs: + top5 = [2, 0, 1, 3, 4] + + class _Tensor: + def tolist(self): + return [0.9, 0.05, 0.02, 0.02, 0.01] + + top5conf = _Tensor() + + +class FakeClassifyResult: + boxes = None + probs = FakeClassifyProbs() + + +# YOLO must have this exact class name for is_yolo() to match. +class YOLO: + task = "detect" + names = {0: "person", 1: "bicycle"} + ckpt_path = "/models/yolov8n.pt" + model = FakeInnerModel() + + def __call__(self, source=None, stream=False, **kwargs): + return [FakeResult([(10, 20, 100, 200, 0.9, 0)])] + + +class FailingYOLO(YOLO): + def __call__(self, source=None, stream=False, **kwargs): + raise RuntimeError("cuda oom") + + +class ClassifyYOLO(YOLO): + task = "classify" + + def __call__(self, source=None, stream=False, **kwargs): + return [FakeClassifyResult()] + + +def make_handle(publish_spy) -> ModelHandle: + info = ModelInfo( + model_name="yolov8n", + model_version="unknown", + model_source="local", + model_format="pytorch", + ) + return ModelHandle(model_id="yolov8n", info=info, publish=publish_spy) + + +# --------------------------------------------------------------------------- +# is_yolo +# --------------------------------------------------------------------------- + + +def test_is_yolo_true_for_yolo_instance(): + assert is_yolo(YOLO()) is True + + +def test_is_yolo_false_for_plain_object(): + assert is_yolo(object()) is False + + +def test_is_yolo_false_for_string(): + assert is_yolo("yolo") is False + + +# --------------------------------------------------------------------------- +# YOLO_FAMILY_RE +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "name,expected", + [ + ("yolov8n", "yolov8"), + ("yolov8s-seg", "yolov8"), + ("yolov9e", "yolov9"), + ("yolov10n", "yolov10"), + ("yolo11n", "yolo11"), + ], +) +def test_yolo_family_re_extracts_prefix(name, expected): + m = YOLO_FAMILY_RE.match(name) + assert m is not None + assert m.group(1).lower() == expected + + +def test_yolo_family_re_no_match_for_unknown_format(): + assert YOLO_FAMILY_RE.match("yolo_nas_s") is None + + +# --------------------------------------------------------------------------- +# detect_accelerator +# --------------------------------------------------------------------------- + + +def test_detect_accelerator_returns_cpu_by_default(): + assert detect_accelerator(YOLO()) == "cpu" + + +def test_detect_accelerator_returns_cuda_when_param_on_cuda(): + model = YOLO() + cuda_param = MagicMock() + cuda_param.device.type = "cuda" + model.model = MagicMock() + model.model.parameters = MagicMock(return_value=iter([cuda_param])) + assert detect_accelerator(model) == "cuda" + + +def test_detect_accelerator_returns_cpu_when_no_inner_model(): + model = YOLO() + model.model = None + assert detect_accelerator(model) == "cpu" + + +# --------------------------------------------------------------------------- +# detect_quantization +# --------------------------------------------------------------------------- + + +def test_detect_quantization_f32_from_float32_dtype(): + assert detect_quantization(YOLO()) == "f32" + + +def test_detect_quantization_f16_from_float16_param(): + model = YOLO() + p = MagicMock() + p.dtype.__str__ = lambda self: "torch.float16" + model.model = MagicMock() + model.model.parameters = MagicMock(return_value=iter([p])) + assert detect_quantization(model) == "f16" + + +def test_detect_quantization_none_when_no_inner_model(): + model = YOLO() + model.model = None + assert detect_quantization(model) is None + + +# --------------------------------------------------------------------------- +# image_input_meta +# --------------------------------------------------------------------------- + + +def test_image_input_meta_returns_none_when_numpy_absent(): + with patch("wildedge.integrations.ultralytics.np", None): + assert image_input_meta(object()) is None + + +def test_image_input_meta_extracts_hwc_array(): + pytest.importorskip("numpy") + import numpy as np + + arr = np.zeros((480, 640, 3), dtype=np.uint8) + meta = image_input_meta(arr) + assert meta is not None + assert meta.width == 640 + assert meta.height == 480 + assert meta.channels == 3 + + +def test_image_input_meta_returns_none_for_non_array(): + assert image_input_meta("not_an_array") is None + + +# --------------------------------------------------------------------------- +# detection_output_meta +# --------------------------------------------------------------------------- + + +def test_detection_output_meta_builds_from_results(): + results = [FakeResult([(10, 20, 100, 200, 0.9, 0), (5, 5, 50, 50, 0.6, 1)])] + names = {0: "person", 1: "bicycle"} + meta = detection_output_meta(results, names) + assert meta is not None + assert meta.num_predictions == 2 + assert meta.top_k is not None + assert meta.top_k[0].label == "person" + assert meta.top_k[0].confidence == 0.9 + assert meta.avg_confidence == pytest.approx(0.75, abs=0.01) + assert meta.num_classes == 2 + + +def test_detection_output_meta_empty_results_gives_zero_preds(): + meta = detection_output_meta([FakeResult([])], {}) + assert meta is not None + assert meta.num_predictions == 0 + assert meta.top_k is None + + +def test_detection_output_meta_top_k_capped_at_five(): + preds = [(i, i, i + 10, i + 10, 0.5, i) for i in range(8)] + names = {i: str(i) for i in range(8)} + meta = detection_output_meta([FakeResult(preds)], names) + assert meta is not None + assert len(meta.top_k) == 5 + + +# --------------------------------------------------------------------------- +# classify_output_meta +# --------------------------------------------------------------------------- + + +def test_classify_output_meta_builds_from_results(): + names = {0: "cat", 1: "dog", 2: "bird", 3: "fish", 4: "hamster"} + meta = classify_output_meta([FakeClassifyResult()], names) + assert meta is not None + assert meta.num_predictions == 5 + assert meta.top_k is not None + assert meta.top_k[0].label == "bird" # top5[0] = 2 → "bird" + assert meta.avg_confidence == 0.9 + + +def test_classify_output_meta_missing_probs_gives_empty(): + result = MagicMock() + result.probs = None + meta = classify_output_meta([result], {}) + assert meta is not None + assert meta.num_predictions == 0 + + +# --------------------------------------------------------------------------- +# weights_file_exists +# --------------------------------------------------------------------------- + + +def test_weights_file_exists_true_when_file_on_disk(tmp_path): + f = tmp_path / "yolov8n.pt" + f.write_bytes(b"fake") + assert weights_file_exists(str(f)) is True + + +def test_weights_file_exists_false_when_file_missing(tmp_path): + with patch("wildedge.integrations.ultralytics._ULTRALYTICS_WEIGHTS_DIR", None): + assert weights_file_exists(str(tmp_path / "missing.pt")) is False + + +def test_weights_file_exists_true_for_non_string_arg(): + assert weights_file_exists(42) is True + + +def test_weights_file_exists_checks_ultralytics_weights_dir(tmp_path): + f = tmp_path / "yolov8n.pt" + f.write_bytes(b"fake") + with patch("wildedge.integrations.ultralytics._ULTRALYTICS_WEIGHTS_DIR", tmp_path): + assert weights_file_exists("yolov8n.pt") is True + + +# --------------------------------------------------------------------------- +# build_download_record +# --------------------------------------------------------------------------- + + +def test_build_download_record_builds_for_existing_file(tmp_path): + f = tmp_path / "yolov8n.pt" + f.write_bytes(b"x" * 1000) + model = MagicMock() + model.ckpt_path = str(f) + rec = build_download_record(model, load_ms=500) + assert rec is not None + assert rec["repo_id"] == "yolov8n" + assert rec["size"] == 1000 + assert rec["cache_hit"] is False + assert rec["source_type"] == "ultralytics" + assert "yolov8n.pt" in rec["source_url"] + + +def test_build_download_record_returns_none_when_no_ckpt_path(): + model = MagicMock() + model.ckpt_path = None + assert build_download_record(model, load_ms=100) is None + + +def test_build_download_record_returns_none_when_file_missing(tmp_path): + model = MagicMock() + model.ckpt_path = str(tmp_path / "missing.pt") + assert build_download_record(model, load_ms=100) is None + + +# --------------------------------------------------------------------------- +# UltralyticsExtractor +# --------------------------------------------------------------------------- + +_extractor = UltralyticsExtractor() + + +def test_extractor_can_handle_yolo_instance(): + assert _extractor.can_handle(YOLO()) is True + + +def test_extractor_can_handle_rejects_other_types(): + assert _extractor.can_handle(object()) is False + + +def test_extractor_extract_info_uses_ckpt_stem_as_model_id(): + model_id, info = _extractor.extract_info(YOLO(), {}) + assert model_id == "yolov8n" + assert info.model_name == "yolov8n" + assert info.model_format == "pytorch" + + +def test_extractor_extract_info_derives_family(): + _, info = _extractor.extract_info(YOLO(), {}) + assert info.model_family == "yolov8" + + +def test_extractor_extract_info_override_model_id(): + model_id, _ = _extractor.extract_info(YOLO(), {"id": "my-detector"}) + assert model_id == "my-detector" + + +def test_extractor_extract_info_fallback_model_name_when_no_ckpt(): + model = YOLO() + model.ckpt_path = "" + model_id, info = _extractor.extract_info(model, {}) + assert model_id == "yolo" + assert info.model_name == "yolo" + + +def test_extractor_memory_bytes_returns_file_size(tmp_path): + f = tmp_path / "yolov8n.pt" + f.write_bytes(b"x" * 2048) + model = YOLO() + model.ckpt_path = str(f) + assert _extractor.memory_bytes(model) == 2048 + + +def test_extractor_memory_bytes_returns_none_when_file_missing(): + model = YOLO() + model.ckpt_path = "/nonexistent/path.pt" + assert _extractor.memory_bytes(model) is None + + +def test_extractor_install_hooks_publishes_inference_on_call(publish_spy): + model = YOLO() + handle = make_handle(publish_spy) + _extractor.install_hooks(model, handle) + model() + assert len(publish_spy.events) == 1 + assert publish_spy.events[0]["event_type"] == "inference" + + +def test_extractor_install_hooks_detect_output_meta(publish_spy): + model = YOLO() + handle = make_handle(publish_spy) + _extractor.install_hooks(model, handle) + model() + event = publish_spy.events[0] + assert event["inference"]["output_meta"]["task"] == "detection" + assert event["inference"]["output_meta"]["num_predictions"] == 1 + assert event["inference"]["input_modality"] == "image" + assert event["inference"]["output_modality"] == "detection" + + +def test_extractor_install_hooks_classify_output_meta(publish_spy): + model = ClassifyYOLO() + handle = make_handle(publish_spy) + _extractor.install_hooks(model, handle) + model() + assert publish_spy.events[0]["inference"]["output_modality"] == "classification" + + +def test_extractor_install_hooks_tracks_error_on_exception(publish_spy): + model = FailingYOLO() + handle = make_handle(publish_spy) + _extractor.install_hooks(model, handle) + with pytest.raises(RuntimeError, match="cuda oom"): + model() + assert publish_spy.events[0]["event_type"] == "error" + + +def test_extractor_install_hooks_does_not_affect_other_instances(publish_spy): + model_a = YOLO() + model_b = YOLO() + handle = make_handle(publish_spy) + _extractor.install_hooks(model_a, handle) + model_b() + assert len(publish_spy.events) == 0 + + +def test_extractor_install_hooks_batch_size_from_list_source(publish_spy): + np = pytest.importorskip("numpy") + model = YOLO() + handle = make_handle(publish_spy) + _extractor.install_hooks(model, handle) + frames = [np.zeros((480, 640, 3), dtype=np.uint8)] * 3 + model(frames) + assert publish_spy.events[0]["inference"]["batch_size"] == 3 diff --git a/wildedge/client.py b/wildedge/client.py index ac2dec2..5056d11 100644 --- a/wildedge/client.py +++ b/wildedge/client.py @@ -24,6 +24,7 @@ from wildedge.integrations.pytorch import PytorchExtractor from wildedge.integrations.registry import noop_integrations, supported_integrations from wildedge.integrations.tensorflow import TensorflowExtractor +from wildedge.integrations.ultralytics import UltralyticsExtractor from wildedge.logging import enable_debug, logger from wildedge.model import ModelHandle, ModelInfo, ModelRegistry from wildedge.paths import ( @@ -80,6 +81,7 @@ def parse_dsn(dsn: str) -> tuple[str, str, str]: DEFAULT_EXTRACTORS: list[BaseExtractor] = [ OnnxExtractor(), GgufExtractor(), + UltralyticsExtractor(), PytorchExtractor(), TensorflowExtractor(), KerasExtractor(), @@ -106,6 +108,7 @@ class WildEdge: "onnx": OnnxExtractor.install_auto_load_patch, "timm": PytorchExtractor.install_timm_patch, "tensorflow": TensorflowExtractor.install_auto_load_patch, + "ultralytics": UltralyticsExtractor.install_auto_load_patch, } # Hub trackers: record download provenance (where models came from). @@ -423,6 +426,10 @@ def instrument( ``"tensorflow"`` Patches ``tf.keras.models.load_model`` and ``tf.saved_model.load``. Requires ``tensorflow``. + ``"ultralytics"`` + Patches ``ultralytics.YOLO.__init__``. Requires ``ultralytics``. + Emits a download event on first load if weights were fetched from + the ultralytics CDN. ``"torch"`` / ``"keras"`` No global constructor to patch; models are user-defined subclasses. Inference is tracked automatically once a model is registered; diff --git a/wildedge/integrations/registry.py b/wildedge/integrations/registry.py index e6b6f7c..3a53778 100644 --- a/wildedge/integrations/registry.py +++ b/wildedge/integrations/registry.py @@ -30,6 +30,7 @@ class IntegrationSpec: IntegrationSpec("torch", ("torch",), "noop"), IntegrationSpec("keras", ("keras",), "noop"), IntegrationSpec("tensorflow", ("tensorflow",), "client_patch"), + IntegrationSpec("ultralytics", ("ultralytics",), "client_patch"), ) INTEGRATIONS_BY_NAME: dict[str, IntegrationSpec] = { diff --git a/wildedge/integrations/ultralytics.py b/wildedge/integrations/ultralytics.py new file mode 100644 index 0000000..c4aa6b7 --- /dev/null +++ b/wildedge/integrations/ultralytics.py @@ -0,0 +1,446 @@ +"""Ultralytics integration.""" + +from __future__ import annotations + +import re +import threading +import time +from pathlib import Path +from typing import TYPE_CHECKING + +from wildedge import constants +from wildedge.events.inference import ( + ClassificationOutputMeta, + DetectionOutputMeta, + HistogramSummary, + ImageInputMeta, + TopKPrediction, +) +from wildedge.integrations.base import BaseExtractor, patch_instance_call_once +from wildedge.logging import logger +from wildedge.model import ModelInfo +from wildedge.timing import elapsed_ms + +try: + import ultralytics as _ultralytics + from ultralytics.utils import ( + WEIGHTS_DIR as _ULTRALYTICS_WEIGHTS_DIR, # type: ignore[import-untyped] + ) +except ImportError: + _ultralytics = None # type: ignore[assignment] + _ULTRALYTICS_WEIGHTS_DIR = None # type: ignore[assignment] + +try: + import numpy as np # type: ignore[import-untyped] +except ImportError: + np = None # type: ignore[assignment] + +if TYPE_CHECKING: + from wildedge.model import ModelHandle + +# --- Patch state (mutable, module-level) --- +_ultralytics_patched = False +_ULTRALYTICS_PATCH_LOCK = threading.Lock() + +# --- Marker names written onto patched objects and classes --- +YOLO_CALL_PATCH_NAME = "ultralytics_call" +YOLO_HANDLE_ATTR = "__wildedge_yolo_handle__" +YOLO_AUTO_LOAD_PATCH_NAME = "ultralytics_auto_load" + +# --- Extracts family prefix from names like "yolov8n", "yolov9e", "yolo11n" --- +YOLO_FAMILY_RE = re.compile(r"^(yolo(?:v\d+|\d+))", re.IGNORECASE) + + +def debug_failure(context: str, exc: BaseException) -> None: + logger.debug("wildedge: ultralytics %s failed: %s", context, exc) + + +def is_yolo(obj: object) -> bool: + # String check avoids importing ultralytics when it is not installed. + return type(obj).__name__ == "YOLO" + + +def detect_accelerator(obj: object) -> str: + try: + inner = getattr(obj, "model", None) + if inner is not None: + try: + first_param = next(inner.parameters()) + device_type = str(getattr(first_param.device, "type", "")) + if device_type: + return device_type + except StopIteration: + pass + except Exception as exc: + debug_failure("accelerator detection", exc) + return "cpu" + + +def detect_quantization(obj: object) -> str | None: + try: + inner = getattr(obj, "model", None) + if inner is None: + return None + for p in inner.parameters(): + dtype = str(p.dtype) + if "bfloat16" in dtype: + return "bf16" + if "float16" in dtype: + return "f16" + if "int8" in dtype or "qint" in dtype: + return "int8" + if "float32" in dtype: + return "f32" + break # only need the first parameter's dtype + except Exception as exc: + debug_failure("quantization detection", exc) + return None + + +def image_input_meta(arr: object) -> ImageInputMeta | None: + """Extract ImageInputMeta from a numpy array (H, W, C). Best-effort, never raises.""" + try: + if np is None or not isinstance(arr, np.ndarray): + return None + shape = arr.shape + if len(shape) == 3: + h, w, c = shape + elif len(shape) == 2: + h, w = shape + c = 1 + else: + return None + + floats = arr.astype(np.float32) + t_min = float(floats.min()) + t_max = float(floats.max()) + span = t_max - t_min + norm = (floats - t_min) / span if span > 0 else floats + + brightness_mean = round(float(norm.mean()), 4) + brightness_stddev = round(float(norm.std()), 4) + + flat = norm.flatten() + edges = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] + buckets = [ + int(((flat >= lo) & (flat < hi)).sum()) for lo, hi in zip(edges, edges[1:]) + ] + buckets[-1] += int((flat == 1.0).sum()) + + return ImageInputMeta( + width=int(w), + height=int(h), + channels=int(c), + histogram_summary=HistogramSummary( + brightness_mean=brightness_mean, + brightness_stddev=brightness_stddev, + brightness_buckets=buckets, + contrast=brightness_stddev, + ), + ) + except Exception as exc: + debug_failure("image input meta extraction", exc) + return None + + +def detection_output_meta( + results: list, names: dict | None +) -> DetectionOutputMeta | None: + """Build DetectionOutputMeta from a list of ultralytics Results. Best-effort, never raises.""" + try: + total_preds = 0 + all_confs: list[float] = [] + top_preds: list[TopKPrediction] = [] + + for result in results: + boxes = getattr(result, "boxes", None) + if boxes is None: + continue + try: + xyxy = boxes.xyxy.tolist() if hasattr(boxes.xyxy, "tolist") else [] + confs = boxes.conf.tolist() if hasattr(boxes.conf, "tolist") else [] + classes = boxes.cls.tolist() if hasattr(boxes.cls, "tolist") else [] + except Exception: + continue + + total_preds += len(confs) + all_confs.extend(confs) + + for i, (conf, cls_id) in enumerate(zip(confs, classes)): + if len(top_preds) >= 5: + break + label = (names or {}).get(int(cls_id), str(int(cls_id))) + bbox = [int(v) for v in xyxy[i]] if i < len(xyxy) else None + top_preds.append( + TopKPrediction( + label=label, + confidence=round(float(conf), 4), + bbox=bbox, + ) + ) + + avg_conf = round(sum(all_confs) / len(all_confs), 4) if all_confs else None + return DetectionOutputMeta( + task="detection", + num_predictions=total_preds, + top_k=top_preds if top_preds else None, + avg_confidence=avg_conf, + num_classes=len(names) if names else None, + ) + except Exception as exc: + debug_failure("detection output meta extraction", exc) + return None + + +def classify_output_meta( + results: list, names: dict | None +) -> ClassificationOutputMeta | None: + """Build ClassificationOutputMeta from a list of ultralytics Results for classify task. + Best-effort, never raises.""" + try: + top_preds: list[TopKPrediction] = [] + top1_conf: float | None = None + + for result in results: + probs = getattr(result, "probs", None) + if probs is None: + continue + try: + top5 = probs.top5 + top5conf = ( + probs.top5conf.tolist() if hasattr(probs.top5conf, "tolist") else [] + ) + for cls_id, conf in zip(top5, top5conf): + label = (names or {}).get(int(cls_id), str(int(cls_id))) + top_preds.append( + TopKPrediction(label=label, confidence=round(float(conf), 4)) + ) + if top5conf: + top1_conf = round(float(top5conf[0]), 4) + except Exception: + continue + break # one result per sample for classify + + return ClassificationOutputMeta( + num_predictions=len(top_preds), + top_k=top_preds if top_preds else None, + avg_confidence=top1_conf, + ) + except Exception as exc: + debug_failure("classify output meta extraction", exc) + return None + + +def weights_file_exists(model_arg: object) -> bool: + """Return True if the weights file appears to already be on disk.""" + if not isinstance(model_arg, str): + return True # not a path string — weights already in memory or a loaded object + p = Path(model_arg) + if p.is_file(): + return True + # ultralytics resolves short names (e.g. "yolov8n.pt") against its assets dir + if ( + _ULTRALYTICS_WEIGHTS_DIR is not None + and (_ULTRALYTICS_WEIGHTS_DIR / p.name).is_file() + ): + return True + return False + + +def build_download_record(obj: object, load_ms: int) -> dict | None: + """Build a single download record for the model weights file. Best-effort.""" + try: + ckpt_path = getattr(obj, "ckpt_path", None) + if not ckpt_path: + return None + p = Path(ckpt_path) + if not p.is_file(): + return None + file_size = p.stat().st_size + bandwidth_bps = int(file_size / load_ms * 1000) if load_ms > 0 else None + return { + "repo_id": p.stem, + "source_type": "ultralytics", + "source_url": f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{p.name}", + "size": file_size, + "duration_ms": load_ms, + "cache_hit": False, + "bandwidth_bps": bandwidth_bps, + } + except Exception as exc: + debug_failure("download record build", exc) + return None + + +def build_patched_call(original_call): # type: ignore[no-untyped-def] + def patched_call(self_inner, *args, **kwargs): # type: ignore[no-untyped-def] + handle = getattr(self_inner, YOLO_HANDLE_ATTR, None) + if handle is None: + return original_call(self_inner, *args, **kwargs) + + source = args[0] if args else kwargs.get("source") + + batch_size: int | None = None + input_meta = None + if source is not None and np is not None: + try: + if isinstance(source, np.ndarray): + input_meta = image_input_meta(source) + batch_size = 1 + elif isinstance(source, list) and source: + batch_size = len(source) + if isinstance(source[0], np.ndarray): + input_meta = image_input_meta(source[0]) + except Exception as exc: + debug_failure("input meta extraction", exc) + + t0 = time.perf_counter() + try: + results = original_call(self_inner, *args, **kwargs) + duration_ms = elapsed_ms(t0) + + task = getattr(self_inner, "task", "detect") or "detect" + names = getattr(self_inner, "names", None) + + output_meta = None + output_modality = "detection" + if isinstance(results, list) and results: + try: + if task == "classify": + output_meta = classify_output_meta(results, names) + output_modality = "classification" + else: + output_meta = detection_output_meta(results, names) + except Exception as exc: + debug_failure("output meta extraction", exc) + + handle.track_inference( + duration_ms=duration_ms, + batch_size=batch_size, + input_modality="image", + output_modality=output_modality, + input_meta=input_meta, + output_meta=output_meta, + success=True, + ) + return results + except Exception as exc: + handle.track_error( + error_code="UNKNOWN", + error_message=str(exc)[: constants.ERROR_MSG_MAX_LEN], + ) + raise + + return patched_call + + +class UltralyticsExtractor(BaseExtractor): + def can_handle(self, obj: object) -> bool: + return is_yolo(obj) + + def extract_info( + self, obj: object, overrides: dict + ) -> tuple[str | None, ModelInfo]: + ckpt_path = getattr(obj, "ckpt_path", None) or "" + stem = Path(ckpt_path).stem if ckpt_path else None + model_name = stem or "yolo" + model_id = overrides.pop("id", None) or model_name + + family = overrides.pop("family", None) + if family is None and model_name: + m = YOLO_FAMILY_RE.match(model_name) + family = m.group(1).lower() if m else None + if family is None: + logger.warning( + "wildedge: ultralytics model family could not be detected - sending as null" + ) + + quantization = overrides.pop("quantization", None) or detect_quantization(obj) + if quantization is None: + logger.warning( + "wildedge: ultralytics model quantization could not be detected - sending as null" + ) + + version = overrides.pop("version", "unknown") + source = overrides.pop("source", "local") + + info = ModelInfo( + model_name=model_name, + model_version=version, + model_source=source, + model_format="pytorch", + model_family=family, + quantization=quantization, + ) + for k, v in overrides.items(): + if hasattr(info, k): + setattr(info, k, v) + + return model_id, info + + def memory_bytes(self, obj: object) -> int | None: + try: + ckpt_path = getattr(obj, "ckpt_path", None) + if ckpt_path: + return Path(ckpt_path).stat().st_size + except Exception as exc: + debug_failure("model size detection", exc) + return None + + def install_hooks(self, obj: object, handle: ModelHandle) -> None: + handle.detected_accelerator = detect_accelerator(obj) + setattr(obj, YOLO_HANDLE_ATTR, handle) + patch_instance_call_once( + obj, + patch_name=YOLO_CALL_PATCH_NAME, + make_patched_call=build_patched_call, + ) + + @classmethod + def install_auto_load_patch(cls, client_ref: object) -> None: + """Patch ultralytics.YOLO.__init__ for automatic load, unload, and download tracking. + + Called once at WildEdge client initialisation. Any subsequent + ``YOLO(...)`` construction is timed and registered automatically. + If the weights file does not exist before the call, a download event + is emitted with the file size and the total load duration. + """ + global _ultralytics_patched + if _ultralytics_patched or _ultralytics is None: + return + + with _ULTRALYTICS_PATCH_LOCK: + if _ultralytics_patched: + return + + original_init = _ultralytics.YOLO.__init__ + if ( + getattr(original_init, "__wildedge_patch_name__", None) + == YOLO_AUTO_LOAD_PATCH_NAME + ): + _ultralytics_patched = True + return + + def patched_init(self_inner, *args, **kwargs): # type: ignore[no-untyped-def] + model_arg = args[0] if args else kwargs.get("model", "yolov8n.pt") + weights_existed = weights_file_exists(model_arg) + + t0 = time.perf_counter() + original_init(self_inner, *args, **kwargs) + load_ms = elapsed_ms(t0) + + c = client_ref() # type: ignore[call-arg] + if c is not None and not c.closed: + downloads = None + if not weights_existed: + record = build_download_record(self_inner, load_ms) + if record is not None: + downloads = [record] + c._on_model_auto_loaded( + self_inner, load_ms=load_ms, downloads=downloads + ) + + patched_init.__wildedge_patch_name__ = YOLO_AUTO_LOAD_PATCH_NAME # type: ignore[attr-defined] + patched_init.__wildedge_original_call__ = original_init # type: ignore[attr-defined] + _ultralytics.YOLO.__init__ = patched_init + _ultralytics_patched = True