From a2807332042cb431ef36f86a9cc23da45a5b49d5 Mon Sep 17 00:00:00 2001 From: Piotr Duda Date: Mon, 16 Mar 2026 17:15:21 +0100 Subject: [PATCH 1/2] OpenAI SDK integration --- pyproject.toml | 2 + tests/test_integrations_openai.py | 513 ++++++++++++++++++++++++++++++ uv.lock | 25 ++ wildedge/client.py | 3 + wildedge/events/__init__.py | 2 + wildedge/events/inference.py | 25 ++ wildedge/integrations/openai.py | 304 ++++++++++++++++++ wildedge/integrations/registry.py | 1 + wildedge/model.py | 3 + 9 files changed, 878 insertions(+) create mode 100644 tests/test_integrations_openai.py create mode 100644 wildedge/integrations/openai.py diff --git a/pyproject.toml b/pyproject.toml index 561a335..5ff8799 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ ban-relative-imports = "all" [tool.pytest.ini_options] testpaths = ["tests"] log_level = "ERROR" +asyncio_mode = "auto" markers = [ "requires_linux: test only on Linux", "requires_macos: test only on macOS", @@ -47,6 +48,7 @@ markers = [ [dependency-groups] dev = [ "pytest>=9.0.2", + "pytest-asyncio>=0.25", "pytest-mock>=3.15.1", "ruff>=0.15.4", "tox", diff --git a/tests/test_integrations_openai.py b/tests/test_integrations_openai.py new file mode 100644 index 0000000..9fea3ce --- /dev/null +++ b/tests/test_integrations_openai.py @@ -0,0 +1,513 @@ +"""Tests for the OpenAI / OpenRouter integration.""" + +from __future__ import annotations + +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import wildedge.integrations.openai as openai_mod +from wildedge.integrations.openai import ( + OpenAIExtractor, + build_api_meta, + build_input_meta, + build_output_meta, + source_from_base_url, + wrap_async_completions, + wrap_sync_completions, +) +from wildedge.model import ModelHandle, ModelInfo + +# --------------------------------------------------------------------------- +# Fake objects — no openai library required +# --------------------------------------------------------------------------- + + +class FakePromptDetails: + cached_tokens = 5 + + +class FakeCompletionDetails: + reasoning_tokens = 3 + + +class FakeUsage: + prompt_tokens = 10 + completion_tokens = 20 + prompt_tokens_details = FakePromptDetails() + completion_tokens_details = FakeCompletionDetails() + + +class FakeUsageNoDetails: + prompt_tokens = 10 + completion_tokens = 20 + prompt_tokens_details = None + completion_tokens_details = None + + +class FakeChoice: + finish_reason = "stop" + + +class FakeResponse: + model = "gpt-4o-2024-08-06" + system_fingerprint = "fp_abc123" + service_tier = "default" + usage = FakeUsage() + choices = [FakeChoice()] + + +class FakeResponseNoUsage: + model = None + system_fingerprint = None + service_tier = None + usage = None + choices = [] + + +class FakeCompletions: + def __init__(self, response=None): + self._response = response or FakeResponse() + + def create(self, *args, **kwargs): + return self._response + + +class FakeAsyncCompletions: + def __init__(self, response=None): + self._response = response or FakeResponse() + + async def create(self, *args, **kwargs): + return self._response + + +# Named "OpenAI" / "AsyncOpenAI" so can_handle sees the right type name. +class OpenAI: + def __init__(self, base_url="https://api.openai.com/v1", api_key=None): + self.base_url = base_url + + +class AsyncOpenAI: + def __init__(self, base_url="https://api.openai.com/v1", api_key=None): + self.base_url = base_url + + +def make_handle(publish_spy) -> ModelHandle: + info = ModelInfo( + model_name="test", + model_version="1.0", + model_source="openai", + model_format="api", + ) + return ModelHandle(model_id="gpt-4o", info=info, publish=publish_spy) + + +def make_fake_client(closed=False): + client = SimpleNamespace(closed=closed, handles={}) + + def register_model(obj, *, model_id=None, source=None, **kwargs): + if model_id not in client.handles: + client.handles[model_id] = SimpleNamespace( + model_id=model_id, + track_inference=MagicMock(), + track_error=MagicMock(), + ) + return client.handles[model_id] + + client.register_model = register_model + return client + + +# --------------------------------------------------------------------------- +# source_from_base_url +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "url,expected", + [ + (None, "openai"), + ("", "openai"), + ("https://openrouter.ai/api/v1", "openrouter"), + ("https://api.openai.com/v1", "openai"), + ("https://api.together.xyz/v1", "api.together.xyz"), + ("https://localhost:11434/v1", "localhost"), + ], +) +def test_source_from_base_url(url, expected): + assert source_from_base_url(url) == expected + + +# --------------------------------------------------------------------------- +# build_input_meta +# --------------------------------------------------------------------------- + + +def test_build_input_meta_picks_last_user_message(): + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello world"}, + ] + meta = build_input_meta(messages, tokens_in=5) + assert meta is not None + assert meta.char_count == len("Hello world") + assert meta.word_count == 2 + assert meta.token_count == 5 + assert meta.prompt_type == "chat" + + +def test_build_input_meta_no_user_message_returns_none(): + assert ( + build_input_meta([{"role": "system", "content": "sys"}], tokens_in=None) is None + ) + + +def test_build_input_meta_empty_messages_returns_none(): + assert build_input_meta([], tokens_in=None) is None + + +def test_build_input_meta_non_string_content_returns_none(): + assert ( + build_input_meta( + [{"role": "user", "content": [{"type": "image_url"}]}], tokens_in=None + ) + is None + ) + + +# --------------------------------------------------------------------------- +# build_output_meta +# --------------------------------------------------------------------------- + + +def test_build_output_meta_extracts_tokens_and_stop_reason(): + meta = build_output_meta(FakeResponse(), duration_ms=500) + assert meta is not None + assert meta.tokens_in == 10 + assert meta.tokens_out == 20 + assert meta.stop_reason == "stop" + assert meta.tokens_per_second == pytest.approx(40.0) + + +def test_build_output_meta_extracts_cached_and_reasoning_tokens(): + meta = build_output_meta(FakeResponse(), duration_ms=500) + assert meta is not None + assert meta.cached_input_tokens == 5 + assert meta.reasoning_tokens_out == 3 + + +def test_build_output_meta_none_cached_when_no_details(): + response = SimpleNamespace( + usage=SimpleNamespace( + prompt_tokens=10, + completion_tokens=5, + prompt_tokens_details=None, + completion_tokens_details=None, + ), + choices=[SimpleNamespace(finish_reason="stop")], + ) + meta = build_output_meta(response, duration_ms=100) + assert meta is not None + assert meta.cached_input_tokens is None + assert meta.reasoning_tokens_out is None + + +def test_build_output_meta_none_when_no_usage(): + assert build_output_meta(FakeResponseNoUsage(), duration_ms=500) is None + + +def test_build_output_meta_zero_duration_gives_no_tps(): + meta = build_output_meta(FakeResponse(), duration_ms=0) + assert meta is not None + assert meta.tokens_per_second is None + + +# --------------------------------------------------------------------------- +# build_api_meta +# --------------------------------------------------------------------------- + + +def test_build_api_meta_extracts_all_fields(): + meta = build_api_meta(FakeResponse()) + assert meta is not None + assert meta.resolved_model_id == "gpt-4o-2024-08-06" + assert meta.system_fingerprint == "fp_abc123" + assert meta.service_tier == "default" + + +def test_build_api_meta_none_when_all_fields_absent(): + assert build_api_meta(FakeResponseNoUsage()) is None + + +def test_build_api_meta_partial_fields(): + response = SimpleNamespace( + model="gpt-4o", system_fingerprint=None, service_tier=None + ) + meta = build_api_meta(response) + assert meta is not None + assert meta.resolved_model_id == "gpt-4o" + assert meta.system_fingerprint is None + + +def test_build_api_meta_to_dict_omits_none_fields(): + response = SimpleNamespace( + model="gpt-4o", system_fingerprint=None, service_tier=None + ) + meta = build_api_meta(response) + assert meta is not None + d = meta.to_dict() + assert "resolved_model_id" in d + assert "system_fingerprint" not in d + assert "service_tier" not in d + + +# --------------------------------------------------------------------------- +# OpenAIExtractor +# --------------------------------------------------------------------------- + + +class TestOpenAIExtractor: + extractor = OpenAIExtractor() + + def test_can_handle_openai(self): + assert self.extractor.can_handle(OpenAI()) + + def test_can_handle_async_openai(self): + assert self.extractor.can_handle(AsyncOpenAI()) + + def test_can_handle_rejects_other_types(self): + assert not self.extractor.can_handle(object()) + assert not self.extractor.can_handle("string") + + def test_extract_info_uses_override_model_id_and_source(self): + obj = OpenAI(base_url="https://openrouter.ai/api/v1") + model_id, info = self.extractor.extract_info( + obj, {"id": "qwen/qwen3-235b", "source": "openrouter"} + ) + assert model_id == "qwen/qwen3-235b" + assert info.model_source == "openrouter" + assert info.model_format == "api" + assert info.model_name == "qwen/qwen3-235b" + + def test_extract_info_derives_source_from_base_url(self): + obj = OpenAI(base_url="https://openrouter.ai/api/v1") + _, info = self.extractor.extract_info(obj, {"id": "qwen/qwen3-235b"}) + assert info.model_source == "openrouter" + + def test_extract_info_returns_none_model_id_when_not_provided(self): + model_id, _ = self.extractor.extract_info(OpenAI(), {}) + assert model_id is None + + def test_install_hooks_is_noop(self, publish_spy): + handle = make_handle(publish_spy) + self.extractor.install_hooks(OpenAI(), handle) # must not raise + + +# --------------------------------------------------------------------------- +# wrap_sync_completions +# --------------------------------------------------------------------------- + + +class TestWrapSyncCompletions: + def setup(self, response=None, closed=False): + completions = FakeCompletions(response) + client = make_fake_client(closed=closed) + wrap_sync_completions(completions, "openai", lambda: client) + return completions, client + + def test_returns_response(self): + completions, _ = self.setup() + result = completions.create( + model="gpt-4o", messages=[{"role": "user", "content": "hi"}] + ) + assert isinstance(result, FakeResponse) + + def test_registers_model_on_first_call(self): + completions, client = self.setup() + completions.create(model="gpt-4o", messages=[]) + assert "gpt-4o" in client.handles + + def test_lazy_registration_only_once(self): + completions, client = self.setup() + completions.create(model="gpt-4o", messages=[]) + completions.create(model="gpt-4o", messages=[]) + assert len(client.handles) == 1 + + def test_tracks_inference_with_token_counts(self): + completions, client = self.setup() + completions.create( + model="gpt-4o", messages=[{"role": "user", "content": "hello"}] + ) + handle = client.handles["gpt-4o"] + handle.track_inference.assert_called_once() + kwargs = handle.track_inference.call_args.kwargs + assert kwargs["input_modality"] == "text" + assert kwargs["output_modality"] == "generation" + assert kwargs["success"] is True + assert kwargs["output_meta"].tokens_out == 20 + + def test_tracks_api_meta(self): + completions, client = self.setup() + completions.create(model="gpt-4o", messages=[]) + kwargs = client.handles["gpt-4o"].track_inference.call_args.kwargs + assert kwargs["api_meta"] is not None + assert kwargs["api_meta"].resolved_model_id == "gpt-4o-2024-08-06" + assert kwargs["api_meta"].system_fingerprint == "fp_abc123" + + def test_tracks_error_and_reraises(self): + class ErrorCompletions: + def create(self, *args, **kwargs): + raise RuntimeError("api error") + + client = make_fake_client() + completions = ErrorCompletions() + wrap_sync_completions(completions, "openai", lambda: client) + + with pytest.raises(RuntimeError, match="api error"): + completions.create(model="gpt-4o", messages=[]) + + client.handles["gpt-4o"].track_error.assert_called_once() + client.handles["gpt-4o"].track_inference.assert_not_called() + + def test_streaming_skips_tracking(self): + completions, client = self.setup() + completions.create(model="gpt-4o", messages=[], stream=True) + if "gpt-4o" in client.handles: + client.handles["gpt-4o"].track_inference.assert_not_called() + + def test_closed_client_passes_through(self): + completions, client = self.setup(closed=True) + result = completions.create(model="gpt-4o", messages=[]) + assert isinstance(result, FakeResponse) + assert "gpt-4o" not in client.handles + + def test_different_models_get_separate_handles(self): + completions, client = self.setup() + completions.create(model="gpt-4o", messages=[]) + completions.create(model="gpt-4-turbo", messages=[]) + assert "gpt-4o" in client.handles + assert "gpt-4-turbo" in client.handles + + +# --------------------------------------------------------------------------- +# wrap_async_completions +# --------------------------------------------------------------------------- + + +class TestWrapAsyncCompletions: + def setup(self, response=None, closed=False): + completions = FakeAsyncCompletions(response) + client = make_fake_client(closed=closed) + wrap_async_completions(completions, "openrouter", lambda: client) + return completions, client + + async def test_returns_response(self): + completions, _ = self.setup() + result = await completions.create( + model="qwen/qwen3-235b", messages=[{"role": "user", "content": "hi"}] + ) + assert isinstance(result, FakeResponse) + + async def test_registers_model_on_first_call(self): + completions, client = self.setup() + await completions.create(model="qwen/qwen3-235b", messages=[]) + assert "qwen/qwen3-235b" in client.handles + + async def test_tracks_inference(self): + completions, client = self.setup() + await completions.create( + model="qwen/qwen3-235b", messages=[{"role": "user", "content": "hello"}] + ) + handle = client.handles["qwen/qwen3-235b"] + handle.track_inference.assert_called_once() + assert handle.track_inference.call_args.kwargs["output_meta"].tokens_out == 20 + + async def test_tracks_error_and_reraises(self): + class ErrorAsyncCompletions: + async def create(self, *args, **kwargs): + raise RuntimeError("timeout") + + client = make_fake_client() + completions = ErrorAsyncCompletions() + wrap_async_completions(completions, "openai", lambda: client) + + with pytest.raises(RuntimeError, match="timeout"): + await completions.create(model="gpt-4o", messages=[]) + + client.handles["gpt-4o"].track_error.assert_called_once() + + async def test_streaming_skips_tracking(self): + completions, client = self.setup() + await completions.create(model="qwen/qwen3-235b", messages=[], stream=True) + if "qwen/qwen3-235b" in client.handles: + client.handles["qwen/qwen3-235b"].track_inference.assert_not_called() + + +# --------------------------------------------------------------------------- +# install_auto_load_patch +# --------------------------------------------------------------------------- + + +def test_install_auto_load_patch_is_idempotent(monkeypatch): + class FakeOpenAI: + def __init__(self, *args, **kwargs): + pass + + class FakeAsyncOpenAI: + def __init__(self, *args, **kwargs): + pass + + fake_openai = types.SimpleNamespace(OpenAI=FakeOpenAI, AsyncOpenAI=FakeAsyncOpenAI) + monkeypatch.setattr(openai_mod, "_openai", fake_openai) + monkeypatch.setattr(openai_mod, "_openai_patched", False) + + OpenAIExtractor.install_auto_load_patch(lambda: None) + first_sync = fake_openai.OpenAI.__init__ + first_async = fake_openai.AsyncOpenAI.__init__ + + OpenAIExtractor.install_auto_load_patch(lambda: None) + assert fake_openai.OpenAI.__init__ is first_sync + assert fake_openai.AsyncOpenAI.__init__ is first_async + + +def test_install_auto_load_patch_skips_when_openai_missing(monkeypatch): + monkeypatch.setattr(openai_mod, "_openai", None) + monkeypatch.setattr(openai_mod, "_openai_patched", False) + + OpenAIExtractor.install_auto_load_patch(lambda: None) + assert not openai_mod._openai_patched + + +def test_install_auto_load_patch_wraps_new_client_instances(monkeypatch): + class FakeCompletionsInner: + def create(self, *args, **kwargs): + return FakeResponseNoUsage() + + class FakeChat: + completions = FakeCompletionsInner() + + class FakeOpenAI: + base_url = "https://api.openai.com/v1" + chat = FakeChat() + + def __init__(self, *args, **kwargs): + pass + + class FakeAsyncOpenAI: + base_url = "https://api.openai.com/v1" + chat = FakeChat() + + def __init__(self, *args, **kwargs): + pass + + fake_openai = types.SimpleNamespace(OpenAI=FakeOpenAI, AsyncOpenAI=FakeAsyncOpenAI) + monkeypatch.setattr(openai_mod, "_openai", fake_openai) + monkeypatch.setattr(openai_mod, "_openai_patched", False) + + client = make_fake_client() + OpenAIExtractor.install_auto_load_patch(lambda: client) + + instance = FakeOpenAI() + assert instance.chat.completions.create is not FakeCompletionsInner.create diff --git a/uv.lock b/uv.lock index 20fb044..de926f8 100644 --- a/uv.lock +++ b/uv.lock @@ -7,6 +7,15 @@ resolution-markers = [ "python_full_version < '3.11'", ] +[[package]] +name = "backports-asyncio-runner" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" }, +] + [[package]] name = "cachetools" version = "7.0.2" @@ -244,6 +253,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" }, + { name = "pytest" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, +] + [[package]] name = "pytest-mock" version = "3.15.1" @@ -403,6 +426,7 @@ source = { editable = "." } dev = [ { name = "coverage" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-mock" }, { name = "ruff" }, { name = "tox" }, @@ -414,6 +438,7 @@ dev = [ dev = [ { name = "coverage" }, { name = "pytest", specifier = ">=9.0.2" }, + { name = "pytest-asyncio", specifier = ">=0.25" }, { name = "pytest-mock", specifier = ">=3.15.1" }, { name = "ruff", specifier = ">=0.15.4" }, { name = "tox" }, diff --git a/wildedge/client.py b/wildedge/client.py index dc9c9d2..da08f02 100644 --- a/wildedge/client.py +++ b/wildedge/client.py @@ -21,6 +21,7 @@ from wildedge.integrations.keras import KerasExtractor from wildedge.integrations.mlx import MlxExtractor from wildedge.integrations.onnx import OnnxExtractor +from wildedge.integrations.openai import OpenAIExtractor from wildedge.integrations.pytorch import PytorchExtractor from wildedge.integrations.registry import noop_integrations, supported_integrations from wildedge.integrations.tensorflow import TensorflowExtractor @@ -84,6 +85,7 @@ def parse_dsn(dsn: str) -> tuple[str, str, str]: DEFAULT_EXTRACTORS: list[BaseExtractor] = [ OnnxExtractor(), GgufExtractor(), + OpenAIExtractor(), UltralyticsExtractor(), TransformersExtractor(), MlxExtractor(), @@ -112,6 +114,7 @@ class WildEdge: "gguf": GgufExtractor.install_auto_load_patch, "mlx": MlxExtractor.install_auto_load_patch, "onnx": OnnxExtractor.install_auto_load_patch, + "openai": OpenAIExtractor.install_auto_load_patch, "timm": PytorchExtractor.install_timm_patch, "tensorflow": TensorflowExtractor.install_auto_load_patch, "transformers": TransformersExtractor.install_auto_load_patch, diff --git a/wildedge/events/__init__.py b/wildedge/events/__init__.py index fceb348..b08bc78 100644 --- a/wildedge/events/__init__.py +++ b/wildedge/events/__init__.py @@ -3,6 +3,7 @@ from wildedge.events.error import ErrorCode, ErrorEvent from wildedge.events.feedback import FeedbackEvent, FeedbackType from wildedge.events.inference import ( + ApiMeta, AudioInputMeta, ClassificationOutputMeta, DetectionOutputMeta, @@ -20,6 +21,7 @@ from wildedge.events.model_unload import ModelUnloadEvent __all__ = [ + "ApiMeta", "AdapterDownload", "AdapterLoad", "AudioInputMeta", diff --git a/wildedge/events/inference.py b/wildedge/events/inference.py index b730020..0c02996 100644 --- a/wildedge/events/inference.py +++ b/wildedge/events/inference.py @@ -225,6 +225,8 @@ class GenerationOutputMeta: task: str = "generation" tokens_in: int | None = None tokens_out: int | None = None + cached_input_tokens: int | None = None + reasoning_tokens_out: int | None = None time_to_first_token_ms: int | None = None tokens_per_second: float | None = None stop_reason: str | None = None @@ -239,6 +241,8 @@ def to_dict(self) -> dict: "task": self.task, "tokens_in": self.tokens_in, "tokens_out": self.tokens_out, + "cached_input_tokens": self.cached_input_tokens, + "reasoning_tokens_out": self.reasoning_tokens_out, "time_to_first_token_ms": self.time_to_first_token_ms, "tokens_per_second": self.tokens_per_second, "stop_reason": self.stop_reason, @@ -250,6 +254,24 @@ def to_dict(self) -> dict: } +@dataclass +class ApiMeta: + resolved_model_id: str | None = None + system_fingerprint: str | None = None + service_tier: str | None = None + + def to_dict(self) -> dict: + return { + k: v + for k, v in { + "resolved_model_id": self.resolved_model_id, + "system_fingerprint": self.system_fingerprint, + "service_tier": self.service_tier, + }.items() + if v is not None + } + + @dataclass class EmbeddingOutputMeta: task: str = "embedding" @@ -281,6 +303,7 @@ class InferenceEvent: ) = None generation_config: GenerationConfig | None = None hardware: HardwareContext | None = None + api_meta: ApiMeta | None = None event_id: str = field(default_factory=lambda: str(uuid.uuid4())) timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) inference_id: str = field(default_factory=lambda: str(uuid.uuid4())) @@ -307,6 +330,8 @@ def to_dict(self) -> dict: inference_data["generation_config"] = self.generation_config.to_dict() if self.hardware is not None: inference_data["hardware"] = self.hardware.to_dict() + if self.api_meta is not None: + inference_data["api_meta"] = self.api_meta.to_dict() return { "event_id": self.event_id, diff --git a/wildedge/integrations/openai.py b/wildedge/integrations/openai.py new file mode 100644 index 0000000..289f59a --- /dev/null +++ b/wildedge/integrations/openai.py @@ -0,0 +1,304 @@ +"""OpenAI / OpenRouter integration.""" + +from __future__ import annotations + +import functools +import threading +import time +from typing import TYPE_CHECKING +from urllib.parse import urlparse + +from wildedge import constants +from wildedge.events.inference import ApiMeta, GenerationOutputMeta, TextInputMeta +from wildedge.integrations.base import BaseExtractor +from wildedge.integrations.common import debug_failure +from wildedge.model import ModelInfo +from wildedge.timing import elapsed_ms + +try: + import openai as _openai +except ImportError: + _openai = None # type: ignore[assignment] + +if TYPE_CHECKING: + from wildedge.model import ModelHandle + +_openai_patched = False +_OPENAI_PATCH_LOCK = threading.Lock() +OPENAI_INIT_PATCH_NAME = "openai_auto_load" + +debug_openai_failure = functools.partial(debug_failure, "openai") + + +def source_from_base_url(base_url: str | None) -> str: + if not base_url: + return "openai" + s = base_url.lower() + if "openrouter" in s: + return "openrouter" + if "openai.com" in s: + return "openai" + try: + return urlparse(s).hostname or s + except Exception: + return s + + +def build_input_meta(messages: list, tokens_in: int | None) -> TextInputMeta | None: + if not messages: + return None + last_user = next((m for m in reversed(messages) if m.get("role") == "user"), None) + if not last_user: + return None + content = last_user.get("content", "") + if not isinstance(content, str) or not content: + return None + return TextInputMeta( + char_count=len(content), + word_count=len(content.split()), + token_count=tokens_in, + prompt_type="chat", + ) + + +def build_output_meta( + response: object, duration_ms: int +) -> GenerationOutputMeta | None: + try: + usage = getattr(response, "usage", None) + if usage is None: + return None + tokens_in = getattr(usage, "prompt_tokens", None) + tokens_out = getattr(usage, "completion_tokens", None) + prompt_details = getattr(usage, "prompt_tokens_details", None) + cached_input_tokens = getattr(prompt_details, "cached_tokens", None) + completion_details = getattr(usage, "completion_tokens_details", None) + reasoning_tokens_out = getattr(completion_details, "reasoning_tokens", None) + choices = getattr(response, "choices", None) or [] + stop_reason = getattr(choices[0], "finish_reason", None) if choices else None + tps = ( + round(tokens_out / duration_ms * 1000, 1) + if duration_ms > 0 and tokens_out + else None + ) + return GenerationOutputMeta( + task="generation", + tokens_in=tokens_in, + tokens_out=tokens_out, + cached_input_tokens=cached_input_tokens, + reasoning_tokens_out=reasoning_tokens_out, + tokens_per_second=tps, + stop_reason=stop_reason, + ) + except Exception as exc: + debug_openai_failure("output meta extraction", exc) + return None + + +def build_api_meta(response: object) -> ApiMeta | None: + try: + resolved_model_id = getattr(response, "model", None) + system_fingerprint = getattr(response, "system_fingerprint", None) + service_tier = getattr(response, "service_tier", None) + if not any([resolved_model_id, system_fingerprint, service_tier]): + return None + return ApiMeta( + resolved_model_id=resolved_model_id, + system_fingerprint=system_fingerprint, + service_tier=service_tier, + ) + except Exception as exc: + debug_openai_failure("api meta extraction", exc) + return None + + +def wrap_sync_completions(completions: object, source: str, client_ref: object) -> None: + original_create = completions.create # type: ignore[attr-defined] + model_handles: dict[str, ModelHandle] = {} + + def patched_create(*args, **kwargs): + model_id: str | None = kwargs.get("model") or (args[0] if args else None) + messages: list = kwargs.get("messages", []) + is_streaming: bool = bool(kwargs.get("stream", False)) + + c = client_ref() # type: ignore[call-arg] + if c is None or c.closed or not model_id: + return original_create(*args, **kwargs) + + if model_id not in model_handles: + try: + model_handles[model_id] = c.register_model( + completions, model_id=model_id, source=source + ) + except Exception as exc: + debug_openai_failure("model registration", exc) + + handle = model_handles.get(model_id) + t0 = time.perf_counter() + try: + result = original_create(*args, **kwargs) + if is_streaming or handle is None: + return result + duration = elapsed_ms(t0) + usage = getattr(result, "usage", None) + tokens_in = getattr(usage, "prompt_tokens", None) if usage else None + handle.track_inference( + duration_ms=duration, + input_modality="text", + output_modality="generation", + success=True, + input_meta=build_input_meta(messages, tokens_in), + output_meta=build_output_meta(result, duration), + api_meta=build_api_meta(result), + ) + return result + except Exception as exc: + if handle is not None: + handle.track_error( + error_code="UNKNOWN", + error_message=str(exc)[: constants.ERROR_MSG_MAX_LEN], + ) + raise + + completions.create = patched_create # type: ignore[attr-defined] + + +def wrap_async_completions( + completions: object, source: str, client_ref: object +) -> None: + original_create = completions.create # type: ignore[attr-defined] + model_handles: dict[str, ModelHandle] = {} + + async def patched_create(*args, **kwargs): + model_id: str | None = kwargs.get("model") or (args[0] if args else None) + messages: list = kwargs.get("messages", []) + is_streaming: bool = bool(kwargs.get("stream", False)) + + c = client_ref() # type: ignore[call-arg] + if c is None or c.closed or not model_id: + return await original_create(*args, **kwargs) + + if model_id not in model_handles: + try: + model_handles[model_id] = c.register_model( + completions, model_id=model_id, source=source + ) + except Exception as exc: + debug_openai_failure("model registration", exc) + + handle = model_handles.get(model_id) + t0 = time.perf_counter() + try: + result = await original_create(*args, **kwargs) + if is_streaming or handle is None: + return result + duration = elapsed_ms(t0) + usage = getattr(result, "usage", None) + tokens_in = getattr(usage, "prompt_tokens", None) if usage else None + handle.track_inference( + duration_ms=duration, + input_modality="text", + output_modality="generation", + success=True, + input_meta=build_input_meta(messages, tokens_in), + output_meta=build_output_meta(result, duration), + api_meta=build_api_meta(result), + ) + return result + except Exception as exc: + if handle is not None: + handle.track_error( + error_code="UNKNOWN", + error_message=str(exc)[: constants.ERROR_MSG_MAX_LEN], + ) + raise + + completions.create = patched_create # type: ignore[attr-defined] + + +class OpenAIExtractor(BaseExtractor): + def can_handle(self, obj: object) -> bool: + return type(obj).__name__ in ( + "OpenAI", + "AsyncOpenAI", + "Completions", + "AsyncCompletions", + ) + + def extract_info( + self, obj: object, overrides: dict + ) -> tuple[str | None, ModelInfo]: + model_id = overrides.pop("id", None) + source = overrides.pop("source", None) or source_from_base_url( + str(getattr(obj, "base_url", None) or "") + ) + info = ModelInfo( + model_name=model_id or "openai-model", + model_version=overrides.pop("version", "unknown"), + model_source=source, + model_format="api", + model_family=overrides.pop("family", None), + quantization=None, + ) + return model_id, info + + def install_hooks(self, obj: object, handle: ModelHandle) -> None: + pass + + @classmethod + def install_auto_load_patch(cls, client_ref: object) -> None: + """Patch openai.OpenAI and openai.AsyncOpenAI to wrap chat.completions.create.""" + global _openai_patched + if _openai_patched or _openai is None: + return + + with _OPENAI_PATCH_LOCK: + if _openai_patched: + return + + original_sync_init = _openai.OpenAI.__init__ + original_async_init = _openai.AsyncOpenAI.__init__ + + if ( + getattr(original_sync_init, "__wildedge_patch_name__", None) + == OPENAI_INIT_PATCH_NAME + ): + _openai_patched = True + return + + def patched_sync_init(self_inner, *args, **kwargs): # type: ignore[no-untyped-def] + original_sync_init(self_inner, *args, **kwargs) + c = client_ref() # type: ignore[call-arg] + if c is not None and not c.closed: + source = source_from_base_url( + str(getattr(self_inner, "base_url", None) or "") + ) + try: + wrap_sync_completions( + self_inner.chat.completions, source, client_ref + ) + except Exception as exc: + debug_openai_failure("sync client wrap", exc) + + def patched_async_init(self_inner, *args, **kwargs): # type: ignore[no-untyped-def] + original_async_init(self_inner, *args, **kwargs) + c = client_ref() # type: ignore[call-arg] + if c is not None and not c.closed: + source = source_from_base_url( + str(getattr(self_inner, "base_url", None) or "") + ) + try: + wrap_async_completions( + self_inner.chat.completions, source, client_ref + ) + except Exception as exc: + debug_openai_failure("async client wrap", exc) + + patched_sync_init.__wildedge_patch_name__ = OPENAI_INIT_PATCH_NAME # type: ignore[attr-defined] + patched_sync_init.__wildedge_original_call__ = original_sync_init # type: ignore[attr-defined] + patched_async_init.__wildedge_patch_name__ = OPENAI_INIT_PATCH_NAME # type: ignore[attr-defined] + patched_async_init.__wildedge_original_call__ = original_async_init # type: ignore[attr-defined] + + _openai.OpenAI.__init__ = patched_sync_init + _openai.AsyncOpenAI.__init__ = patched_async_init + _openai_patched = True diff --git a/wildedge/integrations/registry.py b/wildedge/integrations/registry.py index bd3e07f..e795f38 100644 --- a/wildedge/integrations/registry.py +++ b/wildedge/integrations/registry.py @@ -25,6 +25,7 @@ class IntegrationSpec: INTEGRATION_SPECS: tuple[IntegrationSpec, ...] = ( IntegrationSpec("gguf", ("llama_cpp",), "client_patch"), IntegrationSpec("onnx", ("onnxruntime",), "client_patch"), + IntegrationSpec("openai", ("openai",), "client_patch"), IntegrationSpec("timm", ("timm",), "client_patch"), IntegrationSpec("torch", ("torch",), "noop"), IntegrationSpec("keras", ("keras",), "noop"), diff --git a/wildedge/model.py b/wildedge/model.py index e5b355c..de7e7b1 100644 --- a/wildedge/model.py +++ b/wildedge/model.py @@ -7,6 +7,7 @@ from typing import Any from wildedge.events import ( + ApiMeta, AudioInputMeta, ClassificationOutputMeta, DetectionOutputMeta, @@ -144,6 +145,7 @@ def track_inference( | None = None, generation_config: GenerationConfig | None = None, hardware: HardwareContext | None = None, + api_meta: ApiMeta | None = None, ) -> str: if hardware is None and is_sampling(): hardware = capture_hardware() @@ -159,6 +161,7 @@ def track_inference( output_meta=output_meta, generation_config=generation_config, hardware=hardware, + api_meta=api_meta, ) self.last_inference_id = event.inference_id self.publish(event.to_dict()) From 7234a71dafcec34f309bbe8a6025ce3e0393b9de Mon Sep 17 00:00:00 2001 From: Piotr Duda Date: Mon, 16 Mar 2026 17:24:05 +0100 Subject: [PATCH 2/2] cleanup --- wildedge/integrations/openai.py | 111 ++++++++++++++------------------ 1 file changed, 50 insertions(+), 61 deletions(-) diff --git a/wildedge/integrations/openai.py b/wildedge/integrations/openai.py index 289f59a..e5d1525 100644 --- a/wildedge/integrations/openai.py +++ b/wildedge/integrations/openai.py @@ -30,18 +30,15 @@ debug_openai_failure = functools.partial(debug_failure, "openai") +SOURCE_BY_HOSTNAME: dict[str, str] = { + "api.openai.com": "openai", + "openrouter.ai": "openrouter", +} + + def source_from_base_url(base_url: str | None) -> str: - if not base_url: - return "openai" - s = base_url.lower() - if "openrouter" in s: - return "openrouter" - if "openai.com" in s: - return "openai" - try: - return urlparse(s).hostname or s - except Exception: - return s + hostname = urlparse(base_url.lower()).hostname if base_url else "" + return SOURCE_BY_HOSTNAME.get(hostname or "", hostname or "openai") def build_input_meta(messages: list, tokens_in: int | None) -> TextInputMeta | None: @@ -112,6 +109,42 @@ def build_api_meta(response: object) -> ApiMeta | None: return None +def resolve_handle( + model_id: str, + completions: object, + model_handles: dict[str, ModelHandle], + client: object, + source: str, +) -> ModelHandle | None: + if model_id not in model_handles: + try: + model_handles[model_id] = client.register_model( # type: ignore[attr-defined] + completions, model_id=model_id, source=source + ) + except Exception as exc: + debug_openai_failure("model registration", exc) + return model_handles.get(model_id) + + +def record_inference( + handle: ModelHandle, + result: object, + messages: list, + duration: int, +) -> None: + usage = getattr(result, "usage", None) + tokens_in = getattr(usage, "prompt_tokens", None) if usage else None + handle.track_inference( + duration_ms=duration, + input_modality="text", + output_modality="generation", + success=True, + input_meta=build_input_meta(messages, tokens_in), + output_meta=build_output_meta(result, duration), + api_meta=build_api_meta(result), + ) + + def wrap_sync_completions(completions: object, source: str, client_ref: object) -> None: original_create = completions.create # type: ignore[attr-defined] model_handles: dict[str, ModelHandle] = {} @@ -120,37 +153,15 @@ def patched_create(*args, **kwargs): model_id: str | None = kwargs.get("model") or (args[0] if args else None) messages: list = kwargs.get("messages", []) is_streaming: bool = bool(kwargs.get("stream", False)) - c = client_ref() # type: ignore[call-arg] if c is None or c.closed or not model_id: return original_create(*args, **kwargs) - - if model_id not in model_handles: - try: - model_handles[model_id] = c.register_model( - completions, model_id=model_id, source=source - ) - except Exception as exc: - debug_openai_failure("model registration", exc) - - handle = model_handles.get(model_id) + handle = resolve_handle(model_id, completions, model_handles, c, source) t0 = time.perf_counter() try: result = original_create(*args, **kwargs) - if is_streaming or handle is None: - return result - duration = elapsed_ms(t0) - usage = getattr(result, "usage", None) - tokens_in = getattr(usage, "prompt_tokens", None) if usage else None - handle.track_inference( - duration_ms=duration, - input_modality="text", - output_modality="generation", - success=True, - input_meta=build_input_meta(messages, tokens_in), - output_meta=build_output_meta(result, duration), - api_meta=build_api_meta(result), - ) + if not is_streaming and handle is not None: + record_inference(handle, result, messages, elapsed_ms(t0)) return result except Exception as exc: if handle is not None: @@ -173,37 +184,15 @@ async def patched_create(*args, **kwargs): model_id: str | None = kwargs.get("model") or (args[0] if args else None) messages: list = kwargs.get("messages", []) is_streaming: bool = bool(kwargs.get("stream", False)) - c = client_ref() # type: ignore[call-arg] if c is None or c.closed or not model_id: return await original_create(*args, **kwargs) - - if model_id not in model_handles: - try: - model_handles[model_id] = c.register_model( - completions, model_id=model_id, source=source - ) - except Exception as exc: - debug_openai_failure("model registration", exc) - - handle = model_handles.get(model_id) + handle = resolve_handle(model_id, completions, model_handles, c, source) t0 = time.perf_counter() try: result = await original_create(*args, **kwargs) - if is_streaming or handle is None: - return result - duration = elapsed_ms(t0) - usage = getattr(result, "usage", None) - tokens_in = getattr(usage, "prompt_tokens", None) if usage else None - handle.track_inference( - duration_ms=duration, - input_modality="text", - output_modality="generation", - success=True, - input_meta=build_input_meta(messages, tokens_in), - output_meta=build_output_meta(result, duration), - api_meta=build_api_meta(result), - ) + if not is_streaming and handle is not None: + record_inference(handle, result, messages, elapsed_ms(t0)) return result except Exception as exc: if handle is not None: