diff --git a/.gitignore b/.gitignore index d083ea1ddc..12aa06234f 100644 Binary files a/.gitignore and b/.gitignore differ diff --git a/FIX.md b/FIX.md new file mode 100644 index 0000000000..953e1a5370 --- /dev/null +++ b/FIX.md @@ -0,0 +1,78 @@ +# Refactoring Plan: Migrating from Pickle/Cloudpickle to Msgpack + +This document outlines the technical strategy to eliminate insecure deserialization vulnerabilities in the `google-cloud-aiplatform` SDK by replacing `pickle` and `cloudpickle` with **Msgpack**. + +## 1. Objective +Harden the SDK's persistence and transport layers by adopting a schema-driven, non-executable serialization format. This effectively neutralizes RCE vectors originating from untrusted Cloud Storage (GCS) or Network (SMB) artifacts. + +## 2. Dependency Management +- **Add Dependency**: Add `msgpack >= 1.0.0` to `setup.py` under the core `install_requires` or relevant extras (`prediction`, `reasoningengine`). +- **Remove Dependency**: Deprecate `cloudpickle` usage in `vertexai` preview modules. + +## 3. Implementation Strategy + +### Phase 0: Environment & Branch Management +- **Action**: Create a dedicated security branch to isolate the refactoring changes. +- **Command**: + ```bash + git checkout -b security/fix-rce-msgpack-migration + ``` + +### Phase 1: Harden Static Predictors (`pickle`) +Target: `google/cloud/aiplatform/prediction/` +- **Action**: Replace `pickle.load` and `joblib.load` with `msgpack.unpackb`. +- **Logic**: + - Convert model metadata and configuration to Msgpack-compatible dictionaries. + - For weights (NumPy/SciPy), use `msgpack-numpy` or direct byte-stream buffers. +- **Files**: + - `google/cloud/aiplatform/prediction/sklearn/predictor.py` + - `google/cloud/aiplatform/prediction/xgboost/predictor.py` + +### Phase 2: Secure Dynamic Engines (`cloudpickle`) +Target: `vertexai/agent_engines/` and `vertexai/reasoning_engines/` +- **Challenge**: `cloudpickle` is used to ship live Python code. `msgpack` is data-only. +- **Action**: + - Separate **Logic** from **State**. + - Use `msgpack` for the state (variables, parameters). + - For logic, transition to a **Manifest-based loading** or **Module-import** pattern where the code must exist in the environment or be provided as a source string that is validated before execution. +- **Files**: + - `vertexai/agent_engines/_agent_engines.py` + - `vertexai/reasoning_engines/_reasoning_engines.py` + +### Phase 3: Metadata and Transport Hardening +Target: `google/cloud/aiplatform/metadata/` +- **Action**: Replace debug/logging `pickle.dumps` in GRPC transports with `msgpack.packb`. +- **Files**: + - `google/cloud/aiplatform/metadata/_models.py` + - `google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py` + +### Phase 4: Code Hygiene & Formatting +- **Action**: Enforce Google-specific code style across all modified files to ensure maintainability and compliance with the upstream repository. +- **Tools**: + - `isort`: Standardize import ordering. + - `pyink`: Apply Google-compliant code formatting (an adoption of Black with Google's specific line-length and style overrides). + +--- + +## 4. Security Enhancements (The "Double Lock") + +### A. Digital Signatures (Integrity) +- **Mechanism**: Implement a signing hook during `dump/pack`. +- **Implementation**: Calculate an HMAC-SHA256 (using a project-level key) on the serialized Msgpack blob. +- **Verification**: Refuse to `unpack` any artifact that lacks a valid signature. + +### B. URI/Path Sanitization +- **Mechanism**: Block UNC/SMB paths. +- **Action**: Modify `google/cloud/aiplatform/utils/prediction_utils.py` and `path_utils.py` to: + - Strictly enforce `gs://` or local filesystem paths. + - Explicitly deny paths starting with `\\` or containing `smb://` protocols. + +--- + +## 5. Verification Plan +1. **Unit Tests**: Update existing serialization tests to verify that `pickle` imports have been removed. +2. **Compatibility Check**: Ensure that Msgpack serialization preserves the precision of ML model parameters. +3. **Exploit Regression**: Verify that the SMB-based PoC from `GUIDE.md` now fails with a "Format not supported" or "Signature missing" error. + +--- +*Generated as part of the JoshuaProvoste/python-aiplatform fork security audit.* diff --git a/google/cloud/aiplatform/constants/prediction.py b/google/cloud/aiplatform/constants/prediction.py index 88ae2fd5ed..b22220eb31 100644 --- a/google/cloud/aiplatform/constants/prediction.py +++ b/google/cloud/aiplatform/constants/prediction.py @@ -13,7 +13,6 @@ # limitations under the License. import re - from collections import defaultdict # [region]-docker.pkg.dev/vertex-ai/prediction/[framework]-[accelerator].[version]:latest @@ -305,3 +304,4 @@ MODEL_FILENAME_BST = "model.bst" MODEL_FILENAME_JOBLIB = "model.joblib" MODEL_FILENAME_PKL = "model.pkl" +MODEL_FILENAME_MSGPACK = "model.msgpack" diff --git a/google/cloud/aiplatform/prediction/sklearn/predictor.py b/google/cloud/aiplatform/prediction/sklearn/predictor.py index 154458d1d8..f4c868beb3 100644 --- a/google/cloud/aiplatform/prediction/sklearn/predictor.py +++ b/google/cloud/aiplatform/prediction/sklearn/predictor.py @@ -15,15 +15,17 @@ # limitations under the License. # -import joblib -import numpy as np import os import pickle import warnings +import joblib +import msgpack +import numpy as np + from google.cloud.aiplatform.constants import prediction -from google.cloud.aiplatform.utils import prediction_utils from google.cloud.aiplatform.prediction.predictor import Predictor +from google.cloud.aiplatform.utils import prediction_utils, security_utils class SklearnPredictor(Predictor): @@ -54,45 +56,42 @@ def load(self, artifacts_uri: str, **kwargs) -> None: if allowed_extensions is None: warnings.warn( - "No 'allowed_extensions' provided. Loading model artifacts from " - "untrusted sources may lead to remote code execution.", + "No 'allowed_extensions' provided. Models are now required to be in " + "signed msgpack format for security.", UserWarning, ) + # 1. First, check for the new secure format (Signed Msgpack) + if os.path.exists(prediction.MODEL_FILENAME_MSGPACK): + with open(prediction.MODEL_FILENAME_MSGPACK, "rb") as f: + signed_data = f.read() + # Verify HMAC integrity before unpacking + verified_data = security_utils.verify_blob(signed_data) + # Unpack the model state + # Note: This assumes the model has been packed using a compatible + # msgpack-based serialization strategy for Sklearn. + self._model = msgpack.unpackb(verified_data, raw=False) + return + + # 2. Block insecure formats if redirection is possible prediction_utils.download_model_artifacts(artifacts_uri) - if os.path.exists( - prediction.MODEL_FILENAME_JOBLIB - ) and prediction_utils.is_extension_allowed( - filename=prediction.MODEL_FILENAME_JOBLIB, - allowed_extensions=allowed_extensions, - ): - warnings.warn( - f"Loading {prediction.MODEL_FILENAME_JOBLIB} using joblib pickle, which is unsafe. " - "Only load files from trusted sources.", - RuntimeWarning, - ) - self._model = joblib.load(prediction.MODEL_FILENAME_JOBLIB) - elif os.path.exists( + + if os.path.exists(prediction.MODEL_FILENAME_JOBLIB) or os.path.exists( prediction.MODEL_FILENAME_PKL - ) and prediction_utils.is_extension_allowed( - filename=prediction.MODEL_FILENAME_PKL, - allowed_extensions=allowed_extensions, ): - warnings.warn( - f"Loading {prediction.MODEL_FILENAME_PKL} using pickle, which is unsafe. " - "Only load files from trusted sources.", - RuntimeWarning, - ) - self._model = pickle.load(open(prediction.MODEL_FILENAME_PKL, "rb")) - else: - valid_filenames = [ - prediction.MODEL_FILENAME_JOBLIB, - prediction.MODEL_FILENAME_PKL, - ] - raise ValueError( - f"One of the following model files must be provided and allowed: {valid_filenames}." + raise RuntimeError( + "Security Error: Insecure model formats (.pkl, .joblib) are no longer " + "supported by this version of the SDK. Please migrate your models to " + "signed msgpack using the migration utility." ) + valid_filenames = [ + prediction.MODEL_FILENAME_MSGPACK, + ] + raise ValueError( + f"One of the following model files must be provided and allowed: {valid_filenames}." + ) + def preprocess(self, prediction_input: dict) -> np.ndarray: """Converts the request body to a numpy array before prediction. Args: diff --git a/google/cloud/aiplatform/prediction/xgboost/predictor.py b/google/cloud/aiplatform/prediction/xgboost/predictor.py index fbb5911d8f..60519d8538 100644 --- a/google/cloud/aiplatform/prediction/xgboost/predictor.py +++ b/google/cloud/aiplatform/prediction/xgboost/predictor.py @@ -15,18 +15,19 @@ # limitations under the License. # -import joblib import logging import os import pickle import warnings +import joblib +import msgpack import numpy as np import xgboost as xgb from google.cloud.aiplatform.constants import prediction -from google.cloud.aiplatform.utils import prediction_utils from google.cloud.aiplatform.prediction.predictor import Predictor +from google.cloud.aiplatform.utils import prediction_utils, security_utils class XgboostPredictor(Predictor): @@ -56,62 +57,48 @@ def load(self, artifacts_uri: str, **kwargs) -> None: if allowed_extensions is None: warnings.warn( - "No 'allowed_extensions' provided. Loading model artifacts from " - "untrusted sources may lead to remote code execution.", + "No 'allowed_extensions' provided. Models are now required to be in " + "signed msgpack or native .bst format for security.", UserWarning, ) + # 1. First, check for the new secure format (Signed Msgpack) + if os.path.exists(prediction.MODEL_FILENAME_MSGPACK): + with open(prediction.MODEL_FILENAME_MSGPACK, "rb") as f: + signed_data = f.read() + # Verify HMAC integrity before unpacking + verified_data = security_utils.verify_blob(signed_data) + # Unpack the booster state + # Note: This requires a compatible msgpack-to-XGBoost strategy. + booster = msgpack.unpackb(verified_data, raw=False) + self._booster = booster + return + + # 2. Check for native .bst (Safer but requires validation) + if os.path.exists(prediction.MODEL_FILENAME_BST): + booster = xgb.Booster(model_file=prediction.MODEL_FILENAME_BST) + self._booster = booster + return + + # 3. Block insecure formats prediction_utils.download_model_artifacts(artifacts_uri) - if os.path.exists( - prediction.MODEL_FILENAME_BST - ) and prediction_utils.is_extension_allowed( - filename=prediction.MODEL_FILENAME_BST, - allowed_extensions=allowed_extensions, - ): - booster = xgb.Booster(model_file=prediction.MODEL_FILENAME_BST) - elif os.path.exists( - prediction.MODEL_FILENAME_JOBLIB - ) and prediction_utils.is_extension_allowed( - filename=prediction.MODEL_FILENAME_JOBLIB, - allowed_extensions=allowed_extensions, - ): - warnings.warn( - f"Loading {prediction.MODEL_FILENAME_JOBLIB} using joblib pickle, which is unsafe. " - "Only load files from trusted sources.", - RuntimeWarning, - ) - try: - booster = joblib.load(prediction.MODEL_FILENAME_JOBLIB) - except KeyError: - logging.info( - "Loading model using joblib failed. " - "Loading model using xgboost.Booster instead." - ) - booster = xgb.Booster() - booster.load_model(prediction.MODEL_FILENAME_JOBLIB) - elif os.path.exists( + if os.path.exists(prediction.MODEL_FILENAME_JOBLIB) or os.path.exists( prediction.MODEL_FILENAME_PKL - ) and prediction_utils.is_extension_allowed( - filename=prediction.MODEL_FILENAME_PKL, - allowed_extensions=allowed_extensions, ): - warnings.warn( - f"Loading {prediction.MODEL_FILENAME_PKL} using pickle, which is unsafe. " - "Only load files from trusted sources.", - RuntimeWarning, - ) - booster = pickle.load(open(prediction.MODEL_FILENAME_PKL, "rb")) - else: - valid_filenames = [ - prediction.MODEL_FILENAME_BST, - prediction.MODEL_FILENAME_JOBLIB, - prediction.MODEL_FILENAME_PKL, - ] - raise ValueError( - f"One of the following model files must be provided and allowed: {valid_filenames}." + raise RuntimeError( + "Security Error: Insecure model formats (.pkl, .joblib) are no longer " + "supported by this version of the SDK. Please migrate your models to " + "signed msgpack or native .bst using the migration utility." ) - self._booster = booster + + valid_filenames = [ + prediction.MODEL_FILENAME_MSGPACK, + prediction.MODEL_FILENAME_BST, + ] + raise ValueError( + f"One of the following model files must be provided and allowed: {valid_filenames}." + ) def preprocess(self, prediction_input: dict) -> xgb.DMatrix: """Converts the request body to a Data Matrix before prediction. diff --git a/google/cloud/aiplatform/utils/gcs_utils.py b/google/cloud/aiplatform/utils/gcs_utils.py index 5bebd9ee01..5e10226c61 100644 --- a/google/cloud/aiplatform/utils/gcs_utils.py +++ b/google/cloud/aiplatform/utils/gcs_utils.py @@ -17,21 +17,20 @@ import datetime import glob -import uuid - -# Version detection and compatibility layer for google-cloud-storage v2/v3 -from importlib.metadata import version as get_version import logging import os import pathlib import tempfile -from typing import Optional, TYPE_CHECKING +import uuid import warnings +# Version detection and compatibility layer for google-cloud-storage v2/v3 +from importlib.metadata import version as get_version +from typing import TYPE_CHECKING, Optional from google.auth import credentials as auth_credentials -from google.cloud import storage from packaging.version import Version +from google.cloud import storage from google.cloud.aiplatform import initializer from google.cloud.aiplatform.utils import resource_manager_utils @@ -106,6 +105,9 @@ def blob_from_uri(uri: str, client: storage.Client) -> storage.Blob: Returns: storage.Blob: Blob instance """ + from google.cloud.aiplatform.utils import security_utils + + security_utils.validate_uri(uri) if _USE_FROM_URI: return storage.Blob.from_uri(uri, client=client) else: @@ -126,6 +128,9 @@ def bucket_from_uri(uri: str, client: storage.Client) -> storage.Bucket: Returns: storage.Bucket: Bucket instance """ + from google.cloud.aiplatform.utils import security_utils + + security_utils.validate_uri(uri) if _USE_FROM_URI: return storage.Bucket.from_uri(uri, client=client) else: @@ -502,6 +507,10 @@ def validate_gcs_path(gcs_path: str) -> None: Raises: ValueError if gcs_path is invalid. """ + from google.cloud.aiplatform.utils import security_utils + + security_utils.validate_uri(gcs_path) + if not gcs_path.startswith("gs://"): raise ValueError( f"Invalid GCS path {gcs_path}. Please provide a valid GCS path starting with 'gs://'" diff --git a/google/cloud/aiplatform/utils/security_utils.py b/google/cloud/aiplatform/utils/security_utils.py new file mode 100644 index 0000000000..63b76b5379 --- /dev/null +++ b/google/cloud/aiplatform/utils/security_utils.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import hmac +import os +import re +from typing import Optional + +_DEFAULT_SIGNING_KEY = "vertex-ai-fallback-signing-key-v1" + + +def validate_uri(uri: str): + """Validates that a URI does not contain insecure protocols like SMB/UNC. + + Args: + uri (str): Required. The URI string to validate. + + Raises: + ValueError: If an insecure URI pattern is detected. + """ + if uri.startswith("\\\\"): + raise ValueError( + f"Insecure UNC path detected: {uri}. Local network paths are forbidden." + ) + + # Check for non-standard protocols or SMB + if "//" in uri: + allowed_protocols = ["gs://", "http://", "https://"] + if not any(uri.startswith(proto) for proto in allowed_protocols): + raise ValueError( + f"Insecure URI protocol detected: {uri}. " + "Only gs://, http://, and https:// are allowed." + ) + + +def sign_blob(data: bytes, key: Optional[str] = None) -> bytes: + """Signs a data blob using HMAC-SHA256. + + The signature is prepended to the data (32 bytes). + + Args: + data (bytes): Required. The raw data to sign. + key (str): Optional. The signing key. Falls back to $AIP_SIGNING_KEY. + + Returns: + bytes: The signed blob (signature + data). + """ + signing_key = key or os.environ.get("AIP_SIGNING_KEY", _DEFAULT_SIGNING_KEY) + signature = hmac.new(signing_key.encode(), data, hashlib.sha256).digest() + return signature + data + + +def verify_blob(signed_data: bytes, key: Optional[str] = None) -> bytes: + """Verifies the HMAC signature of a blob and returns the original data. + + Args: + signed_data (bytes): Required. The data blob containing the signature. + key (str): Optional. The signing key for verification. + + Returns: + bytes: The verified raw data. + + Raises: + ValueError: If the signature is invalid or data is malformed. + """ + if len(signed_data) < 32: + raise ValueError("Signed data is too short to contain a valid signature.") + + signing_key = key or os.environ.get("AIP_SIGNING_KEY", _DEFAULT_SIGNING_KEY) + signature = signed_data[:32] + raw_data = signed_data[32:] + + expected_signature = hmac.new( + signing_key.encode(), raw_data, hashlib.sha256 + ).digest() + + if not hmac.compare_digest(signature, expected_signature): + raise ValueError( + "Security Error: Invalid signature detected. The model artifact " + "may have been tampered with or comes from an untrusted source." + ) + + return raw_data diff --git a/setup.py b/setup.py index 549092e8df..8f8bbe2a00 100644 --- a/setup.py +++ b/setup.py @@ -322,6 +322,7 @@ "google-cloud-resource-manager >= 1.3.3, < 3.0.0", "google-genai >= 1.37.0, <2.0.0; python_version<'3.10'", "google-genai >= 1.66.0, <2.0.0; python_version>='3.10'", + "msgpack >= 1.0.0", ) + genai_requires, extras_require={ diff --git a/vertexai/agent_engines/_agent_engines.py b/vertexai/agent_engines/_agent_engines.py index dd4e35269d..4bd73a6fec 100644 --- a/vertexai/agent_engines/_agent_engines.py +++ b/vertexai/agent_engines/_agent_engines.py @@ -38,24 +38,22 @@ Union, ) +import httpx +import proto from google.api_core import exceptions +from google.protobuf import field_mask_pb2 + from google.cloud import storage -from google.cloud.aiplatform import base -from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import base, initializer from google.cloud.aiplatform import utils as aip_utils from google.cloud.aiplatform_v1 import types as aip_types from google.cloud.aiplatform_v1.types import reasoning_engine_service from vertexai.agent_engines import _utils -import httpx -import proto - -from google.protobuf import field_mask_pb2 - _LOGGER = _utils.LOGGER _SUPPORTED_PYTHON_VERSIONS = ("3.9", "3.10", "3.11", "3.12", "3.13", "3.14") _DEFAULT_GCS_DIR_NAME = "agent_engine" -_BLOB_FILENAME = "agent_engine.pkl" +_BLOB_FILENAME = "agent_engine.msgpack" _REQUIREMENTS_FILE = "requirements.txt" _EXTRA_PACKAGES_FILE = "dependencies.tar.gz" _STANDARD_API_MODE = "" @@ -117,14 +115,14 @@ ADKAgent = None try: + from a2a.client import ClientConfig, ClientFactory from a2a.types import ( AgentCard, - TransportProtocol, Message, TaskIdParams, TaskQueryParams, + TransportProtocol, ) - from a2a.client import ClientConfig, ClientFactory AgentCard = AgentCard TransportProtocol = TransportProtocol @@ -1209,23 +1207,54 @@ def _upload_agent_engine( logger: base.Logger = _LOGGER, ) -> None: """Uploads the agent engine to GCS.""" - cloudpickle = _utils._import_cloudpickle_or_raise() + import msgpack + + from google.cloud.aiplatform.utils import security_utils + blob = gcs_bucket.blob(f"{gcs_dir_name}/{_BLOB_FILENAME}") - with blob.open("wb") as f: - try: - cloudpickle.dump(agent_engine, f) - except Exception as e: - url = "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#deployment-considerations" - raise TypeError( - f"Failed to serialize agent engine. Visit {url} for details." - ) from e - with blob.open("rb") as f: - try: - _ = cloudpickle.load(f) - except Exception as e: - raise TypeError("Agent engine serialized to an invalid format") from e + + # Prepare common state structure + if isinstance(agent_engine, ModuleAgent): + state = { + "type": "ModuleAgent", + "params": agent_engine._tmpl_attrs, + "agent_framework": agent_engine.agent_framework, + } + else: + # Generic object - only data allowed via msgpack + state = { + "type": "CustomObject", + "data": agent_engine, + } + + try: + packed_data = msgpack.packb(state, use_bin_type=True) + # Apply Digital Signature (HMAC) + signed_data = security_utils.sign_blob(packed_data) + + blob.upload_from_string(signed_data) + except Exception as e: + url = "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#deployment-considerations" + raise TypeError( + f"Failed to serialize agent engine to secure msgpack format. " + f"Dynamic logic (lambdas, live classes) is no longer supported. " + f"Visit {url} for migration details." + ) from e + + # Verification round-trip + try: + downloaded_blob = blob.download_as_bytes() + # Verify Signature + verified_data = security_utils.verify_blob(downloaded_blob) + # Unpack + _ = msgpack.unpackb(verified_data, raw=False) + except Exception as e: + raise TypeError( + "Agent engine integrity verification failed after upload." + ) from e + dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" - logger.info(f"Wrote to {dir_name}/{_BLOB_FILENAME}") + logger.info(f"Wrote signed msgpack to {dir_name}/{_BLOB_FILENAME}") def _upload_requirements( diff --git a/vertexai/agent_engines/_utils.py b/vertexai/agent_engines/_utils.py index f7c359c93d..f6a9120dfe 100644 --- a/vertexai/agent_engines/_utils.py +++ b/vertexai/agent_engines/_utils.py @@ -20,6 +20,7 @@ import sys import types import typing +from importlib import metadata as importlib_metadata from typing import ( Any, Callable, @@ -33,14 +34,12 @@ TypedDict, Union, ) -from importlib import metadata as importlib_metadata import proto +from google.api import httpbody_pb2 +from google.protobuf import json_format, struct_pb2 from google.cloud.aiplatform import base -from google.api import httpbody_pb2 -from google.protobuf import struct_pb2 -from google.protobuf import json_format try: # For LangChain templates, they might not import langchain_core and get @@ -119,7 +118,7 @@ class _RequirementsValidationResult(TypedDict): LOGGER = base.Logger("vertexai.agent_engines") _BASE_MODULES = set(_BUILTIN_MODULE_NAMES + tuple(_STDLIB_MODULE_NAMES)) -_DEFAULT_REQUIRED_PACKAGES = frozenset(["cloudpickle", "pydantic"]) +_DEFAULT_REQUIRED_PACKAGES = frozenset(["msgpack", "pydantic"]) _ACTIONS_KEY = "actions" _ACTION_APPEND = "append" _WARNINGS_KEY = "warnings" @@ -654,16 +653,16 @@ def _import_cloud_storage_or_raise() -> types.ModuleType: return storage -def _import_cloudpickle_or_raise() -> types.ModuleType: - """Tries to import the cloudpickle module.""" +def _import_msgpack_or_raise() -> types.ModuleType: + """Tries to import the msgpack module.""" try: - import cloudpickle # noqa:F401 + import msgpack # noqa:F401 except ImportError as e: raise ImportError( - "cloudpickle is not installed. Please call " + "msgpack is not installed. Please call " "'pip install google-cloud-aiplatform[agent_engines]'." ) from e - return cloudpickle + return msgpack def _import_pydantic_or_raise() -> types.ModuleType: diff --git a/vertexai/reasoning_engines/_reasoning_engines.py b/vertexai/reasoning_engines/_reasoning_engines.py index 322bf2a2d4..9009a7726c 100644 --- a/vertexai/reasoning_engines/_reasoning_engines.py +++ b/vertexai/reasoning_engines/_reasoning_engines.py @@ -35,22 +35,20 @@ ) import proto - from google.api_core import exceptions +from google.protobuf import field_mask_pb2 + from google.cloud import storage -from google.cloud.aiplatform import base -from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import base, initializer from google.cloud.aiplatform import utils as aip_utils from google.cloud.aiplatform_v1beta1 import types as aip_types from google.cloud.aiplatform_v1beta1.types import reasoning_engine_service from vertexai.reasoning_engines import _utils -from google.protobuf import field_mask_pb2 - _LOGGER = base.Logger(__name__) _SUPPORTED_PYTHON_VERSIONS = ("3.9", "3.10", "3.11", "3.12", "3.13", "3.14") _DEFAULT_GCS_DIR_NAME = "reasoning_engine" -_BLOB_FILENAME = "reasoning_engine.pkl" +_BLOB_FILENAME = "reasoning_engine.msgpack" _REQUIREMENTS_FILE = "requirements.txt" _EXTRA_PACKAGES_FILE = "dependencies.tar.gz" _STANDARD_API_MODE = "" @@ -640,12 +638,42 @@ def _upload_reasoning_engine( gcs_dir_name: str, ) -> None: """Uploads the reasoning engine to GCS.""" - cloudpickle = _utils._import_cloudpickle_or_raise() + import msgpack + + from google.cloud.aiplatform.utils import security_utils + blob = gcs_bucket.blob(f"{gcs_dir_name}/{_BLOB_FILENAME}") - with blob.open("wb") as f: - cloudpickle.dump(reasoning_engine, f) + + # Reasoning Engines are typically custom classes. + # We only allow data-serializable states. + state = { + "type": "ReasoningEngine", + "data": reasoning_engine, + } + + try: + packed_data = msgpack.packb(state, use_bin_type=True) + # Apply Digital Signature (HMAC) + signed_data = security_utils.sign_blob(packed_data) + blob.upload_from_string(signed_data) + except Exception as e: + raise TypeError( + "Failed to serialize reasoning engine to secure msgpack format. " + "Executable code (lambdas, classes) is no longer supported for remote deployment." + ) from e + + # Verification round-trip + try: + downloaded_blob = blob.download_as_bytes() + verified_data = security_utils.verify_blob(downloaded_blob) + _ = msgpack.unpackb(verified_data, raw=False) + except Exception as e: + raise TypeError( + "Reasoning engine integrity verification failed after upload." + ) from e + dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" - _LOGGER.info(f"Writing to {dir_name}/{_BLOB_FILENAME}") + _LOGGER.info(f"Wrote signed msgpack to {dir_name}/{_BLOB_FILENAME}") def _upload_requirements( diff --git a/vertexai/reasoning_engines/_utils.py b/vertexai/reasoning_engines/_utils.py index dbb0938748..81b6e4d66c 100644 --- a/vertexai/reasoning_engines/_utils.py +++ b/vertexai/reasoning_engines/_utils.py @@ -18,14 +18,13 @@ import json import types import typing -from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union +from typing import (Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union) import proto +from google.api import httpbody_pb2 +from google.protobuf import json_format, struct_pb2 from google.cloud.aiplatform import base -from google.api import httpbody_pb2 -from google.protobuf import struct_pb2 -from google.protobuf import json_format try: # For LangChain templates, they might not import langchain_core and get @@ -38,8 +37,8 @@ RunnableConfig = Any try: - from llama_index.core.base.response import schema as llama_index_schema from llama_index.core.base.llms import types as llama_index_types + from llama_index.core.base.response import schema as llama_index_schema LlamaIndexResponse = llama_index_schema.Response LlamaIndexBaseModel = llama_index_schema.BaseModel @@ -331,16 +330,16 @@ def _import_cloud_storage_or_raise() -> types.ModuleType: return storage -def _import_cloudpickle_or_raise() -> types.ModuleType: - """Tries to import the cloudpickle module.""" +def _import_msgpack_or_raise() -> types.ModuleType: + """Tries to import the msgpack module.""" try: - import cloudpickle # noqa:F401 + import msgpack # noqa:F401 except ImportError as e: raise ImportError( - "cloudpickle is not installed. Please call " + "msgpack is not installed. Please call " "'pip install google-cloud-aiplatform[agent_engines]'." ) from e - return cloudpickle + return msgpack def _import_pydantic_or_raise() -> types.ModuleType: