From d7cba8f217590fe179a087a8c2eb14e244909ede Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 9 May 2026 02:10:10 +0300 Subject: [PATCH 1/8] feat: auto-load HF ONNX artifacts on CPU --- AGENTS.md | 9 + .../test_text_classifier_inference_api.py | 27 + .../text_classifier_inference_api.py | 10 + .../default_inference/nlp/th_hf_model_base.py | 622 +++++++++++++++++- extensions/serving/test_th_hf_model_base.py | 232 ++++++- 5 files changed, 885 insertions(+), 15 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index ca6dcbb5..dccc2061 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -686,3 +686,12 @@ Entry format: - Details: Added repo purpose/runtime constraints, ownership table, safe-edit boundaries, required verification matrix, role-based agent cards, A2A-style task contract, mandatory handoff envelope, single-agent loop, actor-critic workflow, reusable lessons-learned section, worked examples, and explicit AGENTS review triggers. Critic concerns addressed in the update: keep single-agent as default to avoid unnecessary delegation, require executable evidence for actor-vs-critic disputes, and keep memory logging critical-only instead of turning the file into an activity log. - Verification: `sed -n '1,260p' AGENTS.md`; `rg -n "Module And File Ownership|Safe-Edit Boundaries|Required Verification Commands|Agent Cards|A2A-Style Task Contract|Actor-Critic|AGENTS Review Triggers|ML-20260317-001" AGENTS.md` - Links: `AGENTS.md` + +- ID: `ML-20260508-001` +- Timestamp: `2026-05-08T23:08:05Z` +- Type: `change` +- Summary: Shared HF text serving now auto-selects ONNX artifacts for CPU-only runtime when the HF repo declares a compatible runtime manifest. +- Criticality: Serving runtime architecture change affecting model loading, artifact downloads, output decoding, and API metadata across generic text-classification deployments. +- Details: `ThHfModelBase` keeps Transformers/PT as the default GPU and fallback path, but CPU-only `HF_RUNTIME=auto` now loads `artifact_manifest.json`, selects a declared ONNX Runtime artifact, downloads only safe allow-patterns, loads schema and contract decoder from HF artifacts, and exposes the decoded artifact contract through the existing text-classifier flow. Business API response shaping now passes through generic model/runtime metadata emitted by serving. +- Verification: `python3 -m unittest extensions.serving.test_th_hf_model_base extensions.serving.test_th_text_classifier extensions.serving.test_th_privacy_filter extensions.business.edge_inference_api.test_text_classifier_inference_api extensions.business.edge_inference_api.test_privacy_filter_inference_api`; `python3 -m py_compile extensions/serving/default_inference/nlp/th_hf_model_base.py extensions/business/edge_inference_api/text_classifier_inference_api.py`; required serving gate `python3 -m unittest extensions.serving.model_testing.test_llm_servings` currently fails at import with `ImportError: cannot import name 'Logger' from 'naeural_core'`. +- Links: `extensions/serving/default_inference/nlp/th_hf_model_base.py`, `extensions/business/edge_inference_api/text_classifier_inference_api.py`, `extensions/serving/test_th_hf_model_base.py` diff --git a/extensions/business/edge_inference_api/test_text_classifier_inference_api.py b/extensions/business/edge_inference_api/test_text_classifier_inference_api.py index 6e3c8085..aa945ea3 100644 --- a/extensions/business/edge_inference_api/test_text_classifier_inference_api.py +++ b/extensions/business/edge_inference_api/test_text_classifier_inference_api.py @@ -152,6 +152,33 @@ def test_build_result_from_inference_preserves_classifier_output(self): self.assertEqual(result_payload["model_name"], "openai/privacy-filter") self.assertEqual(result_payload["pipeline_task"], "token-classification") + def test_build_result_from_inference_preserves_runtime_model_metadata(self): + plugin = TextClassifierInferenceApiPlugin() + + result_payload = plugin._build_result_from_inference( # pylint: disable=protected-access + request_id="req-onnx", + inference={ + "REQUEST_ID": "req-onnx", + "TEXT": "example text", + "result": {"prediction": "safe"}, + "MODEL": {"key": "generic_text_classifier", "version": "2026.05.09"}, + "MODEL_VERSION": "2026.05.09", + "HF_RUNTIME": "onnx_fp32", + "RUNTIME": "onnxruntime", + }, + metadata={}, + request_data={"metadata": {}, "parameters": {"text": "example text"}}, + ) + + self.assertEqual(result_payload["classification"], {"prediction": "safe"}) + self.assertEqual( + result_payload["model"], + {"key": "generic_text_classifier", "version": "2026.05.09"}, + ) + self.assertEqual(result_payload["model_version"], "2026.05.09") + self.assertEqual(result_payload["hf_runtime"], "onnx_fp32") + self.assertEqual(result_payload["runtime"], "onnxruntime") + def test_handle_inferences_falls_back_to_payload_request_id(self): plugin = TextClassifierInferenceApiPlugin() plugin._requests = {"req-1": {"status": "pending"}} # pylint: disable=protected-access diff --git a/extensions/business/edge_inference_api/text_classifier_inference_api.py b/extensions/business/edge_inference_api/text_classifier_inference_api.py index 867f1ff4..2de5aa1a 100644 --- a/extensions/business/edge_inference_api/text_classifier_inference_api.py +++ b/extensions/business/edge_inference_api/text_classifier_inference_api.py @@ -405,6 +405,16 @@ def _build_result_from_inference( result_payload["tokenizer_name"] = inference["TOKENIZER_NAME"] if "PIPELINE_TASK" in inference: result_payload["pipeline_task"] = inference["PIPELINE_TASK"] + if "MODEL" in inference: + result_payload["model"] = inference["MODEL"] + if "MODEL_VERSION" in inference: + result_payload["model_version"] = inference["MODEL_VERSION"] + if "MODEL_REVISION" in inference: + result_payload["model_revision"] = inference["MODEL_REVISION"] + if "HF_RUNTIME" in inference: + result_payload["hf_runtime"] = inference["HF_RUNTIME"] + if "RUNTIME" in inference: + result_payload["runtime"] = inference["RUNTIME"] return result_payload def handle_inference_for_request( diff --git a/extensions/serving/default_inference/nlp/th_hf_model_base.py b/extensions/serving/default_inference/nlp/th_hf_model_base.py index 436febf0..db935025 100644 --- a/extensions/serving/default_inference/nlp/th_hf_model_base.py +++ b/extensions/serving/default_inference/nlp/th_hf_model_base.py @@ -6,6 +6,11 @@ input/output handling. """ +import importlib.util +import inspect +import json +from pathlib import Path + import torch as th from transformers import BitsAndBytesConfig, pipeline as hf_pipeline @@ -24,6 +29,11 @@ "MODEL_NAME": None, "TOKENIZER_NAME": None, "PIPELINE_TASK": None, + "MODEL_REVISION": None, + "HF_RUNTIME": "auto", + "HF_ARTIFACT_MANIFEST": "artifact_manifest.json", + "HF_ONNX_RUNTIME_KEY": "onnx_fp32", + "HF_ONNX_ALLOW_PATTERNS": None, "TEXT_KEYS": ["text", "email_text", "content", "request", "body"], "REQUEST_ID_KEYS": ["request_id", "REQUEST_ID"], "MAX_LENGTH": 512, @@ -44,6 +54,164 @@ } +class HfOnnxArtifactPipeline: + """Callable adapter that exposes an ONNX artifact as a pipeline-like object.""" + + def __init__( + self, + repo_id, + runtime_key, + runtime_config, + tokenizer, + session, + schema, + decoder, + task=None, + max_length=None, + ): + self.repo_id = repo_id + self.runtime_key = runtime_key + self.runtime_config = runtime_config or {} + self.tokenizer = tokenizer + self.session = session + self.schema = schema or {} + self.decoder = decoder + self.task = task + self.framework = "onnxruntime" + self.max_length = max_length + return + + def __call__(self, texts, **kwargs): + """Run one or more text inputs through the ONNX artifact.""" + is_single_text = isinstance(texts, str) + text_items = [texts] if is_single_text else list(texts or []) + results = [ + self._run_single_text(text=text, inference_kwargs=kwargs) + for text in text_items + ] + return results[0] if is_single_text or len(results) == 1 else results + + def _get_max_length(self, inference_kwargs): + max_length = inference_kwargs.get("max_length") + if max_length is not None: + return max_length + if self.max_length is not None: + return self.max_length + schema_max_length = self.schema.get("max_length") + return schema_max_length if schema_max_length is not None else None + + def _tokenize(self, text, inference_kwargs): + tokenize_kwargs = { + "return_tensors": "np", + "truncation": bool(inference_kwargs.get("truncation", True)), + } + max_length = self._get_max_length(inference_kwargs) + if max_length is not None: + tokenize_kwargs["max_length"] = max_length + if "padding" in inference_kwargs: + tokenize_kwargs["padding"] = inference_kwargs["padding"] + return self.tokenizer(text, **tokenize_kwargs) + + def _input_specs(self): + inputs = self.schema.get("inputs") + if isinstance(inputs, list): + return inputs + if isinstance(inputs, dict): + return [ + {"name": name, **(spec if isinstance(spec, dict) else {})} + for name, spec in inputs.items() + ] + return [ + {"name": "input_ids", "dtype": "int64"}, + {"name": "attention_mask", "dtype": "int64"}, + ] + + def _output_names(self): + output_names = self.runtime_config.get("output_names") + if isinstance(output_names, list) and output_names: + return output_names + output_order = self.schema.get("output_order") + if isinstance(output_order, list) and output_order: + return output_order + outputs = self.schema.get("outputs") + if isinstance(outputs, list): + names = [] + for output in outputs: + if isinstance(output, dict) and output.get("name"): + names.append(output["name"]) + elif isinstance(output, str): + names.append(output) + if names: + return names + if hasattr(self.session, "get_outputs"): + session_output_names = [ + output.name for output in self.session.get_outputs() + if getattr(output, "name", None) + ] + if session_output_names: + return session_output_names + return None + + def _prepare_session_inputs(self, encoded): + session_inputs = {} + for input_spec in self._input_specs(): + if isinstance(input_spec, dict): + input_name = input_spec.get("name") + dtype = input_spec.get("dtype") + else: + input_name = str(input_spec) + dtype = None + if not input_name or input_name not in encoded: + continue + value = encoded[input_name] + if dtype is not None and hasattr(value, "astype"): + value = value.astype(dtype) + session_inputs[input_name] = value + if not session_inputs and hasattr(encoded, "items"): + session_inputs = dict(encoded.items()) + return session_inputs + + def _build_output_map(self, raw_outputs, output_names): + if output_names is None: + output_names = [f"output_{idx}" for idx in range(len(raw_outputs))] + return { + output_name: output_value + for output_name, output_value in zip(output_names, raw_outputs) + } + + def _call_decoder(self, outputs_by_name, text): + if self.decoder is None: + return outputs_by_name + decoder_kwargs = { + "runtime": self.runtime_key, + "runtime_key": self.runtime_key, + "text": text, + "repo_id": self.repo_id, + } + try: + signature = inspect.signature(self.decoder) + accepts_var_kwargs = any( + param.kind == inspect.Parameter.VAR_KEYWORD + for param in signature.parameters.values() + ) + if not accepts_var_kwargs: + decoder_kwargs = { + key: value for key, value in decoder_kwargs.items() + if key in signature.parameters + } + except (TypeError, ValueError): + pass + return self.decoder(outputs_by_name, self.schema, **decoder_kwargs) + + def _run_single_text(self, text, inference_kwargs): + encoded = self._tokenize(text=text, inference_kwargs=inference_kwargs) + session_inputs = self._prepare_session_inputs(encoded) + output_names = self._output_names() + raw_outputs = self.session.run(output_names, session_inputs) + outputs_by_name = self._build_output_map(raw_outputs, output_names) + return self._call_decoder(outputs_by_name=outputs_by_name, text=text) + + class ThHfModelBase(BaseServingProcess): CONFIG = _CONFIG @@ -57,6 +225,9 @@ def __init__(self, **kwargs): """ self.classifier = None self.device = None + self.hf_runtime = "pt" + self.hf_runtime_config = {} + self.hf_artifact_manifest = None super(ThHfModelBase, self).__init__(**kwargs) return @@ -102,6 +273,16 @@ def get_pipeline_task(self): """ return self.cfg_pipeline_task + def get_model_revision(self): + """Return the optional Hugging Face model revision. + + Returns + ------- + str or None + Configured `MODEL_REVISION`, or `None` when unset. + """ + return getattr(self, "cfg_model_revision", None) + @property def cache_dir(self): """Return the local cache directory for Hugging Face artifacts. @@ -255,6 +436,346 @@ def _get_model_load_config(self): cache_dir=self.cache_dir, ) + def _requested_hf_runtime(self): + """Return the normalized requested HF runtime selector.""" + requested = getattr(self, "cfg_hf_runtime", "auto") + if requested is None: + return "auto" + requested = str(requested).strip().lower() + if requested in {"", "auto"}: + return "auto" + if requested in {"pt", "torch", "pytorch", "transformers"}: + return "pt" + if requested == "onnx": + return "onnx" + return requested + + def _should_load_hf_artifact_manifest(self, requested_runtime): + """Return whether startup needs the HF artifact manifest.""" + if requested_runtime == "pt": + return False + if requested_runtime == "auto": + return self.device == -1 + return True + + def _download_hf_artifact_file(self, filename): + """Download one HF artifact file and return its local path.""" + from huggingface_hub import hf_hub_download + + return hf_hub_download( + repo_id=self.get_model_name(), + filename=filename, + revision=self.get_model_revision(), + token=self.hf_token, + cache_dir=self.cache_dir, + repo_type="model", + ) + + def _load_hf_artifact_manifest(self): + """Load the optional artifact manifest from the configured HF model repo.""" + manifest_name = getattr(self, "cfg_hf_artifact_manifest", None) + if not manifest_name: + return None + try: + manifest_path = self._download_hf_artifact_file(manifest_name) + return json.loads(Path(manifest_path).read_text(encoding="utf-8")) + except Exception as exc: + if self._requested_hf_runtime() != "auto": + raise + self.P( + f"HF artifact manifest {manifest_name} not available for {self.get_model_name()}: {exc}", + color="y", + ) + return None + + def _get_hf_manifest_runtimes(self, manifest): + """Extract runtime definitions from an artifact manifest.""" + if not isinstance(manifest, dict): + return {} + runtimes = manifest.get("runtimes") + return runtimes if isinstance(runtimes, dict) else {} + + def _runtime_is_onnx(self, runtime_key, runtime_config): + """Return whether a manifest runtime is backed by ONNX Runtime.""" + runtime_config = runtime_config or {} + runtime_name = str(runtime_config.get("runtime", "")).lower() + entrypoint = str(runtime_config.get("entrypoint", "")).lower() + runtime_key = str(runtime_key or "").lower() + return ( + "onnxruntime" in runtime_name + or "onnxruntime" in entrypoint + or runtime_key.startswith("onnx") + ) + + def _resolve_hf_onnx_runtime_key(self, runtimes): + """Find the preferred ONNX runtime key from manifest runtimes.""" + preferred = getattr(self, "cfg_hf_onnx_runtime_key", None) + if preferred in runtimes and self._runtime_is_onnx(preferred, runtimes[preferred]): + return preferred + for runtime_key, runtime_config in runtimes.items(): + if self._runtime_is_onnx(runtime_key, runtime_config): + return runtime_key + return None + + def _select_hf_runtime(self, manifest): + """Select the runtime to load for this startup.""" + requested_runtime = self._requested_hf_runtime() + runtimes = self._get_hf_manifest_runtimes(manifest) + if requested_runtime == "pt": + return "pt", runtimes.get("pt", {}) + if requested_runtime == "auto": + if self.device == -1: + runtime_key = self._resolve_hf_onnx_runtime_key(runtimes) + if runtime_key is not None: + return runtime_key, runtimes[runtime_key] + return "pt", runtimes.get("pt", {}) + if requested_runtime in runtimes: + return requested_runtime, runtimes[requested_runtime] + if requested_runtime == "onnx": + runtime_key = self._resolve_hf_onnx_runtime_key(runtimes) + if runtime_key is not None: + return runtime_key, runtimes[runtime_key] + manifest_name = getattr(self, "cfg_hf_artifact_manifest", "artifact_manifest.json") + raise ValueError( + f"HF runtime {requested_runtime!r} is not declared in {manifest_name!r} for {self.get_model_name()}." + ) + + def _blocked_hf_weight_pattern(self, pattern): + """Return whether a download pattern could pull framework weight files.""" + pattern = str(pattern) + blocked_suffixes = ( + ".safetensors", + "pytorch_model.bin", + "tf_model.h5", + "flax_model.msgpack", + ) + blocked_wildcards = ("*.safetensors", "*.bin", "*.h5", "*.msgpack") + return pattern.endswith(blocked_suffixes) or pattern in blocked_wildcards + + def _build_hf_runtime_allow_patterns(self, runtime_config): + """Build safe HF snapshot allow-patterns for an ONNX runtime.""" + configured_patterns = getattr(self, "cfg_hf_onnx_allow_patterns", None) + if configured_patterns: + patterns = configured_patterns + else: + patterns = runtime_config.get("recommended_allow_patterns") or runtime_config.get("files") + if not patterns: + model_file = runtime_config.get("model") + patterns = [ + model_file, + "*.onnx", + "**/*.onnx", + "onnx/*", + "onnx/**", + "*.json", + "*.py", + "*.txt", + "*.model", + "*.tiktoken", + ] + if isinstance(patterns, str): + patterns = [patterns] + safe_patterns = [] + for pattern in patterns or []: + if not pattern or self._blocked_hf_weight_pattern(pattern): + continue + if pattern not in safe_patterns: + safe_patterns.append(pattern) + if not safe_patterns: + raise ValueError("HF ONNX runtime download has no safe allow patterns.") + return safe_patterns + + def _download_hf_runtime_snapshot(self, runtime_key, runtime_config, allow_patterns): + """Download the minimal HF snapshot needed for a selected runtime.""" + from huggingface_hub import snapshot_download + + self.P( + f"Downloading HF runtime {runtime_key} artifacts for {self.get_model_name()}...", + color="y", + ) + return snapshot_download( + repo_id=self.get_model_name(), + revision=self.get_model_revision(), + token=self.hf_token, + cache_dir=self.cache_dir, + allow_patterns=allow_patterns, + repo_type="model", + ) + + def _runtime_file_list(self, runtime_config): + files = runtime_config.get("files") if isinstance(runtime_config, dict) else None + return files if isinstance(files, list) else [] + + def _first_manifest_file_with_suffix(self, runtime_config, suffixes): + """Return the first exact manifest file path ending with any suffix.""" + for file_path in self._runtime_file_list(runtime_config): + file_path = str(file_path) + if any(file_path.endswith(suffix) for suffix in suffixes): + return file_path + return None + + def _resolve_manifest_file_path(self, model_dir, manifest, runtime_config, keys, suffixes): + """Resolve a model-repo file path declared directly or inferred by suffix.""" + for key in keys: + value = runtime_config.get(key) if isinstance(runtime_config, dict) else None + if value is None and isinstance(manifest, dict): + value = manifest.get(key) + if value: + return Path(model_dir) / str(value) + inferred = self._first_manifest_file_with_suffix(runtime_config, suffixes) + if inferred: + return Path(model_dir) / inferred + return None + + def _load_hf_schema(self, model_dir, manifest, runtime_config): + """Load the JSON schema declared by the selected HF runtime.""" + schema_path = self._resolve_manifest_file_path( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + keys=("schema", "schema_file", "contract_schema"), + suffixes=("_schema.json", "schema.json"), + ) + if schema_path is None or not schema_path.exists(): + raise ValueError(f"HF runtime {self.hf_runtime} does not declare a usable schema file.") + return json.loads(schema_path.read_text(encoding="utf-8")) + + def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): + """Load the artifact decoder function declared by the selected HF runtime.""" + decoder_path = self._resolve_manifest_file_path( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + keys=("decoder", "decoder_file", "contract", "contract_file"), + suffixes=("_contract.py", "contract.py"), + ) + if decoder_path is None or not decoder_path.exists(): + raise ValueError(f"HF runtime {self.hf_runtime} does not declare a usable contract decoder.") + module_name = f"hf_artifact_contract_{abs(hash(str(decoder_path)))}" + spec = importlib.util.spec_from_file_location(module_name, decoder_path) + if spec is None or spec.loader is None: + raise ValueError(f"Could not load HF contract decoder from {decoder_path}.") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + function_name = None + if isinstance(runtime_config, dict): + function_name = runtime_config.get("decoder_function") + if function_name is None and isinstance(manifest, dict): + function_name = manifest.get("decoder_function") + if function_name is None and callable(getattr(module, "decode_outputs", None)): + function_name = "decode_outputs" + if function_name is None: + decode_functions = [ + name for name in dir(module) + if name.startswith("decode_") + and name.endswith("_outputs") + and callable(getattr(module, name, None)) + ] + if len(decode_functions) == 1: + function_name = decode_functions[0] + decoder = getattr(module, function_name, None) if function_name else None + if not callable(decoder): + raise ValueError(f"Could not resolve a decoder function in {decoder_path}.") + return decoder + + def _resolve_hf_onnx_model_path(self, model_dir, runtime_key, runtime_config, schema): + """Resolve the ONNX model file for the selected runtime.""" + for key in ("model", "model_file", "path"): + value = runtime_config.get(key) if isinstance(runtime_config, dict) else None + if value: + return Path(model_dir) / str(value) + models = schema.get("models") if isinstance(schema, dict) else None + if isinstance(models, dict): + candidates = [ + runtime_key, + str(runtime_key).replace("_", "-"), + str(runtime_key).replace("-", "_"), + ] + for candidate in candidates: + value = models.get(candidate) + if value: + if isinstance(value, dict): + value = value.get("path") or value.get("file") or value.get("model") + if not value: + continue + return Path(model_dir) / str(value) + model_file = self._first_manifest_file_with_suffix(runtime_config, (".onnx",)) + if model_file: + return Path(model_dir) / model_file + raise ValueError(f"HF runtime {runtime_key} does not declare an ONNX model file.") + + def _resolve_hf_tokenizer_dir(self, model_dir, manifest, runtime_config, schema): + """Resolve tokenizer directory for the selected artifact runtime.""" + tokenizer_dir = None + for source in (runtime_config, schema, manifest): + if isinstance(source, dict) and source.get("tokenizer_dir"): + tokenizer_dir = source["tokenizer_dir"] + break + return Path(model_dir) / str(tokenizer_dir or ".") + + def _load_hf_onnx_tokenizer(self, model_dir, runtime_config): + """Load the tokenizer for an ONNX HF artifact.""" + from transformers import AutoTokenizer + + return AutoTokenizer.from_pretrained( + str(model_dir), + token=self.hf_token, + trust_remote_code=bool(runtime_config.get("trust_remote_code", False)), + ) + + def _create_hf_onnx_session(self, model_path, providers): + """Create an ONNX Runtime inference session.""" + import onnxruntime as ort + + return ort.InferenceSession(str(model_path), providers=providers) + + def _build_hf_onnx_artifact_pipeline(self, model_dir, runtime_key, runtime_config, manifest): + """Build a callable ONNX artifact pipeline from downloaded HF files.""" + schema = self._load_hf_schema( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + ) + decoder = self._load_hf_contract_decoder( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + ) + tokenizer_dir = self._resolve_hf_tokenizer_dir( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + schema=schema, + ) + tokenizer = self._load_hf_onnx_tokenizer( + model_dir=tokenizer_dir, + runtime_config=runtime_config, + ) + model_path = self._resolve_hf_onnx_model_path( + model_dir=model_dir, + runtime_key=runtime_key, + runtime_config=runtime_config, + schema=schema, + ) + provider = runtime_config.get("provider") or "CPUExecutionProvider" + providers = runtime_config.get("providers") or [provider] + session = self._create_hf_onnx_session( + model_path=model_path, + providers=providers, + ) + return HfOnnxArtifactPipeline( + repo_id=self.get_model_name(), + runtime_key=runtime_key, + runtime_config=runtime_config, + tokenizer=tokenizer, + session=session, + schema=schema, + decoder=decoder, + task=runtime_config.get("pipeline_task") or manifest.get("pipeline_task") or self.get_pipeline_task(), + max_length=self.cfg_max_length, + ) + def _normalize_pipeline_runtime_contract(self): """Patch known gaps in custom remote-code pipeline initialization. @@ -299,19 +820,9 @@ def _run_startup_warmup(self): ) return - def startup(self): - """Load the Hugging Face pipeline and prepare it for inference. - - Raises - ------ - ValueError - If `MODEL_NAME` is not configured. - """ + def _startup_transformers_pipeline(self): + """Load the standard Transformers pipeline runtime.""" model_name = self.get_model_name() - if not model_name: - raise ValueError(f"{self.__class__.__name__} serving requires MODEL_NAME.") - - self.device = self._resolve_pipeline_device() model_load_params, quantization_params = self._get_model_load_config() pipeline_kwargs = self.build_pipeline_kwargs() model_kwargs = { @@ -334,12 +845,84 @@ def startup(self): trust_remote_code=bool(self.cfg_trust_remote_code), device=self.device, model_kwargs=model_kwargs, + revision=self.get_model_revision(), **pipeline_kwargs, ) self._normalize_pipeline_runtime_contract() + return + + def _startup_hf_onnx_artifact(self, runtime_key, runtime_config, manifest): + """Load the selected ONNX artifact runtime from the HF repository.""" + allow_patterns = self._build_hf_runtime_allow_patterns(runtime_config) + model_dir = self._download_hf_runtime_snapshot( + runtime_key=runtime_key, + runtime_config=runtime_config, + allow_patterns=allow_patterns, + ) + self.classifier = self._build_hf_onnx_artifact_pipeline( + model_dir=model_dir, + runtime_key=runtime_key, + runtime_config=runtime_config, + manifest=manifest or {}, + ) + return + + def startup(self): + """Load the Hugging Face runtime and prepare it for inference. + + Raises + ------ + ValueError + If `MODEL_NAME` is not configured. + """ + model_name = self.get_model_name() + if not model_name: + raise ValueError(f"{self.__class__.__name__} serving requires MODEL_NAME.") + + self.device = self._resolve_pipeline_device() + requested_runtime = self._requested_hf_runtime() + manifest = None + if self._should_load_hf_artifact_manifest(requested_runtime=requested_runtime): + manifest = self._load_hf_artifact_manifest() + runtime_key, runtime_config = self._select_hf_runtime(manifest=manifest) + self.hf_runtime = runtime_key + self.hf_runtime_config = dict(runtime_config or {}) + self.hf_artifact_manifest = manifest if isinstance(manifest, dict) else None + + if self._runtime_is_onnx(runtime_key=runtime_key, runtime_config=runtime_config): + self._startup_hf_onnx_artifact( + runtime_key=runtime_key, + runtime_config=runtime_config, + manifest=manifest, + ) + else: + self._startup_transformers_pipeline() self._run_startup_warmup() return + def _get_hf_artifact_model_metadata(self): + """Return model metadata declared by the loaded artifact.""" + metadata = {} + has_artifact_metadata = False + for source in (self.hf_artifact_manifest, getattr(self.classifier, "schema", None)): + if not isinstance(source, dict): + continue + for key in ( + "repo_id", + "repo_key", + "model_key", + "model_version", + "release_channel", + "release_alias_of", + "source_repo_id", + ): + if key not in metadata and source.get(key) is not None: + metadata[key] = source[key] + has_artifact_metadata = True + if has_artifact_metadata and self.hf_runtime: + metadata["runtime"] = self.hf_runtime + return metadata + def get_additional_metadata(self): """Return model metadata attached to decoded predictions. @@ -349,11 +932,24 @@ def get_additional_metadata(self): Model name, tokenizer name, and pipeline task metadata. """ pipeline_task = getattr(self.classifier, "task", None) if self.classifier is not None else None - return { + metadata = { "MODEL_NAME": self.get_model_name(), "TOKENIZER_NAME": self.get_tokenizer_name(), "PIPELINE_TASK": pipeline_task or self.get_pipeline_task(), + "HF_RUNTIME": self.hf_runtime, + "RUNTIME": self.hf_runtime_config.get("runtime") or ( + "onnxruntime" if self._runtime_is_onnx(self.hf_runtime, self.hf_runtime_config) else "transformers" + ), } + model_revision = self.get_model_revision() + if model_revision is not None: + metadata["MODEL_REVISION"] = model_revision + artifact_model_metadata = self._get_hf_artifact_model_metadata() + if artifact_model_metadata: + metadata["MODEL"] = artifact_model_metadata + if artifact_model_metadata.get("model_version") is not None: + metadata["MODEL_VERSION"] = artifact_model_metadata["model_version"] + return metadata def _extract_serving_target(self, struct_payload): """Extract the reserved serving-target metadata from a payload. diff --git a/extensions/serving/test_th_hf_model_base.py b/extensions/serving/test_th_hf_model_base.py index e69c4c0e..6692bc8e 100644 --- a/extensions/serving/test_th_hf_model_base.py +++ b/extensions/serving/test_th_hf_model_base.py @@ -2,6 +2,7 @@ import unittest from pathlib import Path +from tempfile import TemporaryDirectory ROOT = Path(__file__).resolve().parents[2] @@ -14,6 +15,11 @@ def __init__(self, **kwargs): self.cfg_model_name = kwargs.get("MODEL_NAME") self.cfg_tokenizer_name = kwargs.get("TOKENIZER_NAME") self.cfg_pipeline_task = kwargs.get("PIPELINE_TASK") + self.cfg_model_revision = kwargs.get("MODEL_REVISION") + self.cfg_hf_runtime = kwargs.get("HF_RUNTIME", "auto") + self.cfg_hf_artifact_manifest = kwargs.get("HF_ARTIFACT_MANIFEST", "artifact_manifest.json") + self.cfg_hf_onnx_runtime_key = kwargs.get("HF_ONNX_RUNTIME_KEY", "onnx_fp32") + self.cfg_hf_onnx_allow_patterns = kwargs.get("HF_ONNX_ALLOW_PATTERNS") self.cfg_max_length = kwargs.get("MAX_LENGTH", 512) self.cfg_model_weights_size = kwargs.get("MODEL_WEIGHTS_SIZE") self.cfg_hf_token = kwargs.get("HF_TOKEN") @@ -77,6 +83,7 @@ def __init__(self, **kwargs): class _FakePipeline: def __init__(self, task=None): self.task = task + self.framework = "pt" self.inference_calls = [] def __call__(self, text, **kwargs): @@ -104,6 +111,37 @@ def is_available(): return True +class _FakeEncodedValue: + def __init__(self, value): + self.value = value + self.dtype = None + + def astype(self, dtype): + self.dtype = dtype + return self + + +class _FakeTokenizer: + def __init__(self): + self.calls = [] + + def __call__(self, text, **kwargs): + self.calls.append((text, kwargs)) + return { + "input_ids": _FakeEncodedValue([1, 2, 3]), + "attention_mask": _FakeEncodedValue([1, 1, 1]), + } + + +class _FakeOrtSession: + def __init__(self): + self.calls = [] + + def run(self, output_names, inputs): + self.calls.append((output_names, inputs)) + return [[0.25, 0.75]] + + def _load_base_class(): factory = _PipelineFactory() source_path = ROOT / "extensions" / "serving" / "default_inference" / "nlp" / "th_hf_model_base.py" @@ -125,10 +163,10 @@ def _load_base_class(): "__name__": "loaded_th_hf_model_base", } exec(compile(source, str(source_path), "exec"), namespace) # noqa: S102 - return namespace["ThHfModelBase"], factory + return namespace["ThHfModelBase"], namespace["HfOnnxArtifactPipeline"], factory -ThHfModelBase, _PIPELINE_FACTORY = _load_base_class() +ThHfModelBase, HfOnnxArtifactPipeline, _PIPELINE_FACTORY = _load_base_class() class _ConcreteHfModel(ThHfModelBase): @@ -136,6 +174,11 @@ class _ConcreteHfModel(ThHfModelBase): class ThHfModelBaseTests(unittest.TestCase): + def setUp(self): + _PIPELINE_FACTORY.calls = [] + _PIPELINE_FACTORY.instance.inference_calls = [] + return + def test_hf_serving_raises_default_wait_time_above_generic_base(self): self.assertEqual(_ConcreteHfModel.CONFIG["MAX_WAIT_TIME"], 60) @@ -224,6 +267,191 @@ def test_startup_can_disable_warmup(self): self.assertEqual(after_calls, before_calls) + def test_forced_pt_runtime_passes_model_revision_to_transformers_pipeline(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + MODEL_REVISION="rev-123", + HF_RUNTIME="pt", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + + plugin.startup() + + _args, kwargs = _PIPELINE_FACTORY.calls[-1] + self.assertEqual(kwargs["revision"], "rev-123") + self.assertEqual(plugin.hf_runtime, "pt") + + def test_auto_runtime_uses_onnx_artifact_on_cpu_only(self): + manifest = { + "model_key": "generic_text_classifier", + "model_version": "2026.05.09", + "pipeline_task": "text-classification", + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "entrypoint": "onnxruntime.InferenceSession", + "files": [ + "model.onnx", + "tokenizer.json", + "contract.py", + "schema.json", + "model.safetensors", + ], + } + }, + } + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + download_calls = [] + plugin._load_hf_artifact_manifest = lambda: manifest # pylint: disable=protected-access + + def fake_download(runtime_key, runtime_config, allow_patterns): + download_calls.append((runtime_key, runtime_config, allow_patterns)) + return "/tmp/models/test-model" + + plugin._download_hf_runtime_snapshot = fake_download # pylint: disable=protected-access + plugin._build_hf_onnx_artifact_pipeline = ( # pylint: disable=protected-access + lambda model_dir, runtime_key, runtime_config, manifest: _FakePipeline(task="text-classification") + ) + + plugin.startup() + + self.assertEqual(plugin.hf_runtime, "onnx_fp32") + self.assertEqual(len(_PIPELINE_FACTORY.calls), 0) + self.assertEqual(download_calls[0][0], "onnx_fp32") + self.assertIn("model.onnx", download_calls[0][2]) + self.assertNotIn("model.safetensors", download_calls[0][2]) + + def test_auto_runtime_keeps_transformers_pipeline_when_gpu_available(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + plugin._load_hf_artifact_manifest = lambda: { # pylint: disable=protected-access + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "entrypoint": "onnxruntime.InferenceSession", + "files": ["model.onnx"], + } + }, + } + + plugin.startup() + + self.assertEqual(plugin.device, 0) + self.assertEqual(plugin.hf_runtime, "pt") + self.assertEqual(len(_PIPELINE_FACTORY.calls), 1) + + def test_forced_onnx_runtime_uses_manifest_runtime_without_hardcoded_key(self): + manifest = { + "runtimes": { + "cpu_artifact": { + "runtime": "onnxruntime", + "entrypoint": "onnxruntime.InferenceSession", + "files": ["model.onnx", "schema.json", "contract.py"], + } + }, + } + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + HF_RUNTIME="onnx", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + plugin._load_hf_artifact_manifest = lambda: manifest # pylint: disable=protected-access + plugin._download_hf_runtime_snapshot = ( # pylint: disable=protected-access + lambda runtime_key, runtime_config, allow_patterns: "/tmp/models/test-model" + ) + plugin._build_hf_onnx_artifact_pipeline = ( # pylint: disable=protected-access + lambda model_dir, runtime_key, runtime_config, manifest: _FakePipeline(task="text-classification") + ) + + plugin.startup() + + self.assertEqual(plugin.hf_runtime, "cpu_artifact") + self.assertEqual(len(_PIPELINE_FACTORY.calls), 0) + + def test_onnx_artifact_pipeline_uses_hf_contract_decoder(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + fake_tokenizer = _FakeTokenizer() + fake_session = _FakeOrtSession() + created_sessions = [] + plugin._load_hf_onnx_tokenizer = lambda model_dir, runtime_config: fake_tokenizer # pylint: disable=protected-access + + def fake_create_session(model_path, providers): + created_sessions.append((model_path, providers)) + return fake_session + + plugin._create_hf_onnx_session = fake_create_session # pylint: disable=protected-access + + with TemporaryDirectory() as tmpdir: + model_dir = Path(tmpdir) + (model_dir / "model.onnx").write_text("fake", encoding="utf-8") + (model_dir / "schema.json").write_text( + ( + '{"inputs":[{"name":"input_ids","dtype":"int64"},' + '{"name":"attention_mask","dtype":"int64"}],' + '"outputs":[{"name":"scores"}],' + '"models":{"onnx_fp32":{"path":"model.onnx"}}}' + ), + encoding="utf-8", + ) + (model_dir / "contract.py").write_text( + ( + "def decode_generic_outputs(outputs, schema, **kwargs):\n" + " return {\n" + " 'contract': 'hf',\n" + " 'outputs': outputs,\n" + " 'repo_id': kwargs.get('repo_id'),\n" + " 'runtime': kwargs.get('runtime_key'),\n" + " }\n" + ), + encoding="utf-8", + ) + manifest = { + "pipeline_task": "text-classification", + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "files": ["model.onnx", "schema.json", "contract.py"], + } + }, + } + + pipeline = plugin._build_hf_onnx_artifact_pipeline( # pylint: disable=protected-access + model_dir=str(model_dir), + runtime_key="onnx_fp32", + runtime_config=manifest["runtimes"]["onnx_fp32"], + manifest=manifest, + ) + result = pipeline("hello world") + batched_single_result = pipeline(["hello world"]) + + self.assertIsInstance(pipeline, HfOnnxArtifactPipeline) + self.assertEqual(result["contract"], "hf") + self.assertEqual(batched_single_result["contract"], "hf") + self.assertEqual(result["outputs"], {"scores": [0.25, 0.75]}) + self.assertEqual(result["repo_id"], "test/model") + self.assertEqual(result["runtime"], "onnx_fp32") + self.assertEqual(Path(created_sessions[0][0]).name, "model.onnx") + output_names, inputs = fake_session.calls[-1] + self.assertEqual(output_names, ["scores"]) + self.assertEqual(inputs["input_ids"].dtype, "int64") + self.assertEqual(fake_tokenizer.calls[-1][1]["return_tensors"], "np") + if __name__ == "__main__": unittest.main() From e5441964bdf64fa0babaf0d175ec5b94044ec670 Mon Sep 17 00:00:00 2001 From: Cristi Bleotiu Date: Mon, 11 May 2026 12:57:31 +0300 Subject: [PATCH 2/8] fix: harden hf onnx runtime fallback What changed: - Make auto ONNX startup opportunistic and fall back to Transformers/PT on ONNX init or warmup failure. - Keep explicit ONNX runtimes fail-fast while explicit PT skips manifest lookup. - Gate decoder and tokenizer remote code on global and runtime trust flags. - Confine manifest-declared artifact paths to the downloaded HF snapshot and filter broad/framework-weight allow patterns. - Forward runtime metadata consistently for privacy-filter responses and add focused regression coverage. Why: - Preserve seamless CPU ONNX when available without breaking Transformers fallback or weakening remote-code/path safety. --- .../privacy_filter_inference_api.py | 10 + .../test_privacy_filter_inference_api.py | 13 + .../default_inference/nlp/th_hf_model_base.py | 100 ++++-- extensions/serving/test_th_hf_model_base.py | 306 +++++++++++++++++- 4 files changed, 403 insertions(+), 26 deletions(-) diff --git a/extensions/business/edge_inference_api/privacy_filter_inference_api.py b/extensions/business/edge_inference_api/privacy_filter_inference_api.py index 295d8c00..3cf0bd72 100644 --- a/extensions/business/edge_inference_api/privacy_filter_inference_api.py +++ b/extensions/business/edge_inference_api/privacy_filter_inference_api.py @@ -95,4 +95,14 @@ def _build_result_from_inference( # pylint: disable=arguments-differ result_payload["tokenizer_name"] = inference["TOKENIZER_NAME"] if "PIPELINE_TASK" in inference: result_payload["pipeline_task"] = inference["PIPELINE_TASK"] + if "MODEL" in inference: + result_payload["model"] = inference["MODEL"] + if "MODEL_VERSION" in inference: + result_payload["model_version"] = inference["MODEL_VERSION"] + if "MODEL_REVISION" in inference: + result_payload["model_revision"] = inference["MODEL_REVISION"] + if "HF_RUNTIME" in inference: + result_payload["hf_runtime"] = inference["HF_RUNTIME"] + if "RUNTIME" in inference: + result_payload["runtime"] = inference["RUNTIME"] return result_payload diff --git a/extensions/business/edge_inference_api/test_privacy_filter_inference_api.py b/extensions/business/edge_inference_api/test_privacy_filter_inference_api.py index d5219e83..ecb3bec7 100644 --- a/extensions/business/edge_inference_api/test_privacy_filter_inference_api.py +++ b/extensions/business/edge_inference_api/test_privacy_filter_inference_api.py @@ -55,6 +55,11 @@ def test_build_result_from_inference_uses_findings_key(self): "FINDINGS_COUNT": 1, "MODEL_NAME": "openai/privacy-filter", "PIPELINE_TASK": "token-classification", + "MODEL": {"model_key": "privacy_filter", "model_version": "2026.05.09"}, + "MODEL_VERSION": "2026.05.09", + "MODEL_REVISION": "rev-privacy", + "HF_RUNTIME": "pt", + "RUNTIME": "transformers", }, metadata={}, request_data={"metadata": {}, "parameters": {"text": "example text"}}, @@ -73,6 +78,14 @@ def test_build_result_from_inference_uses_findings_key(self): self.assertEqual(result_payload["findings_count"], 1) self.assertEqual(result_payload["model_name"], "openai/privacy-filter") self.assertEqual(result_payload["pipeline_task"], "token-classification") + self.assertEqual( + result_payload["model"], + {"model_key": "privacy_filter", "model_version": "2026.05.09"}, + ) + self.assertEqual(result_payload["model_version"], "2026.05.09") + self.assertEqual(result_payload["model_revision"], "rev-privacy") + self.assertEqual(result_payload["hf_runtime"], "pt") + self.assertEqual(result_payload["runtime"], "transformers") if __name__ == "__main__": diff --git a/extensions/serving/default_inference/nlp/th_hf_model_base.py b/extensions/serving/default_inference/nlp/th_hf_model_base.py index db935025..bca48e19 100644 --- a/extensions/serving/default_inference/nlp/th_hf_model_base.py +++ b/extensions/serving/default_inference/nlp/th_hf_model_base.py @@ -179,14 +179,16 @@ def _build_output_map(self, raw_outputs, output_names): for output_name, output_value in zip(output_names, raw_outputs) } - def _call_decoder(self, outputs_by_name, text): + def _call_decoder(self, outputs_by_name, text, inference_kwargs): if self.decoder is None: return outputs_by_name decoder_kwargs = { + **dict(inference_kwargs or {}), "runtime": self.runtime_key, "runtime_key": self.runtime_key, "text": text, "repo_id": self.repo_id, + "inference_kwargs": dict(inference_kwargs or {}), } try: signature = inspect.signature(self.decoder) @@ -209,7 +211,11 @@ def _run_single_text(self, text, inference_kwargs): output_names = self._output_names() raw_outputs = self.session.run(output_names, session_inputs) outputs_by_name = self._build_output_map(raw_outputs, output_names) - return self._call_decoder(outputs_by_name=outputs_by_name, text=text) + return self._call_decoder( + outputs_by_name=outputs_by_name, + text=text, + inference_kwargs=inference_kwargs, + ) class ThHfModelBase(BaseServingProcess): @@ -545,12 +551,13 @@ def _blocked_hf_weight_pattern(self, pattern): pattern = str(pattern) blocked_suffixes = ( ".safetensors", - "pytorch_model.bin", - "tf_model.h5", - "flax_model.msgpack", + ".bin", + ".h5", + ".msgpack", ) - blocked_wildcards = ("*.safetensors", "*.bin", "*.h5", "*.msgpack") - return pattern.endswith(blocked_suffixes) or pattern in blocked_wildcards + blocked_wildcards = ("*", "**/*", "*.safetensors", "*.bin", "*.h5", "*.msgpack") + blocked_directory_globs = pattern.endswith("/*") or pattern.endswith("/**") + return pattern.endswith(blocked_suffixes) or pattern in blocked_wildcards or blocked_directory_globs def _build_hf_runtime_allow_patterns(self, runtime_config): """Build safe HF snapshot allow-patterns for an ONNX runtime.""" @@ -565,8 +572,6 @@ def _build_hf_runtime_allow_patterns(self, runtime_config): model_file, "*.onnx", "**/*.onnx", - "onnx/*", - "onnx/**", "*.json", "*.py", "*.txt", @@ -606,6 +611,19 @@ def _runtime_file_list(self, runtime_config): files = runtime_config.get("files") if isinstance(runtime_config, dict) else None return files if isinstance(files, list) else [] + def _resolve_hf_snapshot_path(self, model_dir, file_path): + """Resolve a manifest path while keeping it inside the downloaded snapshot.""" + path = Path(str(file_path)) + if path.is_absolute(): + raise ValueError(f"HF artifact path {file_path!r} must be relative to the model snapshot.") + snapshot_dir = Path(model_dir).resolve() + resolved_path = (snapshot_dir / path).resolve() + try: + resolved_path.relative_to(snapshot_dir) + except ValueError as exc: + raise ValueError(f"HF artifact path {file_path!r} escapes the model snapshot.") from exc + return resolved_path + def _first_manifest_file_with_suffix(self, runtime_config, suffixes): """Return the first exact manifest file path ending with any suffix.""" for file_path in self._runtime_file_list(runtime_config): @@ -621,10 +639,10 @@ def _resolve_manifest_file_path(self, model_dir, manifest, runtime_config, keys, if value is None and isinstance(manifest, dict): value = manifest.get(key) if value: - return Path(model_dir) / str(value) + return self._resolve_hf_snapshot_path(model_dir=model_dir, file_path=value) inferred = self._first_manifest_file_with_suffix(runtime_config, suffixes) if inferred: - return Path(model_dir) / inferred + return self._resolve_hf_snapshot_path(model_dir=model_dir, file_path=inferred) return None def _load_hf_schema(self, model_dir, manifest, runtime_config): @@ -640,6 +658,13 @@ def _load_hf_schema(self, model_dir, manifest, runtime_config): raise ValueError(f"HF runtime {self.hf_runtime} does not declare a usable schema file.") return json.loads(schema_path.read_text(encoding="utf-8")) + def _runtime_allows_remote_code(self, manifest, runtime_config): + """Return whether the selected runtime explicitly allows Python artifact code.""" + for source in (runtime_config, manifest): + if isinstance(source, dict) and "trust_remote_code" in source: + return bool(source.get("trust_remote_code")) + return False + def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): """Load the artifact decoder function declared by the selected HF runtime.""" decoder_path = self._resolve_manifest_file_path( @@ -651,6 +676,14 @@ def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): ) if decoder_path is None or not decoder_path.exists(): raise ValueError(f"HF runtime {self.hf_runtime} does not declare a usable contract decoder.") + if not bool(self.cfg_trust_remote_code) or not self._runtime_allows_remote_code( + manifest=manifest, + runtime_config=runtime_config, + ): + raise ValueError( + "HF ONNX artifact decoder requires global TRUST_REMOTE_CODE=True and runtime " + f"trust_remote_code=True because it executes Python code from {decoder_path}." + ) module_name = f"hf_artifact_contract_{abs(hash(str(decoder_path)))}" spec = importlib.util.spec_from_file_location(module_name, decoder_path) if spec is None or spec.loader is None: @@ -684,7 +717,7 @@ def _resolve_hf_onnx_model_path(self, model_dir, runtime_key, runtime_config, sc for key in ("model", "model_file", "path"): value = runtime_config.get(key) if isinstance(runtime_config, dict) else None if value: - return Path(model_dir) / str(value) + return self._resolve_hf_snapshot_path(model_dir=model_dir, file_path=value) models = schema.get("models") if isinstance(schema, dict) else None if isinstance(models, dict): candidates = [ @@ -699,10 +732,10 @@ def _resolve_hf_onnx_model_path(self, model_dir, runtime_key, runtime_config, sc value = value.get("path") or value.get("file") or value.get("model") if not value: continue - return Path(model_dir) / str(value) + return self._resolve_hf_snapshot_path(model_dir=model_dir, file_path=value) model_file = self._first_manifest_file_with_suffix(runtime_config, (".onnx",)) if model_file: - return Path(model_dir) / model_file + return self._resolve_hf_snapshot_path(model_dir=model_dir, file_path=model_file) raise ValueError(f"HF runtime {runtime_key} does not declare an ONNX model file.") def _resolve_hf_tokenizer_dir(self, model_dir, manifest, runtime_config, schema): @@ -712,16 +745,19 @@ def _resolve_hf_tokenizer_dir(self, model_dir, manifest, runtime_config, schema) if isinstance(source, dict) and source.get("tokenizer_dir"): tokenizer_dir = source["tokenizer_dir"] break - return Path(model_dir) / str(tokenizer_dir or ".") + return self._resolve_hf_snapshot_path(model_dir=model_dir, file_path=tokenizer_dir or ".") - def _load_hf_onnx_tokenizer(self, model_dir, runtime_config): + def _load_hf_onnx_tokenizer(self, model_dir, runtime_config, manifest=None): """Load the tokenizer for an ONNX HF artifact.""" from transformers import AutoTokenizer return AutoTokenizer.from_pretrained( str(model_dir), token=self.hf_token, - trust_remote_code=bool(runtime_config.get("trust_remote_code", False)), + trust_remote_code=bool(self.cfg_trust_remote_code) and self._runtime_allows_remote_code( + manifest=manifest, + runtime_config=runtime_config, + ), ) def _create_hf_onnx_session(self, model_path, providers): @@ -751,6 +787,7 @@ def _build_hf_onnx_artifact_pipeline(self, model_dir, runtime_key, runtime_confi tokenizer = self._load_hf_onnx_tokenizer( model_dir=tokenizer_dir, runtime_config=runtime_config, + manifest=manifest, ) model_path = self._resolve_hf_onnx_model_path( model_dir=model_dir, @@ -889,15 +926,32 @@ def startup(self): self.hf_runtime_config = dict(runtime_config or {}) self.hf_artifact_manifest = manifest if isinstance(manifest, dict) else None + run_warmup = True if self._runtime_is_onnx(runtime_key=runtime_key, runtime_config=runtime_config): - self._startup_hf_onnx_artifact( - runtime_key=runtime_key, - runtime_config=runtime_config, - manifest=manifest, - ) + try: + self._startup_hf_onnx_artifact( + runtime_key=runtime_key, + runtime_config=runtime_config, + manifest=manifest, + ) + self._run_startup_warmup() + run_warmup = False + except Exception as exc: + if requested_runtime != "auto": + raise + self.P( + f"HF auto runtime could not start ONNX artifact {runtime_key!r} for " + f"{self.get_model_name()}: {exc}. Falling back to Transformers/PT.", + color="y", + ) + self.hf_runtime = "pt" + self.hf_runtime_config = {} + self.hf_artifact_manifest = None + self._startup_transformers_pipeline() else: self._startup_transformers_pipeline() - self._run_startup_warmup() + if run_warmup: + self._run_startup_warmup() return def _get_hf_artifact_model_metadata(self): diff --git a/extensions/serving/test_th_hf_model_base.py b/extensions/serving/test_th_hf_model_base.py index 6692bc8e..f541a9bb 100644 --- a/extensions/serving/test_th_hf_model_base.py +++ b/extensions/serving/test_th_hf_model_base.py @@ -1,3 +1,4 @@ +import sys import types import unittest @@ -282,6 +283,45 @@ def test_forced_pt_runtime_passes_model_revision_to_transformers_pipeline(self): self.assertEqual(kwargs["revision"], "rev-123") self.assertEqual(plugin.hf_runtime, "pt") + def test_forced_pt_runtime_on_cpu_skips_manifest_lookup(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + HF_RUNTIME="pt", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + plugin._load_hf_artifact_manifest = ( # pylint: disable=protected-access + lambda: (_ for _ in ()).throw(AssertionError("manifest should not be loaded")) + ) + + plugin.startup() + + self.assertEqual(plugin.device, -1) + self.assertEqual(plugin.hf_runtime, "pt") + self.assertEqual(len(_PIPELINE_FACTORY.calls), 1) + + def test_onnx_allow_patterns_reject_framework_weights_and_broad_downloads(self): + plugin = _ConcreteHfModel(MODEL_NAME="test/model") + + allow_patterns = plugin._build_hf_runtime_allow_patterns({ # pylint: disable=protected-access + "files": [ + "*", + "**/*", + "onnx/*", + "onnx/**", + "model.onnx", + "tokenizer.json", + "contract.py", + "pytorch_model-00001-of-00002.bin", + "model.safetensors", + "tf_model.h5", + "flax_model.msgpack", + ], + }) + + self.assertEqual(allow_patterns, ["model.onnx", "tokenizer.json", "contract.py"]) + def test_auto_runtime_uses_onnx_artifact_on_cpu_only(self): manifest = { "model_key": "generic_text_classifier", @@ -379,6 +419,259 @@ def test_forced_onnx_runtime_uses_manifest_runtime_without_hardcoded_key(self): self.assertEqual(plugin.hf_runtime, "cpu_artifact") self.assertEqual(len(_PIPELINE_FACTORY.calls), 0) + def test_auto_runtime_falls_back_to_transformers_when_onnx_startup_fails(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + plugin._load_hf_artifact_manifest = lambda: { # pylint: disable=protected-access + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "entrypoint": "onnxruntime.InferenceSession", + "files": ["model.onnx", "schema.json", "contract.py"], + } + }, + } + + def fail_onnx_startup(runtime_key, runtime_config, manifest): # pylint: disable=unused-argument + raise RuntimeError("onnxruntime is not installed") + + plugin._startup_hf_onnx_artifact = fail_onnx_startup # pylint: disable=protected-access + + plugin.startup() + + self.assertEqual(plugin.hf_runtime, "pt") + self.assertEqual(plugin.hf_runtime_config, {}) + self.assertIsNone(plugin.hf_artifact_manifest) + self.assertEqual(len(_PIPELINE_FACTORY.calls), 1) + self.assertTrue( + any("Falling back to Transformers/PT" in message[0][0] for message in plugin.logged_messages) + ) + + def test_forced_onnx_runtime_does_not_fallback_after_startup_failure(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + HF_RUNTIME="onnx", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + plugin._load_hf_artifact_manifest = lambda: { # pylint: disable=protected-access + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "entrypoint": "onnxruntime.InferenceSession", + "files": ["model.onnx", "schema.json", "contract.py"], + } + }, + } + plugin._startup_hf_onnx_artifact = ( # pylint: disable=protected-access + lambda runtime_key, runtime_config, manifest: (_ for _ in ()).throw(RuntimeError("bad onnx")) + ) + + with self.assertRaisesRegex(RuntimeError, "bad onnx"): + plugin.startup() + + self.assertEqual(len(_PIPELINE_FACTORY.calls), 0) + + def test_named_onnx_runtime_does_not_fallback_after_startup_failure(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + HF_RUNTIME="onnx_fp32", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + plugin._load_hf_artifact_manifest = lambda: { # pylint: disable=protected-access + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "entrypoint": "onnxruntime.InferenceSession", + "files": ["model.onnx", "schema.json", "contract.py"], + } + }, + } + plugin._startup_hf_onnx_artifact = ( # pylint: disable=protected-access + lambda runtime_key, runtime_config, manifest: (_ for _ in ()).throw(RuntimeError("bad named onnx")) + ) + + with self.assertRaisesRegex(RuntimeError, "bad named onnx"): + plugin.startup() + + self.assertEqual(len(_PIPELINE_FACTORY.calls), 0) + + def test_auto_runtime_falls_back_to_transformers_when_onnx_warmup_fails(self): + class _FailingWarmupPipeline: + task = "text-classification" + schema = {} + + def __call__(self, text, **kwargs): # pylint: disable=unused-argument + raise RuntimeError("onnx warmup failed") + + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + PIPELINE_TASK="text-classification", + ) + plugin._load_hf_artifact_manifest = lambda: { # pylint: disable=protected-access + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "entrypoint": "onnxruntime.InferenceSession", + "files": ["model.onnx", "schema.json", "contract.py"], + } + }, + } + + def set_failing_pipeline(runtime_key, runtime_config, manifest): # pylint: disable=unused-argument + plugin.classifier = _FailingWarmupPipeline() + return + + plugin._startup_hf_onnx_artifact = set_failing_pipeline # pylint: disable=protected-access + + plugin.startup() + + self.assertEqual(plugin.hf_runtime, "pt") + self.assertEqual(len(_PIPELINE_FACTORY.calls), 1) + self.assertEqual(_PIPELINE_FACTORY.instance.inference_calls[-1][0], "Warmup request.") + self.assertTrue( + any("Falling back to Transformers/PT" in message[0][0] for message in plugin.logged_messages) + ) + + def test_hf_contract_decoder_requires_global_trust_remote_code(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + TRUST_REMOTE_CODE=False, + ) + plugin.hf_runtime = "onnx_fp32" + + with TemporaryDirectory() as tmpdir: + model_dir = Path(tmpdir) + (model_dir / "contract.py").write_text( + "def decode_outputs(outputs, schema):\n return outputs\n", + encoding="utf-8", + ) + + with self.assertRaisesRegex(ValueError, "TRUST_REMOTE_CODE=True"): + plugin._load_hf_contract_decoder( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"decoder": "contract.py"}, + ) + + def test_hf_contract_decoder_requires_runtime_trust_remote_code(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + TRUST_REMOTE_CODE=True, + ) + plugin.hf_runtime = "onnx_fp32" + + with TemporaryDirectory() as tmpdir: + model_dir = Path(tmpdir) + (model_dir / "contract.py").write_text( + "def decode_outputs(outputs, schema):\n return outputs\n", + encoding="utf-8", + ) + + with self.assertRaisesRegex(ValueError, "runtime trust_remote_code=True"): + plugin._load_hf_contract_decoder( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"decoder": "contract.py", "trust_remote_code": False}, + ) + + def test_hf_artifact_paths_must_stay_inside_snapshot(self): + plugin = _ConcreteHfModel(MODEL_NAME="test/model") + plugin.hf_runtime = "onnx_fp32" + + with TemporaryDirectory() as tmpdir: + model_dir = Path(tmpdir) + (model_dir / "schema.json").write_text("{}", encoding="utf-8") + + with self.assertRaisesRegex(ValueError, "escapes the model snapshot"): + plugin._load_hf_schema( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"schema": "../schema.json"}, + ) + + with self.assertRaisesRegex(ValueError, "must be relative"): + plugin._load_hf_contract_decoder( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"decoder": str((model_dir / "contract.py").resolve())}, + ) + + with self.assertRaisesRegex(ValueError, "escapes the model snapshot"): + plugin._resolve_hf_onnx_model_path( # pylint: disable=protected-access + model_dir=str(model_dir), + runtime_key="onnx_fp32", + runtime_config={"model": "../model.onnx"}, + schema={}, + ) + + with self.assertRaisesRegex(ValueError, "escapes the model snapshot"): + plugin._resolve_hf_tokenizer_dir( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"tokenizer_dir": "../tokenizer"}, + schema={}, + ) + + def test_onnx_tokenizer_remote_code_requires_global_trust_remote_code(self): + calls = [] + + class _FakeAutoTokenizer: + @staticmethod + def from_pretrained(model_dir, **kwargs): + calls.append((model_dir, kwargs)) + return _FakeTokenizer() + + fake_transformers = types.SimpleNamespace(AutoTokenizer=_FakeAutoTokenizer) + original_transformers = sys.modules.get("transformers") + sys.modules["transformers"] = fake_transformers + try: + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + TRUST_REMOTE_CODE=False, + ) + plugin._load_hf_onnx_tokenizer( # pylint: disable=protected-access + model_dir="/tmp/model", + runtime_config={"trust_remote_code": True}, + ) + + self.assertFalse(calls[-1][1]["trust_remote_code"]) + + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + TRUST_REMOTE_CODE=True, + ) + plugin._load_hf_onnx_tokenizer( # pylint: disable=protected-access + model_dir="/tmp/model", + runtime_config={"trust_remote_code": True}, + ) + + self.assertTrue(calls[-1][1]["trust_remote_code"]) + + plugin._load_hf_onnx_tokenizer( # pylint: disable=protected-access + model_dir="/tmp/model", + runtime_config={"trust_remote_code": False}, + ) + + self.assertFalse(calls[-1][1]["trust_remote_code"]) + finally: + if original_transformers is None: + sys.modules.pop("transformers", None) + else: + sys.modules["transformers"] = original_transformers + def test_onnx_artifact_pipeline_uses_hf_contract_decoder(self): plugin = _ConcreteHfModel( MODEL_NAME="test/model", @@ -389,7 +682,9 @@ def test_onnx_artifact_pipeline_uses_hf_contract_decoder(self): fake_tokenizer = _FakeTokenizer() fake_session = _FakeOrtSession() created_sessions = [] - plugin._load_hf_onnx_tokenizer = lambda model_dir, runtime_config: fake_tokenizer # pylint: disable=protected-access + plugin._load_hf_onnx_tokenizer = ( # pylint: disable=protected-access + lambda model_dir, runtime_config, manifest=None: fake_tokenizer + ) def fake_create_session(model_path, providers): created_sessions.append((model_path, providers)) @@ -411,12 +706,14 @@ def fake_create_session(model_path, providers): ) (model_dir / "contract.py").write_text( ( - "def decode_generic_outputs(outputs, schema, **kwargs):\n" + "def decode_generic_outputs(outputs, schema, aggregation_strategy=None, inference_kwargs=None, **kwargs):\n" " return {\n" " 'contract': 'hf',\n" " 'outputs': outputs,\n" " 'repo_id': kwargs.get('repo_id'),\n" " 'runtime': kwargs.get('runtime_key'),\n" + " 'aggregation_strategy': aggregation_strategy,\n" + " 'inference_kwargs': inference_kwargs,\n" " }\n" ), encoding="utf-8", @@ -426,6 +723,7 @@ def fake_create_session(model_path, providers): "runtimes": { "onnx_fp32": { "runtime": "onnxruntime", + "trust_remote_code": True, "files": ["model.onnx", "schema.json", "contract.py"], } }, @@ -437,7 +735,7 @@ def fake_create_session(model_path, providers): runtime_config=manifest["runtimes"]["onnx_fp32"], manifest=manifest, ) - result = pipeline("hello world") + result = pipeline("hello world", aggregation_strategy="simple", threshold=0.7) batched_single_result = pipeline(["hello world"]) self.assertIsInstance(pipeline, HfOnnxArtifactPipeline) @@ -446,6 +744,8 @@ def fake_create_session(model_path, providers): self.assertEqual(result["outputs"], {"scores": [0.25, 0.75]}) self.assertEqual(result["repo_id"], "test/model") self.assertEqual(result["runtime"], "onnx_fp32") + self.assertEqual(result["aggregation_strategy"], "simple") + self.assertEqual(result["inference_kwargs"]["threshold"], 0.7) self.assertEqual(Path(created_sessions[0][0]).name, "model.onnx") output_names, inputs = fake_session.calls[-1] self.assertEqual(output_names, ["scores"]) From 9c317ade87013b0096d013fa7814333c83972dbd Mon Sep 17 00:00:00 2001 From: Cristi Bleotiu Date: Mon, 11 May 2026 12:58:48 +0300 Subject: [PATCH 3/8] fix: require runtime trust for hf onnx code What changed: - Require selected ONNX runtime config trust_remote_code=True before executing artifact decoder or tokenizer remote code. - Add regression coverage proving a top-level manifest trust flag cannot enable runtime code execution by itself. Why: - Avoid remote-code trust bypasses from broad manifest metadata; the selected runtime must explicitly opt in. --- .../default_inference/nlp/th_hf_model_base.py | 5 +---- extensions/serving/test_th_hf_model_base.py | 22 +++++++++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/extensions/serving/default_inference/nlp/th_hf_model_base.py b/extensions/serving/default_inference/nlp/th_hf_model_base.py index bca48e19..aabc9af9 100644 --- a/extensions/serving/default_inference/nlp/th_hf_model_base.py +++ b/extensions/serving/default_inference/nlp/th_hf_model_base.py @@ -660,10 +660,7 @@ def _load_hf_schema(self, model_dir, manifest, runtime_config): def _runtime_allows_remote_code(self, manifest, runtime_config): """Return whether the selected runtime explicitly allows Python artifact code.""" - for source in (runtime_config, manifest): - if isinstance(source, dict) and "trust_remote_code" in source: - return bool(source.get("trust_remote_code")) - return False + return isinstance(runtime_config, dict) and bool(runtime_config.get("trust_remote_code")) def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): """Load the artifact decoder function declared by the selected HF runtime.""" diff --git a/extensions/serving/test_th_hf_model_base.py b/extensions/serving/test_th_hf_model_base.py index f541a9bb..da73e9a4 100644 --- a/extensions/serving/test_th_hf_model_base.py +++ b/extensions/serving/test_th_hf_model_base.py @@ -585,6 +585,28 @@ def test_hf_contract_decoder_requires_runtime_trust_remote_code(self): runtime_config={"decoder": "contract.py", "trust_remote_code": False}, ) + def test_top_level_manifest_trust_remote_code_does_not_enable_runtime_decoder(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + TRUST_REMOTE_CODE=True, + ) + plugin.hf_runtime = "onnx_fp32" + + with TemporaryDirectory() as tmpdir: + model_dir = Path(tmpdir) + (model_dir / "contract.py").write_text( + "def decode_outputs(outputs, schema):\n return outputs\n", + encoding="utf-8", + ) + + with self.assertRaisesRegex(ValueError, "runtime trust_remote_code=True"): + plugin._load_hf_contract_decoder( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={"trust_remote_code": True}, + runtime_config={"decoder": "contract.py"}, + ) + def test_hf_artifact_paths_must_stay_inside_snapshot(self): plugin = _ConcreteHfModel(MODEL_NAME="test/model") plugin.hf_runtime = "onnx_fp32" From 9fd9a16a724e537d83941980d22086479091ca7d Mon Sep 17 00:00:00 2001 From: Codex Date: Mon, 11 May 2026 17:33:40 +0300 Subject: [PATCH 4/8] feat: support privacy-filter onnx fallback What changed: - Added subclass ONNX fallback hooks in the HF serving base. - Added local privacy-filter ONNX discovery and BIOES/Viterbi span decoding. - Covered fallback runtime selection and privacy-filter decoder behavior with tests. Why: - Allow openai/privacy-filter ONNX artifacts to run without a remote artifact manifest or remote Python decoder code. --- .../default_inference/nlp/th_hf_model_base.py | 51 ++- .../nlp/th_privacy_filter.py | 348 ++++++++++++++++++ extensions/serving/test_th_hf_model_base.py | 70 ++++ extensions/serving/test_th_privacy_filter.py | 131 +++++++ 4 files changed, 596 insertions(+), 4 deletions(-) diff --git a/extensions/serving/default_inference/nlp/th_hf_model_base.py b/extensions/serving/default_inference/nlp/th_hf_model_base.py index aabc9af9..951469a7 100644 --- a/extensions/serving/default_inference/nlp/th_hf_model_base.py +++ b/extensions/serving/default_inference/nlp/th_hf_model_base.py @@ -105,6 +105,10 @@ def _tokenize(self, text, inference_kwargs): "return_tensors": "np", "truncation": bool(inference_kwargs.get("truncation", True)), } + for source in (self.schema, self.runtime_config): + extra_tokenize_kwargs = source.get("tokenizer_kwargs") if isinstance(source, dict) else None + if isinstance(extra_tokenize_kwargs, dict): + tokenize_kwargs.update(extra_tokenize_kwargs) max_length = self._get_max_length(inference_kwargs) if max_length is not None: tokenize_kwargs["max_length"] = max_length @@ -179,7 +183,7 @@ def _build_output_map(self, raw_outputs, output_names): for output_name, output_value in zip(output_names, raw_outputs) } - def _call_decoder(self, outputs_by_name, text, inference_kwargs): + def _call_decoder(self, outputs_by_name, text, encoded, inference_kwargs): if self.decoder is None: return outputs_by_name decoder_kwargs = { @@ -188,6 +192,8 @@ def _call_decoder(self, outputs_by_name, text, inference_kwargs): "runtime_key": self.runtime_key, "text": text, "repo_id": self.repo_id, + "tokenizer_output": encoded, + "encoded": encoded, "inference_kwargs": dict(inference_kwargs or {}), } try: @@ -214,6 +220,7 @@ def _run_single_text(self, text, inference_kwargs): return self._call_decoder( outputs_by_name=outputs_by_name, text=text, + encoded=encoded, inference_kwargs=inference_kwargs, ) @@ -477,15 +484,32 @@ def _download_hf_artifact_file(self, filename): repo_type="model", ) + def _get_hf_onnx_fallback_manifest(self): + """Return a subclass-provided ONNX manifest when the repo has no manifest. + + This hook lets dedicated serving classes support standard HF ONNX layouts + without requiring remote Python artifact code or model-specific logic in + the shared base class. + """ + return None + def _load_hf_artifact_manifest(self): """Load the optional artifact manifest from the configured HF model repo.""" manifest_name = getattr(self, "cfg_hf_artifact_manifest", None) if not manifest_name: - return None + return self._get_hf_onnx_fallback_manifest() try: manifest_path = self._download_hf_artifact_file(manifest_name) return json.loads(Path(manifest_path).read_text(encoding="utf-8")) except Exception as exc: + fallback_manifest = self._get_hf_onnx_fallback_manifest() + if isinstance(fallback_manifest, dict): + self.P( + f"HF artifact manifest {manifest_name} not available for {self.get_model_name()}; " + "using subclass ONNX fallback manifest.", + color="y", + ) + return fallback_manifest if self._requested_hf_runtime() != "auto": raise self.P( @@ -647,6 +671,9 @@ def _resolve_manifest_file_path(self, model_dir, manifest, runtime_config, keys, def _load_hf_schema(self, model_dir, manifest, runtime_config): """Load the JSON schema declared by the selected HF runtime.""" + inline_schema = runtime_config.get("inline_schema") if isinstance(runtime_config, dict) else None + if isinstance(inline_schema, dict): + return inline_schema schema_path = self._resolve_manifest_file_path( model_dir=model_dir, manifest=manifest, @@ -709,6 +736,22 @@ def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): raise ValueError(f"Could not resolve a decoder function in {decoder_path}.") return decoder + def _get_hf_onnx_artifact_schema(self, model_dir, manifest, runtime_config): + """Return the schema used by an ONNX artifact runtime.""" + return self._load_hf_schema( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + ) + + def _get_hf_onnx_artifact_decoder(self, model_dir, manifest, runtime_config): + """Return the decoder used by an ONNX artifact runtime.""" + return self._load_hf_contract_decoder( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + ) + def _resolve_hf_onnx_model_path(self, model_dir, runtime_key, runtime_config, schema): """Resolve the ONNX model file for the selected runtime.""" for key in ("model", "model_file", "path"): @@ -765,12 +808,12 @@ def _create_hf_onnx_session(self, model_path, providers): def _build_hf_onnx_artifact_pipeline(self, model_dir, runtime_key, runtime_config, manifest): """Build a callable ONNX artifact pipeline from downloaded HF files.""" - schema = self._load_hf_schema( + schema = self._get_hf_onnx_artifact_schema( model_dir=model_dir, manifest=manifest, runtime_config=runtime_config, ) - decoder = self._load_hf_contract_decoder( + decoder = self._get_hf_onnx_artifact_decoder( model_dir=model_dir, manifest=manifest, runtime_config=runtime_config, diff --git a/extensions/serving/default_inference/nlp/th_privacy_filter.py b/extensions/serving/default_inference/nlp/th_privacy_filter.py index c4ce806a..e7a13ea3 100644 --- a/extensions/serving/default_inference/nlp/th_privacy_filter.py +++ b/extensions/serving/default_inference/nlp/th_privacy_filter.py @@ -7,6 +7,9 @@ - redaction-friendly post-processing metadata """ +import json +import math + from extensions.serving.default_inference.nlp.th_hf_model_base import ( _CONFIG as BASE_HF_MODEL_CONFIG, ThHfModelBase, @@ -30,11 +33,356 @@ FIXED_CENSOR_SIZE = 4 +PRIVACY_FILTER_ONNX_RUNTIME_KEY = "onnx_fp32" +PRIVACY_FILTER_ONNX_MODEL_FILE = "onnx/model.onnx" +PRIVACY_FILTER_VITERBI_FILE = "viterbi_calibration.json" class ThPrivacyFilter(ThHfModelBase): CONFIG = _CONFIG + def _get_hf_onnx_fallback_manifest(self): + """Declare the public HF ONNX layout when no artifact manifest exists.""" + if self.get_model_name() != "openai/privacy-filter": + return None + return { + "model_key": "openai_privacy_filter", + "source_repo_id": "openai/privacy-filter", + "pipeline_task": "token-classification", + "runtimes": { + PRIVACY_FILTER_ONNX_RUNTIME_KEY: { + "runtime": "onnxruntime", + "entrypoint": "onnxruntime.InferenceSession", + "pipeline_task": "token-classification", + "model": PRIVACY_FILTER_ONNX_MODEL_FILE, + "decoder_type": "privacy_filter_span_decoder", + "files": [ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + PRIVACY_FILTER_VITERBI_FILE, + PRIVACY_FILTER_ONNX_MODEL_FILE, + "onnx/model.onnx_data", + "onnx/model.onnx_data_1", + "onnx/model.onnx_data_2", + ], + "recommended_allow_patterns": [ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + PRIVACY_FILTER_VITERBI_FILE, + PRIVACY_FILTER_ONNX_MODEL_FILE, + "onnx/model.onnx_data", + "onnx/model.onnx_data_1", + "onnx/model.onnx_data_2", + ], + "providers": ["CPUExecutionProvider"], + }, + }, + } + + def _get_hf_onnx_artifact_schema(self, model_dir, manifest, runtime_config): + """Build a local schema for the privacy-filter ONNX artifacts.""" + if runtime_config.get("decoder_type") != "privacy_filter_span_decoder": + return super()._get_hf_onnx_artifact_schema( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + ) + config_path = self._resolve_hf_snapshot_path(model_dir=model_dir, file_path="config.json") + config = json.loads(config_path.read_text(encoding="utf-8")) + calibration = {} + calibration_path = self._resolve_hf_snapshot_path( + model_dir=model_dir, + file_path=PRIVACY_FILTER_VITERBI_FILE, + ) + if calibration_path.exists(): + calibration = json.loads(calibration_path.read_text(encoding="utf-8")) + return { + "inputs": [ + {"name": "input_ids", "dtype": "int64"}, + {"name": "attention_mask", "dtype": "int64"}, + ], + "outputs": [{"name": "logits"}], + "output_order": ["logits"], + "id2label": config.get("id2label", {}), + "tokenizer_kwargs": {"return_offsets_mapping": True}, + "viterbi_calibration": calibration, + } + + def _get_hf_onnx_artifact_decoder(self, model_dir, manifest, runtime_config): + """Use the local privacy-filter decoder instead of remote Python code.""" + if runtime_config.get("decoder_type") == "privacy_filter_span_decoder": + return self._decode_privacy_filter_onnx_outputs + return super()._get_hf_onnx_artifact_decoder( + model_dir=model_dir, + manifest=manifest, + runtime_config=runtime_config, + ) + + def _to_plain_list(self, value): + """Convert tensors/arrays to plain Python lists for decoder logic.""" + if hasattr(value, "tolist"): + return value.tolist() + return value + + def _first_batch_item(self, value): + """Return the first batch element from a tensor-like value.""" + value = self._to_plain_list(value) + if isinstance(value, list) and len(value) == 1 and isinstance(value[0], list): + return value[0] + return value + + def _get_tokenizer_field(self, tokenizer_output, field_name): + if not hasattr(tokenizer_output, "get"): + return None + return self._first_batch_item(tokenizer_output.get(field_name)) + + def _get_privacy_filter_id2label(self, schema): + raw_id2label = schema.get("id2label") if isinstance(schema, dict) else None + if not isinstance(raw_id2label, dict) or len(raw_id2label) == 0: + raise ValueError("Privacy-filter ONNX schema must provide id2label.") + labels_by_id = { + int(label_id): label + for label_id, label in raw_id2label.items() + } + return [ + labels_by_id[idx] + for idx in range(max(labels_by_id) + 1) + ] + + def _split_privacy_filter_label(self, label): + if not isinstance(label, str) or label == "O": + return "O", None + if "-" not in label: + return label, None + prefix, entity = label.split("-", 1) + return prefix, entity + + def _get_privacy_filter_transition_biases(self, schema): + calibration = schema.get("viterbi_calibration") if isinstance(schema, dict) else None + operating_points = calibration.get("operating_points") if isinstance(calibration, dict) else None + default_point = operating_points.get("default") if isinstance(operating_points, dict) else None + biases = default_point.get("biases") if isinstance(default_point, dict) else None + return biases if isinstance(biases, dict) else {} + + def _privacy_filter_transition_is_valid(self, previous_label, current_label): + current_prefix, current_entity = self._split_privacy_filter_label(current_label) + previous_prefix, previous_entity = self._split_privacy_filter_label(previous_label) + if previous_label is None: + return current_prefix in {"O", "B", "S"} + if previous_prefix in {"O", "E", "S"}: + return current_prefix in {"O", "B", "S"} + if previous_prefix in {"B", "I"}: + return current_prefix in {"I", "E"} and current_entity == previous_entity + return False + + def _privacy_filter_terminal_is_valid(self, label): + prefix, _entity = self._split_privacy_filter_label(label) + return prefix in {"O", "E", "S"} + + def _privacy_filter_transition_bias(self, previous_label, current_label, biases): + if previous_label is None: + return 0.0 + previous_prefix, previous_entity = self._split_privacy_filter_label(previous_label) + current_prefix, current_entity = self._split_privacy_filter_label(current_label) + if previous_prefix == "O" and current_prefix == "O": + return float(biases.get("transition_bias_background_stay", 0.0)) + if previous_prefix == "O" and current_prefix in {"B", "S"}: + return float(biases.get("transition_bias_background_to_start", 0.0)) + if previous_prefix in {"E", "S"} and current_prefix == "O": + return float(biases.get("transition_bias_end_to_background", 0.0)) + if previous_prefix in {"E", "S"} and current_prefix in {"B", "S"}: + return float(biases.get("transition_bias_end_to_start", 0.0)) + if ( + previous_prefix in {"B", "I"} + and current_prefix == "I" + and current_entity == previous_entity + ): + return float(biases.get("transition_bias_inside_to_continue", 0.0)) + if ( + previous_prefix in {"B", "I"} + and current_prefix == "E" + and current_entity == previous_entity + ): + return float(biases.get("transition_bias_inside_to_end", 0.0)) + return 0.0 + + def _softmax(self, values): + if not values: + return [] + max_value = max(values) + exps = [math.exp(value - max_value) for value in values] + total = sum(exps) + if total == 0: + return [0.0 for _ in values] + return [value / total for value in exps] + + def _decode_privacy_filter_label_ids(self, logits, labels, offsets, attention_mask, schema): + """Run constrained BIOES Viterbi decoding over token logits.""" + o_label_id = labels.index("O") if "O" in labels else 0 + biases = self._get_privacy_filter_transition_biases(schema) + previous_scores = None + backpointers = [] + selected_probabilities = [] + probabilities_by_token = [] + invalid_score = -1e9 + for token_idx, token_logits in enumerate(logits): + token_logits = [float(value) for value in token_logits] + probabilities_by_token.append(self._softmax(token_logits)) + is_content_token = True + if attention_mask is not None and token_idx < len(attention_mask): + is_content_token = bool(attention_mask[token_idx]) + if offsets is not None and token_idx < len(offsets): + start, end = offsets[token_idx] + if int(start) == int(end): + is_content_token = False + if not is_content_token: + token_logits = [ + 0.0 if label_idx == o_label_id else invalid_score + for label_idx, _label in enumerate(labels) + ] + current_scores = [] + current_backpointers = [] + for label_idx, label in enumerate(labels): + emission_score = token_logits[label_idx] + if previous_scores is None: + if self._privacy_filter_transition_is_valid(None, label): + current_scores.append(emission_score) + current_backpointers.append(None) + else: + current_scores.append(invalid_score) + current_backpointers.append(None) + continue + best_score = invalid_score + best_previous_idx = 0 + for previous_idx, previous_label in enumerate(labels): + if not self._privacy_filter_transition_is_valid(previous_label, label): + continue + score = ( + previous_scores[previous_idx] + + self._privacy_filter_transition_bias(previous_label, label, biases) + + emission_score + ) + if score > best_score: + best_score = score + best_previous_idx = previous_idx + current_scores.append(best_score) + current_backpointers.append(best_previous_idx) + previous_scores = current_scores + backpointers.append(current_backpointers) + if not previous_scores: + return [], [] + terminal_scores = [ + score if self._privacy_filter_terminal_is_valid(labels[idx]) else invalid_score + for idx, score in enumerate(previous_scores) + ] + if max(terminal_scores) > invalid_score: + previous_scores = terminal_scores + best_label_idx = max(range(len(previous_scores)), key=lambda idx: previous_scores[idx]) + label_ids = [] + for token_idx in range(len(backpointers) - 1, -1, -1): + label_ids.append(best_label_idx) + previous_idx = backpointers[token_idx][best_label_idx] + best_label_idx = previous_idx if previous_idx is not None else o_label_id + label_ids.reverse() + for token_idx, label_idx in enumerate(label_ids): + probabilities = probabilities_by_token[token_idx] + selected_probabilities.append(probabilities[label_idx] if label_idx < len(probabilities) else 0.0) + return label_ids, selected_probabilities + + def _build_privacy_filter_spans(self, text, labels, label_ids, probabilities, offsets): + spans = [] + current_span = None + for token_idx, label_id in enumerate(label_ids): + if offsets is None or token_idx >= len(offsets): + continue + start, end = offsets[token_idx] + start = int(start) + end = int(end) + if start == end: + continue + label = labels[label_id] + prefix, entity = self._split_privacy_filter_label(label) + token_score = probabilities[token_idx] if token_idx < len(probabilities) else 0.0 + if prefix == "O": + if current_span is not None: + spans.append(current_span) + current_span = None + continue + if prefix == "S": + if current_span is not None: + spans.append(current_span) + current_span = None + spans.append({ + "entity_group": entity, + "entity": entity, + "score": token_score, + "word": text[start:end], + "start": start, + "end": end, + }) + continue + if prefix == "B" or current_span is None or current_span["entity_group"] != entity: + if current_span is not None: + spans.append(current_span) + current_span = { + "entity_group": entity, + "entity": entity, + "score": token_score, + "word": text[start:end], + "start": start, + "end": end, + "_scores": [token_score], + } + if prefix == "E": + current_span["_scores"].append(token_score) + current_span["end"] = end + current_span["word"] = text[current_span["start"]:current_span["end"]] + spans.append(current_span) + current_span = None + continue + current_span["end"] = end + current_span["word"] = text[current_span["start"]:current_span["end"]] + current_span["_scores"].append(token_score) + current_span["score"] = sum(current_span["_scores"]) / len(current_span["_scores"]) + if prefix == "E": + spans.append(current_span) + current_span = None + if current_span is not None: + spans.append(current_span) + for span in spans: + span.pop("_scores", None) + return spans + + def _decode_privacy_filter_onnx_outputs(self, outputs, schema, text=None, tokenizer_output=None, **kwargs): + """Decode ONNX token logits into privacy-filter span dictionaries.""" + logits = outputs.get("logits") if isinstance(outputs, dict) else None + if logits is None and isinstance(outputs, dict) and outputs: + logits = next(iter(outputs.values())) + logits = self._first_batch_item(logits) + if not isinstance(logits, list): + raise ValueError("Privacy-filter ONNX decoder expected logits output.") + offsets = self._get_tokenizer_field(tokenizer_output, "offset_mapping") + if offsets is None: + raise ValueError("Privacy-filter ONNX decoder requires tokenizer offset_mapping.") + attention_mask = self._get_tokenizer_field(tokenizer_output, "attention_mask") + labels = self._get_privacy_filter_id2label(schema) + label_ids, probabilities = self._decode_privacy_filter_label_ids( + logits=logits, + labels=labels, + offsets=offsets, + attention_mask=attention_mask, + schema=schema, + ) + return self._build_privacy_filter_spans( + text=text or "", + labels=labels, + label_ids=label_ids, + probabilities=probabilities, + offsets=offsets, + ) + def _extract_struct_payload(self, payload): """Extract the structured payload used by the privacy filter. diff --git a/extensions/serving/test_th_hf_model_base.py b/extensions/serving/test_th_hf_model_base.py index da73e9a4..9e9db7dd 100644 --- a/extensions/serving/test_th_hf_model_base.py +++ b/extensions/serving/test_th_hf_model_base.py @@ -174,6 +174,22 @@ class _ConcreteHfModel(ThHfModelBase): pass +class _FallbackManifestHfModel(ThHfModelBase): + def _get_hf_onnx_fallback_manifest(self): + return { + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "files": ["model.onnx"], + "inline_schema": { + "inputs": [{"name": "input_ids", "dtype": "int64"}], + "outputs": [{"name": "scores"}], + }, + }, + }, + } + + class ThHfModelBaseTests(unittest.TestCase): def setUp(self): _PIPELINE_FACTORY.calls = [] @@ -367,6 +383,58 @@ def fake_download(runtime_key, runtime_config, allow_patterns): self.assertIn("model.onnx", download_calls[0][2]) self.assertNotIn("model.safetensors", download_calls[0][2]) + def test_auto_runtime_uses_subclass_onnx_fallback_manifest_when_hf_manifest_missing(self): + plugin = _FallbackManifestHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + download_calls = [] + plugin._download_hf_artifact_file = ( # pylint: disable=protected-access + lambda filename: (_ for _ in ()).throw(RuntimeError("not found")) + ) + plugin._download_hf_runtime_snapshot = ( # pylint: disable=protected-access + lambda runtime_key, runtime_config, allow_patterns: download_calls.append( + (runtime_key, runtime_config, allow_patterns) + ) or "/tmp/models/test-model" + ) + plugin._build_hf_onnx_artifact_pipeline = ( # pylint: disable=protected-access + lambda model_dir, runtime_key, runtime_config, manifest: _FakePipeline(task="text-classification") + ) + + plugin.startup() + + self.assertEqual(plugin.hf_runtime, "onnx_fp32") + self.assertEqual(download_calls[0][0], "onnx_fp32") + self.assertEqual(len(_PIPELINE_FACTORY.calls), 0) + self.assertTrue( + any("using subclass ONNX fallback manifest" in message[0][0] for message in plugin.logged_messages) + ) + + def test_forced_onnx_runtime_uses_subclass_fallback_manifest_when_hf_manifest_missing(self): + plugin = _FallbackManifestHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + HF_RUNTIME="onnx", + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + plugin._download_hf_artifact_file = ( # pylint: disable=protected-access + lambda filename: (_ for _ in ()).throw(RuntimeError("not found")) + ) + plugin._download_hf_runtime_snapshot = ( # pylint: disable=protected-access + lambda runtime_key, runtime_config, allow_patterns: "/tmp/models/test-model" + ) + plugin._build_hf_onnx_artifact_pipeline = ( # pylint: disable=protected-access + lambda model_dir, runtime_key, runtime_config, manifest: _FakePipeline(task="text-classification") + ) + + plugin.startup() + + self.assertEqual(plugin.hf_runtime, "onnx_fp32") + self.assertEqual(len(_PIPELINE_FACTORY.calls), 0) + def test_auto_runtime_keeps_transformers_pipeline_when_gpu_available(self): plugin = _ConcreteHfModel( MODEL_NAME="test/model", @@ -722,6 +790,7 @@ def fake_create_session(model_path, providers): '{"inputs":[{"name":"input_ids","dtype":"int64"},' '{"name":"attention_mask","dtype":"int64"}],' '"outputs":[{"name":"scores"}],' + '"tokenizer_kwargs":{"return_offsets_mapping":true},' '"models":{"onnx_fp32":{"path":"model.onnx"}}}' ), encoding="utf-8", @@ -773,6 +842,7 @@ def fake_create_session(model_path, providers): self.assertEqual(output_names, ["scores"]) self.assertEqual(inputs["input_ids"].dtype, "int64") self.assertEqual(fake_tokenizer.calls[-1][1]["return_tensors"], "np") + self.assertTrue(fake_tokenizer.calls[-1][1]["return_offsets_mapping"]) if __name__ == "__main__": diff --git a/extensions/serving/test_th_privacy_filter.py b/extensions/serving/test_th_privacy_filter.py index 266b308b..417a4a8b 100644 --- a/extensions/serving/test_th_privacy_filter.py +++ b/extensions/serving/test_th_privacy_filter.py @@ -92,6 +92,14 @@ def _payload_matches_current_serving(self, struct_payload): return False return True + def _resolve_hf_snapshot_path(self, model_dir, file_path): + path = Path(str(file_path)) + if path.is_absolute(): + raise ValueError("path must be relative") + resolved = (Path(model_dir).resolve() / path).resolve() + resolved.relative_to(Path(model_dir).resolve()) + return resolved + def _load_plugin_class(): source_path = ROOT / "extensions" / "serving" / "default_inference" / "nlp" / "th_privacy_filter.py" @@ -126,6 +134,129 @@ def test_config_pins_privacy_filter_defaults(self): "simple", ) + def test_privacy_filter_declares_local_onnx_fallback_manifest(self): + plugin = ThPrivacyFilter(MODEL_NAME="openai/privacy-filter") + + manifest = plugin._get_hf_onnx_fallback_manifest() # pylint: disable=protected-access + runtime = manifest["runtimes"]["onnx_fp32"] + + self.assertEqual(runtime["runtime"], "onnxruntime") + self.assertEqual(runtime["decoder_type"], "privacy_filter_span_decoder") + self.assertIn("onnx/model.onnx", runtime["files"]) + self.assertIn("onnx/model.onnx_data_2", runtime["files"]) + self.assertIn("viterbi_calibration.json", runtime["recommended_allow_patterns"]) + + def test_privacy_filter_does_not_declare_onnx_fallback_for_other_models(self): + plugin = ThPrivacyFilter(MODEL_NAME="other/privacy-filter") + + self.assertIsNone(plugin._get_hf_onnx_fallback_manifest()) # pylint: disable=protected-access + + def test_privacy_filter_builds_local_onnx_schema_from_hf_files(self): + plugin = ThPrivacyFilter(MODEL_NAME="openai/privacy-filter") + + from tempfile import TemporaryDirectory + + with TemporaryDirectory() as tmpdir: + model_dir = Path(tmpdir) + (model_dir / "config.json").write_text( + '{"id2label":{"0":"O","1":"S-private_email"}}', + encoding="utf-8", + ) + (model_dir / "viterbi_calibration.json").write_text( + '{"operating_points":{"default":{"biases":{"transition_bias_background_stay":0.0}}}}', + encoding="utf-8", + ) + + schema = plugin._get_hf_onnx_artifact_schema( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"decoder_type": "privacy_filter_span_decoder"}, + ) + + self.assertEqual(schema["output_order"], ["logits"]) + self.assertEqual(schema["id2label"]["1"], "S-private_email") + self.assertTrue(schema["tokenizer_kwargs"]["return_offsets_mapping"]) + self.assertIn("viterbi_calibration", schema) + + def test_privacy_filter_local_onnx_decoder_emits_spans_from_bioes_logits(self): + plugin = ThPrivacyFilter(MODEL_NAME="openai/privacy-filter") + schema = { + "id2label": { + "0": "O", + "1": "B-private_email", + "2": "I-private_email", + "3": "E-private_email", + "4": "S-private_email", + }, + "viterbi_calibration": { + "operating_points": { + "default": { + "biases": {}, + }, + }, + }, + } + outputs = { + "logits": [[ + [8.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 8.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 8.0, 0.0], + [8.0, 0.0, 0.0, 0.0, 0.0], + ]], + } + tokenizer_output = { + "offset_mapping": [[[0, 0], [0, 5], [5, 17], [0, 0]]], + "attention_mask": [[1, 1, 1, 1]], + } + + spans = plugin._decode_privacy_filter_onnx_outputs( # pylint: disable=protected-access + outputs=outputs, + schema=schema, + text="alice@example.com", + tokenizer_output=tokenizer_output, + ) + + self.assertEqual(len(spans), 1) + self.assertEqual(spans[0]["entity_group"], "private_email") + self.assertEqual(spans[0]["word"], "alice@example.com") + self.assertEqual(spans[0]["start"], 0) + self.assertEqual(spans[0]["end"], 17) + self.assertGreater(spans[0]["score"], 0.9) + + def test_privacy_filter_viterbi_decoder_rejects_invalid_terminal_inside_label(self): + plugin = ThPrivacyFilter(MODEL_NAME="openai/privacy-filter") + schema = { + "id2label": { + "0": "O", + "1": "B-private_email", + "2": "I-private_email", + "3": "E-private_email", + "4": "S-private_email", + }, + "viterbi_calibration": {}, + } + outputs = { + "logits": [[ + [0.0, 8.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 8.0, 7.5, 0.0], + ]], + } + tokenizer_output = { + "offset_mapping": [[[0, 5], [5, 17]]], + "attention_mask": [[1, 1]], + } + + spans = plugin._decode_privacy_filter_onnx_outputs( # pylint: disable=protected-access + outputs=outputs, + schema=schema, + text="alice@example.com", + tokenizer_output=tokenizer_output, + ) + + self.assertEqual(len(spans), 1) + self.assertEqual(spans[0]["entity_group"], "private_email") + self.assertEqual(spans[0]["end"], 17) + def test_post_process_emits_redaction_friendly_fields(self): plugin = ThPrivacyFilter() From fb95019730c45f4f0d6b1bbcc5016bde1da6d104 Mon Sep 17 00:00:00 2001 From: Codex Date: Mon, 11 May 2026 18:24:51 +0300 Subject: [PATCH 5/8] fix: support hf snapshot symlink artifacts What changed: - Keep HF artifact path traversal checks lexical so valid snapshot symlinks into the cache blob store are accepted. - Merge exact manifest files with recommended ONNX allow patterns after filtering broad or framework-weight downloads. - Add regression coverage for both behaviors. Why: - Live PR image validation showed Sentinel and privacy-filter ONNX startup falling back because valid HF snapshot files were rejected as escaping the snapshot. --- .../default_inference/nlp/th_hf_model_base.py | 32 ++++++++++++------- extensions/serving/test_th_hf_model_base.py | 31 +++++++++++++++++- 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/extensions/serving/default_inference/nlp/th_hf_model_base.py b/extensions/serving/default_inference/nlp/th_hf_model_base.py index 951469a7..50b4aac5 100644 --- a/extensions/serving/default_inference/nlp/th_hf_model_base.py +++ b/extensions/serving/default_inference/nlp/th_hf_model_base.py @@ -9,7 +9,7 @@ import importlib.util import inspect import json -from pathlib import Path +from pathlib import Path, PurePosixPath import torch as th @@ -589,11 +589,19 @@ def _build_hf_runtime_allow_patterns(self, runtime_config): if configured_patterns: patterns = configured_patterns else: - patterns = runtime_config.get("recommended_allow_patterns") or runtime_config.get("files") + patterns = [] + for source_patterns in ( + runtime_config.get("recommended_allow_patterns"), + runtime_config.get("files"), + [runtime_config.get("model")] if runtime_config.get("model") else None, + ): + if not source_patterns: + continue + if isinstance(source_patterns, str): + source_patterns = [source_patterns] + patterns.extend(source_patterns) if not patterns: - model_file = runtime_config.get("model") patterns = [ - model_file, "*.onnx", "**/*.onnx", "*.json", @@ -637,16 +645,16 @@ def _runtime_file_list(self, runtime_config): def _resolve_hf_snapshot_path(self, model_dir, file_path): """Resolve a manifest path while keeping it inside the downloaded snapshot.""" - path = Path(str(file_path)) + raw_path = str(file_path) + path = PurePosixPath(raw_path) if path.is_absolute(): raise ValueError(f"HF artifact path {file_path!r} must be relative to the model snapshot.") - snapshot_dir = Path(model_dir).resolve() - resolved_path = (snapshot_dir / path).resolve() - try: - resolved_path.relative_to(snapshot_dir) - except ValueError as exc: - raise ValueError(f"HF artifact path {file_path!r} escapes the model snapshot.") from exc - return resolved_path + if ".." in path.parts: + raise ValueError(f"HF artifact path {file_path!r} escapes the model snapshot.") + # Hugging Face snapshots commonly symlink files into the shared cache + # blob store. A resolved containment check would reject valid snapshots, + # so keep the traversal guard lexical and return the snapshot path itself. + return Path(model_dir) / Path(*path.parts) def _first_manifest_file_with_suffix(self, runtime_config, suffixes): """Return the first exact manifest file path ending with any suffix.""" diff --git a/extensions/serving/test_th_hf_model_base.py b/extensions/serving/test_th_hf_model_base.py index 9e9db7dd..0d7569a3 100644 --- a/extensions/serving/test_th_hf_model_base.py +++ b/extensions/serving/test_th_hf_model_base.py @@ -321,6 +321,10 @@ def test_onnx_allow_patterns_reject_framework_weights_and_broad_downloads(self): plugin = _ConcreteHfModel(MODEL_NAME="test/model") allow_patterns = plugin._build_hf_runtime_allow_patterns({ # pylint: disable=protected-access + "recommended_allow_patterns": [ + "onnx/*", + "schema.json", + ], "files": [ "*", "**/*", @@ -334,9 +338,13 @@ def test_onnx_allow_patterns_reject_framework_weights_and_broad_downloads(self): "tf_model.h5", "flax_model.msgpack", ], + "model": "model.onnx", }) - self.assertEqual(allow_patterns, ["model.onnx", "tokenizer.json", "contract.py"]) + self.assertEqual( + allow_patterns, + ["schema.json", "model.onnx", "tokenizer.json", "contract.py"], + ) def test_auto_runtime_uses_onnx_artifact_on_cpu_only(self): manifest = { @@ -713,6 +721,27 @@ def test_hf_artifact_paths_must_stay_inside_snapshot(self): schema={}, ) + def test_hf_artifact_paths_allow_snapshot_symlink_targets_outside_snapshot(self): + plugin = _ConcreteHfModel(MODEL_NAME="test/model") + plugin.hf_runtime = "onnx_fp32" + + with TemporaryDirectory() as tmpdir: + root_dir = Path(tmpdir) + model_dir = root_dir / "snapshot" + blob_dir = root_dir / "blobs" + model_dir.mkdir() + blob_dir.mkdir() + (blob_dir / "schema.json").write_text('{"inputs": []}', encoding="utf-8") + (model_dir / "schema.json").symlink_to(blob_dir / "schema.json") + + schema = plugin._load_hf_schema( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"schema": "schema.json"}, + ) + + self.assertEqual(schema, {"inputs": []}) + def test_onnx_tokenizer_remote_code_requires_global_trust_remote_code(self): calls = [] From d27b9e8429756293748c1de1f1e73afa1cfd4d1e Mon Sep 17 00:00:00 2001 From: Codex Date: Mon, 11 May 2026 20:15:58 +0300 Subject: [PATCH 6/8] fix: allow legacy trusted onnx decoders What changed: - Temporarily allow ONNX artifact decoders without runtime-level trust_remote_code to inherit global TRUST_REMOTE_CODE=True. - Keep explicit runtime trust_remote_code=False as a hard block. - Add a TODO documenting the security concern and declarative decoder replacement path. Why: - The current Sentinel ONNX artifact predates runtime-level trust metadata and uses a reviewed contract decoder, so it needs a compatibility path until the artifact moves to declarative decoding. --- .../default_inference/nlp/th_hf_model_base.py | 13 ++++++++++--- extensions/serving/test_th_hf_model_base.py | 17 +++++++++-------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/extensions/serving/default_inference/nlp/th_hf_model_base.py b/extensions/serving/default_inference/nlp/th_hf_model_base.py index 50b4aac5..a0605870 100644 --- a/extensions/serving/default_inference/nlp/th_hf_model_base.py +++ b/extensions/serving/default_inference/nlp/th_hf_model_base.py @@ -695,7 +695,14 @@ def _load_hf_schema(self, model_dir, manifest, runtime_config): def _runtime_allows_remote_code(self, manifest, runtime_config): """Return whether the selected runtime explicitly allows Python artifact code.""" - return isinstance(runtime_config, dict) and bool(runtime_config.get("trust_remote_code")) + if isinstance(runtime_config, dict) and "trust_remote_code" in runtime_config: + return bool(runtime_config.get("trust_remote_code")) + # TODO: replace this temporary compatibility path with declarative ONNX + # decoders (for example multihead_classification_v1) so artifact Python + # does not execute unless each runtime explicitly opts into remote code. + # This currently preserves legacy Sentinel ONNX artifacts whose decoder is + # a reviewed contract file but whose manifest predates runtime-level trust. + return bool(self.cfg_trust_remote_code) def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): """Load the artifact decoder function declared by the selected HF runtime.""" @@ -713,8 +720,8 @@ def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): runtime_config=runtime_config, ): raise ValueError( - "HF ONNX artifact decoder requires global TRUST_REMOTE_CODE=True and runtime " - f"trust_remote_code=True because it executes Python code from {decoder_path}." + "HF ONNX artifact decoder requires TRUST_REMOTE_CODE=True and no explicit " + f"runtime trust_remote_code=False because it executes Python code from {decoder_path}." ) module_name = f"hf_artifact_contract_{abs(hash(str(decoder_path)))}" spec = importlib.util.spec_from_file_location(module_name, decoder_path) diff --git a/extensions/serving/test_th_hf_model_base.py b/extensions/serving/test_th_hf_model_base.py index 0d7569a3..c5f8d30d 100644 --- a/extensions/serving/test_th_hf_model_base.py +++ b/extensions/serving/test_th_hf_model_base.py @@ -654,14 +654,14 @@ def test_hf_contract_decoder_requires_runtime_trust_remote_code(self): encoding="utf-8", ) - with self.assertRaisesRegex(ValueError, "runtime trust_remote_code=True"): + with self.assertRaisesRegex(ValueError, "runtime trust_remote_code=False"): plugin._load_hf_contract_decoder( # pylint: disable=protected-access model_dir=str(model_dir), manifest={}, runtime_config={"decoder": "contract.py", "trust_remote_code": False}, ) - def test_top_level_manifest_trust_remote_code_does_not_enable_runtime_decoder(self): + def test_missing_runtime_trust_remote_code_temporarily_inherits_global_trust(self): plugin = _ConcreteHfModel( MODEL_NAME="test/model", DEVICE="cpu", @@ -676,12 +676,13 @@ def test_top_level_manifest_trust_remote_code_does_not_enable_runtime_decoder(se encoding="utf-8", ) - with self.assertRaisesRegex(ValueError, "runtime trust_remote_code=True"): - plugin._load_hf_contract_decoder( # pylint: disable=protected-access - model_dir=str(model_dir), - manifest={"trust_remote_code": True}, - runtime_config={"decoder": "contract.py"}, - ) + decoder = plugin._load_hf_contract_decoder( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"decoder": "contract.py"}, + ) + + self.assertEqual(decoder({"ok": True}, {}), {"ok": True}) def test_hf_artifact_paths_must_stay_inside_snapshot(self): plugin = _ConcreteHfModel(MODEL_NAME="test/model") From 1026dcc5d189e6fc2265ff0bf28f4a12c2ecf3d3 Mon Sep 17 00:00:00 2001 From: Codex Date: Mon, 11 May 2026 20:20:52 +0300 Subject: [PATCH 7/8] fix: let global trust gate legacy onnx decoders What changed: - Split ONNX remote-code trust between tokenizer/model loading and decoder execution. - Keep tokenizer/model loading tied to runtime-level trust_remote_code. - Temporarily allow Python decoder execution when global TRUST_REMOTE_CODE=True, even for legacy runtimes that mark ONNX trust_remote_code=False. Why: - Current Sentinel ONNX artifacts use trust_remote_code=False for tokenizer/model loading but still declare a Python contract decoder. This keeps the temporary compatibility path narrow until declarative decoding replaces it. --- .../default_inference/nlp/th_hf_model_base.py | 17 +++++++++++------ extensions/serving/test_th_hf_model_base.py | 15 ++++++++------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/extensions/serving/default_inference/nlp/th_hf_model_base.py b/extensions/serving/default_inference/nlp/th_hf_model_base.py index a0605870..55d4b6d7 100644 --- a/extensions/serving/default_inference/nlp/th_hf_model_base.py +++ b/extensions/serving/default_inference/nlp/th_hf_model_base.py @@ -695,13 +695,18 @@ def _load_hf_schema(self, model_dir, manifest, runtime_config): def _runtime_allows_remote_code(self, manifest, runtime_config): """Return whether the selected runtime explicitly allows Python artifact code.""" - if isinstance(runtime_config, dict) and "trust_remote_code" in runtime_config: - return bool(runtime_config.get("trust_remote_code")) + return isinstance(runtime_config, dict) and bool(runtime_config.get("trust_remote_code")) + + def _runtime_allows_decoder_remote_code(self, manifest, runtime_config): + """Return whether the selected runtime may execute Python decoder code.""" # TODO: replace this temporary compatibility path with declarative ONNX # decoders (for example multihead_classification_v1) so artifact Python # does not execute unless each runtime explicitly opts into remote code. # This currently preserves legacy Sentinel ONNX artifacts whose decoder is - # a reviewed contract file but whose manifest predates runtime-level trust. + # a reviewed contract file but whose manifest marks the ONNX runtime as + # trust_remote_code=False because tokenizer/model loading does not need HF + # remote code. The decoder still executes Python, so this is intentionally + # gated by global TRUST_REMOTE_CODE and should be removed after repackaging. return bool(self.cfg_trust_remote_code) def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): @@ -715,13 +720,13 @@ def _load_hf_contract_decoder(self, model_dir, manifest, runtime_config): ) if decoder_path is None or not decoder_path.exists(): raise ValueError(f"HF runtime {self.hf_runtime} does not declare a usable contract decoder.") - if not bool(self.cfg_trust_remote_code) or not self._runtime_allows_remote_code( + if not bool(self.cfg_trust_remote_code) or not self._runtime_allows_decoder_remote_code( manifest=manifest, runtime_config=runtime_config, ): raise ValueError( - "HF ONNX artifact decoder requires TRUST_REMOTE_CODE=True and no explicit " - f"runtime trust_remote_code=False because it executes Python code from {decoder_path}." + "HF ONNX artifact decoder requires TRUST_REMOTE_CODE=True because it executes " + f"Python code from {decoder_path}." ) module_name = f"hf_artifact_contract_{abs(hash(str(decoder_path)))}" spec = importlib.util.spec_from_file_location(module_name, decoder_path) diff --git a/extensions/serving/test_th_hf_model_base.py b/extensions/serving/test_th_hf_model_base.py index c5f8d30d..ff017b5e 100644 --- a/extensions/serving/test_th_hf_model_base.py +++ b/extensions/serving/test_th_hf_model_base.py @@ -639,7 +639,7 @@ def test_hf_contract_decoder_requires_global_trust_remote_code(self): runtime_config={"decoder": "contract.py"}, ) - def test_hf_contract_decoder_requires_runtime_trust_remote_code(self): + def test_runtime_trust_remote_code_false_temporarily_inherits_global_trust(self): plugin = _ConcreteHfModel( MODEL_NAME="test/model", DEVICE="cpu", @@ -654,12 +654,13 @@ def test_hf_contract_decoder_requires_runtime_trust_remote_code(self): encoding="utf-8", ) - with self.assertRaisesRegex(ValueError, "runtime trust_remote_code=False"): - plugin._load_hf_contract_decoder( # pylint: disable=protected-access - model_dir=str(model_dir), - manifest={}, - runtime_config={"decoder": "contract.py", "trust_remote_code": False}, - ) + decoder = plugin._load_hf_contract_decoder( # pylint: disable=protected-access + model_dir=str(model_dir), + manifest={}, + runtime_config={"decoder": "contract.py", "trust_remote_code": False}, + ) + + self.assertEqual(decoder({"ok": True}, {}), {"ok": True}) def test_missing_runtime_trust_remote_code_temporarily_inherits_global_trust(self): plugin = _ConcreteHfModel( From 550fbadec10057218fcc2d40bbc05dd46c2bb672 Mon Sep 17 00:00:00 2001 From: Codex Date: Mon, 11 May 2026 20:30:20 +0300 Subject: [PATCH 8/8] fix: materialize hf onnx external data What changed: - Prepare HF ONNX artifacts in an edge-node-owned materialized cache before creating ONNX Runtime sessions. - Hardlink resolved HF cache blobs when possible and copy as fallback. - Preserve runtime relative layout for .onnx and external data sidecars. - Add regression coverage for symlinked external data files. Why: - ONNX Runtime rejects HF snapshot symlinks for external data because resolved sidecars can escape the model directory. --- .../default_inference/nlp/th_hf_model_base.py | 68 ++++++++++++++++++ extensions/serving/test_th_hf_model_base.py | 72 +++++++++++++++++++ 2 files changed, 140 insertions(+) diff --git a/extensions/serving/default_inference/nlp/th_hf_model_base.py b/extensions/serving/default_inference/nlp/th_hf_model_base.py index 55d4b6d7..5b530c42 100644 --- a/extensions/serving/default_inference/nlp/th_hf_model_base.py +++ b/extensions/serving/default_inference/nlp/th_hf_model_base.py @@ -9,6 +9,8 @@ import importlib.util import inspect import json +import os +import shutil from pathlib import Path, PurePosixPath import torch as th @@ -798,6 +800,65 @@ def _resolve_hf_onnx_model_path(self, model_dir, runtime_key, runtime_config, sc return self._resolve_hf_snapshot_path(model_dir=model_dir, file_path=model_file) raise ValueError(f"HF runtime {runtime_key} does not declare an ONNX model file.") + def _hf_onnx_materialized_root(self, model_dir, runtime_key): + """Return the local directory used for ORT-compatible ONNX artifacts.""" + snapshot_name = Path(model_dir).name + model_key = str(self.get_model_name()).replace("/", "--") + return Path(self.cache_dir) / "_onnx_materialized" / model_key / snapshot_name / str(runtime_key) + + def _materialize_hf_onnx_file(self, source_path, destination_path): + """Materialize one HF snapshot file as a real local file or hardlink.""" + source_path = Path(source_path) + destination_path = Path(destination_path) + destination_path.parent.mkdir(parents=True, exist_ok=True) + resolved_source = source_path.resolve() + if destination_path.exists(): + try: + if not destination_path.is_symlink() and destination_path.stat().st_size == resolved_source.stat().st_size: + return + except OSError: + pass + destination_path.unlink() + try: + os.link(resolved_source, destination_path) + except OSError: + shutil.copy2(resolved_source, destination_path) + return + + def _materialize_hf_onnx_artifact(self, model_dir, runtime_key, runtime_config, schema, model_path): + """Prepare an ONNX artifact outside the HF symlink snapshot for ORT.""" + root_dir = self._hf_onnx_materialized_root(model_dir=model_dir, runtime_key=runtime_key) + materialized_paths = [] + model_snapshot_dir = Path(model_dir) + model_relative_path = Path(model_path).relative_to(model_snapshot_dir) + file_paths = [] + for file_path in self._runtime_file_list(runtime_config): + file_path = str(file_path) + if file_path.endswith(".onnx") or ".onnx_data" in file_path or file_path.endswith(".onnx.data"): + file_paths.append(file_path) + if str(model_relative_path) not in file_paths: + file_paths.append(str(model_relative_path)) + for file_path in file_paths: + source_path = self._resolve_hf_snapshot_path(model_dir=model_dir, file_path=file_path) + if not source_path.exists(): + continue + destination_path = root_dir / Path(file_path) + self._materialize_hf_onnx_file( + source_path=source_path, + destination_path=destination_path, + ) + materialized_paths.append(destination_path) + materialized_model_path = root_dir / model_relative_path + if not materialized_model_path.exists(): + raise ValueError(f"Could not materialize ONNX model file {model_relative_path!s}.") + if materialized_paths: + self.P( + f"Materialized HF ONNX artifact {runtime_key} with {len(materialized_paths)} file(s) " + f"under {root_dir}.", + color="y", + ) + return materialized_model_path + def _resolve_hf_tokenizer_dir(self, model_dir, manifest, runtime_config, schema): """Resolve tokenizer directory for the selected artifact runtime.""" tokenizer_dir = None @@ -855,6 +916,13 @@ def _build_hf_onnx_artifact_pipeline(self, model_dir, runtime_key, runtime_confi runtime_config=runtime_config, schema=schema, ) + model_path = self._materialize_hf_onnx_artifact( + model_dir=model_dir, + runtime_key=runtime_key, + runtime_config=runtime_config, + schema=schema, + model_path=model_path, + ) provider = runtime_config.get("provider") or "CPUExecutionProvider" providers = runtime_config.get("providers") or [provider] session = self._create_hf_onnx_session( diff --git a/extensions/serving/test_th_hf_model_base.py b/extensions/serving/test_th_hf_model_base.py index ff017b5e..5c2077dc 100644 --- a/extensions/serving/test_th_hf_model_base.py +++ b/extensions/serving/test_th_hf_model_base.py @@ -875,6 +875,78 @@ def fake_create_session(model_path, providers): self.assertEqual(fake_tokenizer.calls[-1][1]["return_tensors"], "np") self.assertTrue(fake_tokenizer.calls[-1][1]["return_offsets_mapping"]) + def test_onnx_artifact_pipeline_materializes_symlinked_external_data(self): + plugin = _ConcreteHfModel( + MODEL_NAME="test/model", + DEVICE="cpu", + TRUST_REMOTE_CODE=True, + PIPELINE_TASK="text-classification", + WARMUP_ENABLED=False, + ) + fake_tokenizer = _FakeTokenizer() + fake_session = _FakeOrtSession() + created_sessions = [] + plugin._load_hf_onnx_tokenizer = ( # pylint: disable=protected-access + lambda model_dir, runtime_config, manifest=None: fake_tokenizer + ) + plugin._create_hf_onnx_session = ( # pylint: disable=protected-access + lambda model_path, providers: created_sessions.append((Path(model_path), providers)) or fake_session + ) + + with TemporaryDirectory() as tmpdir: + root_dir = Path(tmpdir) + model_dir = root_dir / "snapshot" + blob_dir = root_dir / "blobs" + cache_dir = root_dir / "models-cache" + onnx_dir = model_dir / "onnx" + blob_dir.mkdir(parents=True) + onnx_dir.mkdir(parents=True) + cache_dir.mkdir() + plugin.log.get_models_folder = lambda: str(cache_dir) + (blob_dir / "model.onnx").write_text("onnx", encoding="utf-8") + (blob_dir / "model.onnx_data").write_text("weights", encoding="utf-8") + (onnx_dir / "model.onnx").symlink_to(blob_dir / "model.onnx") + (onnx_dir / "model.onnx_data").symlink_to(blob_dir / "model.onnx_data") + (model_dir / "schema.json").write_text( + '{"outputs":[{"name":"scores"}],"models":{"onnx_fp32":{"path":"onnx/model.onnx"}}}', + encoding="utf-8", + ) + (model_dir / "contract.py").write_text( + "def decode_outputs(outputs, schema, **kwargs):\n return outputs\n", + encoding="utf-8", + ) + manifest = { + "pipeline_task": "text-classification", + "runtimes": { + "onnx_fp32": { + "runtime": "onnxruntime", + "trust_remote_code": False, + "files": [ + "onnx/model.onnx", + "onnx/model.onnx_data", + "schema.json", + "contract.py", + ], + } + }, + } + + plugin._build_hf_onnx_artifact_pipeline( # pylint: disable=protected-access + model_dir=str(model_dir), + runtime_key="onnx_fp32", + runtime_config=manifest["runtimes"]["onnx_fp32"], + manifest=manifest, + ) + + materialized_model_path = created_sessions[0][0] + materialized_sidecar_path = materialized_model_path.parent / "model.onnx_data" + self.assertTrue(materialized_model_path.exists()) + self.assertTrue(materialized_sidecar_path.exists()) + self.assertFalse(materialized_model_path.is_symlink()) + self.assertFalse(materialized_sidecar_path.is_symlink()) + self.assertEqual(materialized_sidecar_path.read_text(encoding="utf-8"), "weights") + self.assertIn("_onnx_materialized", str(materialized_model_path)) + if __name__ == "__main__": unittest.main()