diff --git a/README.md b/README.md index 2972aa4..66f181f 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![Coverage](https://codecov.io/gh/wildedge/wildedge-python/branch/main/graph/badge.svg)](https://codecov.io/gh/wildedge/wildedge-python) -On-device ML inference monitoring for Python. Tracks latency, errors, and model metadata without capturing inputs or outputs. +On-device ML inference monitoring for Python. Track in-dept, model quality & performance information. ## Install @@ -16,158 +16,62 @@ On-device ML inference monitoring for Python. Tracks latency, errors, and model uv add wildedge-sdk ``` -## Quick start +## CLI — zero code changes -```python -import wildedge - -client = wildedge.WildEdge( - dsn="...", # or set WILDEDGE_DSN -) -``` - -## CLI wrapper - -Use `wildedge run` to execute an existing Python entrypoint with WildEdge runtime initialization and integration instrumentation enabled before user code starts: - -```bash -wildedge run --dsn "https://@ingest.wildedge.dev/" -- python app.py -``` - -End-to-end CLI wrapper example: +Drop `wildedge run` in front of your existing command. WildEdge instruments the runtime before your code starts — no SDK calls required in user code. ```bash WILDEDGE_DSN="https://@ingest.wildedge.dev/" \ -wildedge run --print-startup-report --integrations timm -- \ -python examples/cli/cli_wrapper_example.py +wildedge run --integrations timm -- python app.py ``` -See `examples/cli/cli_wrapper_example.py` for a script that has no WildEdge SDK calls in user code. -Use `examples/cli/demo.sh` for a single-script linear flow (sync, doctor, run). -This demo uses `examples/cli/pyproject.toml`, so dependency management is delegated to `uv`. - -Module entrypoints are supported: +Validate your environment before deploying: ```bash -wildedge run -- python -m your_package.main --arg value +wildedge doctor --integrations all --network-check ``` -Validate local runtime readiness: +Useful flags: -```bash -wildedge doctor --integrations all -``` +| Flag | Description | +|---|---| +| `--integrations` | Comma-separated list of integrations to activate (or `all`) | +| `--hubs` | Hub trackers to activate: `huggingface`, `torchhub` | +| `--print-startup-report` | Print per-integration status at startup | +| `--strict-integrations` | Fail if a requested integration can't be loaded | +| `--no-propagate` | Don't pass WildEdge env vars to child processes | -Machine-readable diagnostics: +## SDK -```bash -wildedge doctor --format json --integrations all -``` +```python +import wildedge -Optional DSN reachability probe: +client = wildedge.WildEdge(dsn="...") # or WILDEDGE_DSN env var +client.instrument("transformers", hubs=["huggingface"]) -```bash -wildedge doctor --network-check --dsn "https://@ingest.wildedge.dev/" +# models loaded after this point are tracked automatically ``` -Useful run flags: -- `--strict-integrations`: fail startup if a requested integration cannot be instrumented. -- `--no-propagate`: do not propagate WildEdge runtime env vars to nested child processes. -- `--print-startup-report`: print startup diagnostics with per-integration status. - -## Integrations - -Call `client.instrument()` to activate auto-tracking for a supported library. Models created afterwards are registered and timed automatically with no changes to existing call sites. - -See the `examples/` folder for complete working examples. +## Supported integrations -### Integration initialization - -Initialize integrations at process startup, before model loading begins. Instrumentation patches are applied per process and should be installed before imports and constructor calls on instrumented libraries. - -For high-priority paths, keep explicit registration with `client.load(...)` or `client.register_model(...)` as a fallback when model creation does not go through a patched API. -For an explicit fallback pattern, see `examples/gguf_gemma_manual_example.py`. - -### PyTorch (custom models) +| Integration | Patches | Hub tracking | Example | +|---|---|---|---| +| `transformers` | `pipeline()`, `AutoModel.from_pretrained()` | `huggingface` | [transformers_example.py](examples/transformers_example.py) | +| `timm` | `timm.create_model()` | `huggingface`, `torchhub` | [timm_example.py](examples/timm_example.py) | +| `gguf` | `llama_cpp.Llama.__init__` | `huggingface` | [gguf_example.py](examples/gguf_example.py) | +| `onnx` | `ort.InferenceSession` | `huggingface` | [onnx_example.py](examples/onnx_example.py) | +| `ultralytics` | `ultralytics.YOLO.__init__` | — | — | +| `tensorflow` | `tf.keras.models.load_model`, `tf.saved_model.load` | — | [tensorflow_example.py](examples/tensorflow_example.py) | +| `torch` | forward hooks via `client.load()` | `torchhub` | [pytorch_example.py](examples/pytorch_example.py) | +| `keras` | forward hooks via `client.load()` | — | [keras_example.py](examples/keras_example.py) | -PyTorch models are user-defined subclasses, so there is no single constructor to patch. Use `client.load()` to time construction and track load/unload automatically; inference is tracked via forward hooks once the model is registered. +For `torch` and `keras`, models are user-defined subclasses so there's no constructor to patch. Use `client.load()` to get load/unload tracking alongside inference: ```python model = client.load(MyModel) output = model(x) # tracked automatically ``` -See `examples/pytorch_example.py` for a complete example. - -### Keras (custom models) - -Same pattern as PyTorch: - -```python -model = client.load(MyKerasModel) -output = model(x) # tracked automatically -``` - -See `examples/keras_example.py` for a complete example. - -### ONNX Runtime - -```python -import onnxruntime as ort - -client.instrument("onnx") - -session = ort.InferenceSession("yolov8n.onnx") -outputs = session.run(None, {"input": image}) # tracked automatically -``` - -See `examples/onnx_example.py` for a complete example. - -### TensorFlow - -```python -import tensorflow as tf - -client.instrument("tensorflow") - -model = tf.keras.models.load_model("model.keras") # tracked automatically -output = model(batch, training=False) # tracked automatically -``` - -See `examples/tensorflow_example.py` for a complete example. - -### timm - -```python -import timm - -client.instrument("timm") - -model = timm.create_model("resnet50", pretrained=True) -output = model(image_tensor) # tracked automatically -``` - -See `examples/timm_example.py` for a complete example. - -### GGUF / llama.cpp - -```python -from llama_cpp import Llama - -client.instrument("gguf") - -llm = Llama("llama-3.2-1b.Q4_K_M.gguf", n_ctx=2048, n_gpu_layers=-1) -result = llm("What is the capital of France?") # tracked automatically -``` - -See `examples/gguf_example.py` for a complete example. - -## Limitations - -- Currently supports only Python 3.10+ due to use of modern type annotations. -- Overhead: Less than 1% latency increase in internal benchmarks. -- For air-gapped environments, on-premise server installation is required. - ## Manual tracking Use `@wildedge.track` as a decorator or context manager when auto-instrumentation isn't available: @@ -175,83 +79,35 @@ Use `@wildedge.track` as a decorator or context manager when auto-instrumentatio ```python handle = client.register_model(my_model) -# decorator @wildedge.track(handle) def run(input): return my_model.predict(input) - -# context manager -with wildedge.track(handle): - result = my_model.predict(input) ``` ## Configuration | Parameter | Default | Env var | Description | |---|---|---|---| -| `dsn` | `-` | `WILDEDGE_DSN` | Required. `https://@ingest.wildedge.dev/` | -| `app_version` | `None` | `-` | Optional. Your app's version string. | -| `app_identity` | `` | `WILDEDGE_APP_IDENTITY` | Namespace for offline persistence paths. Set per-app to isolate multi-process workloads in one project. | -| `debug` | `false` | `WILDEDGE_DEBUG` | Log events to console. | -| `batch_size` | `10` | `-` | Events per transmission (recommended: 1-100). | -| `flush_interval_sec` | `60` | `-` | Max seconds between flushes (recommended: 1-3600). | -| `max_queue_size` | `200` | `-` | In-memory buffer limit (recommended: 10-10000). | -| `enable_offline_persistence` | `true` | `-` | Persist pending unsent events on disk and replay on restart. | -| `offline_queue_dir` | OS-specific state dir | `-` | Folder for pending queue persistence (defaults to platform state path). | -| `max_event_age_sec` | `900` | `-` | Max age for queued events before dead-lettering. | -| `enable_dead_letter_persistence` | `false` | `-` | Persist dropped batches/events to disk dead-letter store. | -| `dead_letter_dir` | OS-specific cache dir | `-` | Directory where dead-letter batch files are stored. | -| `max_dead_letter_batches` | `10` | `-` | Max dead-letter batch files retained on disk. | - -## Testing - -### Run tests locally - -Install development dependencies and run the test suite: - -```bash -uv sync --group dev -uv run pytest -``` - -### Run tests across Python versions - -Use `tox` to run the test suite against all supported Python versions (3.10+): - -```bash -uv sync --group dev -tox -``` - -Compatibility matrix details are documented in `docs/compatibility.md`. - -Run the compatibility matrix locally with one command: - -```bash -python3 scripts/run_compat_local.py -``` - -To fail when a dependency row is unsupported on your local platform, use: - -```bash -python3 scripts/run_compat_local.py --strict-unsupported -``` +| `dsn` | — | `WILDEDGE_DSN` | `https://@ingest.wildedge.dev/` | +| `app_version` | `None` | — | Your app's version string | +| `app_identity` | `` | `WILDEDGE_APP_IDENTITY` | Namespace for offline persistence; set per-app in multi-process workloads | +| `debug` | `false` | `WILDEDGE_DEBUG` | Log events to console | +| `batch_size` | `10` | — | Events per transmission (1–100) | +| `flush_interval_sec` | `60` | — | Max seconds between flushes (1–3600) | +| `max_queue_size` | `200` | — | In-memory buffer limit (10–10000) | +| `enable_offline_persistence` | `true` | — | Persist unsent events to disk and replay on restart | +| `max_event_age_sec` | `900` | — | Max age before dead-lettering | +| `enable_dead_letter_persistence` | `false` | — | Persist dropped batches to disk | -## Security +## Privacy -WildEdge SDK privacy model: -- **No input/output capture**: Only metadata (latency, errors, model info) is collected. -- **Secure transmission**: Data is sent over HTTPS to WildEdge servers. -- **Local processing**: All inference happens locally; SDK only monitors performance. -- **DSN-based auth**: Uses project-specific secrets for authentication. +WildEdge captures **no inputs or outputs** — only metadata: latency, errors, model info, and download provenance. All inference runs locally; only telemetry is transmitted over HTTPS. -If you discover a security issue, please email security@wildedge.dev instead of creating a public issue. +Report security issues to security@wildedge.dev. ## Links -- [Full Documentation](https://docs.wildedge.dev) -- [Website](https://wildedge.dev) +- [Documentation](https://docs.wildedge.dev) - [Compatibility Matrix](docs/compatibility.md) - [Changelog](CHANGELOG.md) -- [Contributing](CONTRIBUTING.md) - [License](LICENSE) diff --git a/examples/transformers_example.py b/examples/transformers_example.py new file mode 100644 index 0000000..86912a0 --- /dev/null +++ b/examples/transformers_example.py @@ -0,0 +1,106 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = ["wildedge-sdk", "transformers", "torch"] +# +# [tool.uv.sources] +# wildedge-sdk = { path = "..", editable = true } +# /// +""" +HuggingFace Transformers integration example. + +WildEdge patches transformers.pipeline (and AutoModel.from_pretrained) at +client initialisation, so load timing, download tracking, inference tracking, +and unload tracking all happen automatically. + +Usage: + uv run transformers_example.py # text classification + uv run transformers_example.py --task generate # text generation + uv run transformers_example.py --task embed # feature extraction +""" + +from __future__ import annotations + +import argparse + +from transformers import pipeline + +import wildedge + + +def run_classify() -> None: + pipe = pipeline( + "text-classification", + model="distilbert-base-uncased-finetuned-sst-2-english", + ) + inputs = [ + "I absolutely loved this film — the performances were outstanding!", + "The service was awful and the food arrived cold.", + "An average experience, nothing special either way.", + ] + print("Sentiment classification:") + for text in inputs: + result = pipe(text) + label = result[0]["label"] + score = result[0]["score"] + bar = "█" * int(score * 20) + print(f" {label:<9} {bar:<20} {score:.3f} {text!r}") + + +def run_generate() -> None: + pipe = pipeline("text-generation", model="gpt2", max_new_tokens=40) + prompts = [ + "The future of on-device AI is", + "Once upon a time, a small robot learned", + ] + print("Text generation (GPT-2):") + for prompt in prompts: + result = pipe(prompt, do_sample=False) + print(f" Prompt : {prompt!r}") + print(f" Output : {result[0]['generated_text']!r}\n") + + +def run_embed() -> None: + pipe = pipeline("feature-extraction", model="bert-base-uncased") + sentences = [ + "Machine learning is transforming every industry.", + "On-device inference keeps your data private.", + "WildEdge monitors ML performance in production.", + ] + print("Feature extraction (BERT):") + for sent in sentences: + result = pipe(sent) + # result shape: [1, seq_len, hidden_size] — take CLS token embedding + cls_embedding = result[0][0] + dims = len(cls_embedding) + norm = sum(v**2 for v in cls_embedding) ** 0.5 + print(f" dims={dims} L2={norm:.2f} {sent!r}") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="WildEdge + HuggingFace Transformers example" + ) + parser.add_argument( + "--task", + choices=["classify", "generate", "embed"], + default="classify", + help="Pipeline task to demonstrate (default: classify)", + ) + args = parser.parse_args() + + # instrument() patches transformers.pipeline and AutoModel.from_pretrained + # before any model is loaded — everything below is tracked automatically. + client = wildedge.WildEdge(app_version="1.0.0") # set WILDEDGE_DSN env var + client.instrument("transformers", hubs=["huggingface"]) + + print() + {"classify": run_classify, "generate": run_generate, "embed": run_embed}[ + args.task + ]() + + client.flush() + print("\nDone — events flushed to WildEdge.") + + +if __name__ == "__main__": + main() diff --git a/wildedge/client.py b/wildedge/client.py index 5056d11..df38253 100644 --- a/wildedge/client.py +++ b/wildedge/client.py @@ -24,6 +24,7 @@ from wildedge.integrations.pytorch import PytorchExtractor from wildedge.integrations.registry import noop_integrations, supported_integrations from wildedge.integrations.tensorflow import TensorflowExtractor +from wildedge.integrations.transformers import TransformersExtractor from wildedge.integrations.ultralytics import UltralyticsExtractor from wildedge.logging import enable_debug, logger from wildedge.model import ModelHandle, ModelInfo, ModelRegistry @@ -82,6 +83,7 @@ def parse_dsn(dsn: str) -> tuple[str, str, str]: OnnxExtractor(), GgufExtractor(), UltralyticsExtractor(), + TransformersExtractor(), PytorchExtractor(), TensorflowExtractor(), KerasExtractor(), @@ -108,6 +110,7 @@ class WildEdge: "onnx": OnnxExtractor.install_auto_load_patch, "timm": PytorchExtractor.install_timm_patch, "tensorflow": TensorflowExtractor.install_auto_load_patch, + "transformers": TransformersExtractor.install_auto_load_patch, "ultralytics": UltralyticsExtractor.install_auto_load_patch, } @@ -426,6 +429,10 @@ def instrument( ``"tensorflow"`` Patches ``tf.keras.models.load_model`` and ``tf.saved_model.load``. Requires ``tensorflow``. + ``"transformers"`` + Patches ``transformers.pipeline`` and + ``PreTrainedModel.from_pretrained`` (covers all ``AutoModel.*`` + calls). Requires ``transformers``. ``"ultralytics"`` Patches ``ultralytics.YOLO.__init__``. Requires ``ultralytics``. Emits a download event on first load if weights were fetched from diff --git a/wildedge/integrations/registry.py b/wildedge/integrations/registry.py index 3a53778..6b09b7a 100644 --- a/wildedge/integrations/registry.py +++ b/wildedge/integrations/registry.py @@ -31,6 +31,7 @@ class IntegrationSpec: IntegrationSpec("keras", ("keras",), "noop"), IntegrationSpec("tensorflow", ("tensorflow",), "client_patch"), IntegrationSpec("ultralytics", ("ultralytics",), "client_patch"), + IntegrationSpec("transformers", ("transformers",), "client_patch"), ) INTEGRATIONS_BY_NAME: dict[str, IntegrationSpec] = { diff --git a/wildedge/integrations/transformers.py b/wildedge/integrations/transformers.py new file mode 100644 index 0000000..8404cf2 --- /dev/null +++ b/wildedge/integrations/transformers.py @@ -0,0 +1,477 @@ +"""HuggingFace Transformers integration.""" + +from __future__ import annotations + +import os +import threading +import time +from typing import TYPE_CHECKING + +from wildedge import constants +from wildedge.events.inference import ( + ClassificationOutputMeta, + EmbeddingOutputMeta, + GenerationOutputMeta, + TextInputMeta, + TopKPrediction, +) +from wildedge.integrations.base import BaseExtractor, patch_instance_call_once +from wildedge.logging import logger +from wildedge.model import ModelInfo +from wildedge.timing import elapsed_ms + +try: + import transformers as _transformers +except ImportError: + _transformers = None # type: ignore[assignment] + +if TYPE_CHECKING: + from wildedge.model import ModelHandle + +# --- Patch state --- +_transformers_patched = False +_TRANSFORMERS_PATCH_LOCK = threading.Lock() +TRANSFORMERS_AUTO_LOAD_PATCH_NAME = "transformers_auto_load" + +# --- Pipeline instance patching --- +PIPELINE_CALL_PATCH_NAME = "transformers_pipeline_call" +PIPELINE_HANDLE_ATTR = "__wildedge_pipeline_handle__" + +# Thread-local flag: suppress from_pretrained tracking when called inside pipeline() +_tl = threading.local() + + +def _debug_failure(context: str, exc: BaseException) -> None: + logger.debug("wildedge: transformers %s failed: %s", context, exc) + + +def _is_pretrained_model(obj: object) -> bool: + """String-check avoids importing transformers when not installed.""" + for cls in type(obj).__mro__: + if cls.__name__ == "PreTrainedModel" and "transformers" in cls.__module__: + return True + return False + + +def _is_pipeline(obj: object) -> bool: + for cls in type(obj).__mro__: + if cls.__name__ == "Pipeline" and "transformers" in cls.__module__: + return True + return False + + +def _extract_model_config(obj: object) -> tuple[str | None, str | None, str | None]: + """Returns (name_or_path, model_type, architectures[0]). Never raises.""" + try: + config = getattr(obj, "config", None) + if config is None: + inner = getattr(obj, "model", None) + config = getattr(inner, "config", None) if inner is not None else None + if config is None: + return None, None, None + name_or_path = getattr(config, "name_or_path", None) or None + model_type = getattr(config, "model_type", None) or None + archs = getattr(config, "architectures", None) + arch = archs[0] if archs else None + return name_or_path, model_type, arch + except Exception as exc: + _debug_failure("config extraction", exc) + return None, None, None + + +def _is_local_path(name_or_path: str | None) -> bool: + if not name_or_path: + return False + return os.path.sep in name_or_path or os.path.exists(name_or_path) + + +def _detect_quantization(obj: object) -> str | None: + try: + # Prefer quantization_config (bitsandbytes, GPTQ, AWQ, etc.) + config = getattr(obj, "config", None) + if config is None: + inner = getattr(obj, "model", None) + config = getattr(inner, "config", None) if inner is not None else None + if config is not None: + qconfig = getattr(config, "quantization_config", None) + if qconfig is not None: + quant_type = getattr(qconfig, "quant_type", None) or getattr( + qconfig, "quantization_type", None + ) + bits = getattr(qconfig, "bits", None) or getattr( + qconfig, "num_bits", None + ) + if quant_type: + return str(quant_type).lower() + if bits: + return f"int{int(bits)}" + # Fall back to model dtype + model = getattr(obj, "model", obj) + dtype = getattr(model, "dtype", None) + if dtype is not None: + s = str(dtype) + if "bfloat16" in s: + return "bf16" + if "float16" in s: + return "f16" + if "int8" in s: + return "int8" + except Exception as exc: + _debug_failure("quantization detection", exc) + return None + + +def _detect_accelerator(obj: object) -> str: + try: + model = getattr(obj, "model", obj) + first = next(model.parameters()) # type: ignore[union-attr] + return str(getattr(first.device, "type", "cpu") or "cpu") + except Exception: + pass + return "cpu" + + +def _infer_task_from_arch(arch: str | None) -> str | None: + """Guess broad task category from architecture class name.""" + if not arch: + return None + lower = arch.lower() + if any(k in lower for k in ("forsequenceclassification", "fortokenclassification")): + return "classification" + if any( + k in lower for k in ("forcausallm", "forseq2seqlm", "forconditionalgeneration") + ): + return "generation" + if lower.endswith("model"): + return "embedding" + return None + + +# --------------------------------------------------------------------------- +# Pipeline call patching +# --------------------------------------------------------------------------- + + +def _pipeline_input_meta(inputs: object) -> TextInputMeta | None: + try: + texts: list[str] = [] + if isinstance(inputs, str): + texts = [inputs] + elif isinstance(inputs, list): + texts = [t for t in inputs if isinstance(t, str)] + if not texts: + return None + char_count = sum(len(t) for t in texts) + word_count = sum(len(t.split()) for t in texts) + return TextInputMeta(char_count=char_count, word_count=word_count) + except Exception as exc: + _debug_failure("pipeline input meta", exc) + return None + + +def _pipeline_output_meta( + task: str | None, outputs: object +) -> ClassificationOutputMeta | GenerationOutputMeta | EmbeddingOutputMeta | None: + if task is None: + return None + try: + t = task.lower() + if any( + k in t + for k in ( + "classification", + "sentiment", + "zero-shot", + "ner", + "token-class", + ) + ): + if isinstance(outputs, list) and outputs: + first = outputs[0] + if isinstance(first, dict) and "score" in first: + label = str(first.get("label", "")) + score = round(float(first["score"]), 4) + return ClassificationOutputMeta( + avg_confidence=score, + top_k=[TopKPrediction(label=label, confidence=score)], + ) + return ClassificationOutputMeta() + + if any( + k in t for k in ("generation", "translation", "summariz", "conversational") + ): + tokens_out: int | None = None + if isinstance(outputs, list) and outputs: + first = outputs[0] + if isinstance(first, dict): + text = ( + first.get("generated_text") + or first.get("translation_text") + or first.get("summary_text") + ) + if text: + tokens_out = len(str(text).split()) + return GenerationOutputMeta(tokens_out=tokens_out) + + if any(k in t for k in ("feature", "embed", "similarity")): + dims: int | None = None + if isinstance(outputs, list) and outputs: + first = outputs[0] + if ( + isinstance(first, list) + and first + and isinstance(first[0], (int, float)) + ): + dims = len(first) + elif isinstance(first, list) and first and isinstance(first[0], list): + # [[[token_embs]]] shape for feature-extraction + dims = len(first[0]) + return EmbeddingOutputMeta(dimensions=dims) + + except Exception as exc: + _debug_failure("pipeline output meta", exc) + return None + + +def _pipeline_modalities(task: str | None) -> tuple[str | None, str | None]: + if not task: + return "text", None + t = task.lower() + if any(k in t for k in ("classification", "sentiment", "zero-shot", "ner")): + return "text", "classification" + if any(k in t for k in ("generation", "translation", "summariz", "conversational")): + return "text", "generation" + if any(k in t for k in ("feature", "embed", "similarity")): + return "text", "embedding" + return "text", None + + +def _build_pipeline_patched_call(original_call): # type: ignore[no-untyped-def] + def patched_call(self_inner, inputs, *args, **kwargs): # type: ignore[no-untyped-def] + handle = getattr(self_inner, PIPELINE_HANDLE_ATTR, None) + if handle is None: + return original_call(self_inner, inputs, *args, **kwargs) + + task = getattr(self_inner, "task", None) + batch_size: int | None = ( + len(inputs) + if isinstance(inputs, list) + else (1 if isinstance(inputs, str) else None) + ) + input_meta = _pipeline_input_meta(inputs) + input_modality, output_modality = _pipeline_modalities(task) + + t0 = time.perf_counter() + try: + outputs = original_call(self_inner, inputs, *args, **kwargs) + duration_ms = elapsed_ms(t0) + output_meta = _pipeline_output_meta(task, outputs) + handle.track_inference( + duration_ms=duration_ms, + batch_size=batch_size, + input_modality=input_modality, + output_modality=output_modality, + input_meta=input_meta, + output_meta=output_meta, + success=True, + ) + return outputs + except Exception as exc: + handle.track_error( + error_code="UNKNOWN", + error_message=str(exc)[: constants.ERROR_MSG_MAX_LEN], + ) + raise + + return patched_call + + +# --------------------------------------------------------------------------- +# Extractor +# --------------------------------------------------------------------------- + + +class TransformersExtractor(BaseExtractor): + def can_handle(self, obj: object) -> bool: + return _is_pretrained_model(obj) or _is_pipeline(obj) + + def extract_info( + self, obj: object, overrides: dict + ) -> tuple[str | None, ModelInfo]: + name_or_path, model_type, arch = _extract_model_config(obj) + + model_name = arch or model_type or type(obj).__name__ + model_id = overrides.pop("id", None) or name_or_path or model_name + family = overrides.pop("family", None) or model_type + version = overrides.pop("version", "unknown") + source = overrides.pop("source", None) + if source is None: + source = "local" if _is_local_path(name_or_path) else "huggingface" + quantization = overrides.pop("quantization", None) or _detect_quantization(obj) + + info = ModelInfo( + model_name=model_name, + model_version=version, + model_source=source, + model_format="transformers", + model_family=family, + quantization=quantization, + ) + for k, v in overrides.items(): + if hasattr(info, k): + setattr(info, k, v) + + return model_id, info + + def memory_bytes(self, obj: object) -> int | None: + try: + model = getattr(obj, "model", obj) + params = sum(p.numel() * p.element_size() for p in model.parameters()) # type: ignore[union-attr] + buffers = sum(b.numel() * b.element_size() for b in model.buffers()) # type: ignore[union-attr] + return params + buffers + except Exception as exc: + _debug_failure("memory estimation", exc) + return None + + def install_hooks(self, obj: object, handle: ModelHandle) -> None: + handle.detected_accelerator = _detect_accelerator(obj) + + if _is_pipeline(obj): + setattr(obj, PIPELINE_HANDLE_ATTR, handle) + patch_instance_call_once( + obj, + patch_name=PIPELINE_CALL_PATCH_NAME, + make_patched_call=_build_pipeline_patched_call, + ) + else: + # PreTrainedModel: use PyTorch forward hooks + _local = threading.local() + _, _, arch = _extract_model_config(obj) + task_hint = _infer_task_from_arch(arch) + + def pre_hook(module, args): # type: ignore[no-untyped-def] + _local.t0 = time.perf_counter() + + def post_hook(module, args, output): # type: ignore[no-untyped-def] + t0 = getattr(_local, "t0", None) + duration_ms = elapsed_ms(t0) if t0 is not None else 0 + + batch_size: int | None = None + input_meta: TextInputMeta | None = None + input_modality, output_modality = "text", task_hint + + try: + # args[0] is typically input_ids: (batch, seq_len) + input_ids = args[0] if args else None + if input_ids is not None and hasattr(input_ids, "shape"): + shape = input_ids.shape + if len(shape) >= 1: + batch_size = int(shape[0]) + if len(shape) >= 2: + input_meta = TextInputMeta(token_count=int(shape[1])) + except Exception as exc: + _debug_failure("forward hook input extraction", exc) + + handle.track_inference( + duration_ms=duration_ms, + batch_size=batch_size, + input_modality=input_modality, + output_modality=output_modality, + input_meta=input_meta, + output_meta=None, + success=True, + ) + + obj.register_forward_pre_hook(pre_hook) # type: ignore[union-attr] + obj.register_forward_hook(post_hook) # type: ignore[union-attr] + + # ----------------------------------------------------------------------- + # Auto-load patches + # ----------------------------------------------------------------------- + + @classmethod + def install_auto_load_patch(cls, client_ref: object) -> None: + """Patch transformers.pipeline and PreTrainedModel.from_pretrained. + + Called once at WildEdge client initialisation. Any subsequent + ``pipeline(...)`` or ``AutoModel.from_pretrained(...)`` call is timed + and registered automatically. HuggingFace Hub downloads are intercepted + for the duration of the call and emitted as a model_download event. + A thread-local guard prevents double-tracking when pipeline() calls + from_pretrained() internally. + """ + global _transformers_patched + if _transformers_patched or _transformers is None: + return + + with _TRANSFORMERS_PATCH_LOCK: + if _transformers_patched: + return + cls._patch_pipeline(client_ref) + cls._patch_from_pretrained(client_ref) + _transformers_patched = True + + @classmethod + def _patch_pipeline(cls, client_ref: object) -> None: + original_pipeline = _transformers.pipeline + if ( + getattr(original_pipeline, "__wildedge_patch_name__", None) + == TRANSFORMERS_AUTO_LOAD_PATCH_NAME + ): + return + + def patched_pipeline(*args, **kwargs): # type: ignore[no-untyped-def] + c = client_ref() # type: ignore[call-arg] + hub_before = ( + c._snapshot_hub_caches() if c is not None and not c.closed else {} + ) + t0 = time.perf_counter() + _tl.inside_pipeline = True + try: + pipe = original_pipeline(*args, **kwargs) + finally: + _tl.inside_pipeline = False + load_ms = elapsed_ms(t0) + if c is not None and not c.closed: + downloads = c._diff_hub_caches(hub_before, load_ms) or None + c._on_model_auto_loaded(pipe, load_ms=load_ms, downloads=downloads) + return pipe + + patched_pipeline.__wildedge_patch_name__ = TRANSFORMERS_AUTO_LOAD_PATCH_NAME # type: ignore[attr-defined] + patched_pipeline.__wildedge_original_call__ = original_pipeline # type: ignore[attr-defined] + _transformers.pipeline = patched_pipeline + + @classmethod + def _patch_from_pretrained(cls, client_ref: object) -> None: + original_bound = _transformers.PreTrainedModel.from_pretrained + if ( + getattr(original_bound, "__wildedge_patch_name__", None) + == TRANSFORMERS_AUTO_LOAD_PATCH_NAME + ): + return + + original_func = original_bound.__func__ + + def patched_from_pretrained(model_cls, *args, **kwargs): # type: ignore[no-untyped-def] + # Don't double-track models loaded inside pipeline() + if getattr(_tl, "inside_pipeline", False): + return original_func(model_cls, *args, **kwargs) + c = client_ref() # type: ignore[call-arg] + hub_before = ( + c._snapshot_hub_caches() if c is not None and not c.closed else {} + ) + t0 = time.perf_counter() + model = original_func(model_cls, *args, **kwargs) + load_ms = elapsed_ms(t0) + if c is not None and not c.closed: + downloads = c._diff_hub_caches(hub_before, load_ms) or None + c._on_model_auto_loaded(model, load_ms=load_ms, downloads=downloads) + return model + + patched_from_pretrained.__wildedge_patch_name__ = ( + TRANSFORMERS_AUTO_LOAD_PATCH_NAME # type: ignore[attr-defined] + ) + patched_from_pretrained.__wildedge_original_call__ = original_func # type: ignore[attr-defined] + _transformers.PreTrainedModel.from_pretrained = classmethod( # type: ignore[assignment] + patched_from_pretrained + )