From fcac29d251d04bf54caa7d3be9ef2a7f72e0ccf3 Mon Sep 17 00:00:00 2001 From: XinyanZhou Date: Tue, 21 Apr 2026 20:29:57 +0800 Subject: [PATCH 1/3] feat:compatible with Pageindex SDK --- pageindex/__init__.py | 2 + pageindex/client.py | 105 ++++++++++- pageindex/cloud_api.py | 264 ++++++++++++++++++++++++++++ pageindex/errors.py | 10 +- pageindex/utils.py | 92 ++++++++-- tests/test_errors.py | 6 +- tests/test_legacy_sdk_contract.py | 262 +++++++++++++++++++++++++++ tests/test_legacy_utils_contract.py | 106 +++++++++++ 8 files changed, 828 insertions(+), 19 deletions(-) create mode 100644 pageindex/cloud_api.py create mode 100644 tests/test_legacy_sdk_contract.py create mode 100644 tests/test_legacy_utils_contract.py diff --git a/pageindex/__init__.py b/pageindex/__init__.py index 64464418f..4f2418ea5 100644 --- a/pageindex/__init__.py +++ b/pageindex/__init__.py @@ -13,6 +13,7 @@ from .events import QueryEvent from .errors import ( PageIndexError, + PageIndexAPIError, CollectionNotFoundError, DocumentNotFoundError, IndexingError, @@ -32,6 +33,7 @@ "StorageEngine", "QueryEvent", "PageIndexError", + "PageIndexAPIError", "CollectionNotFoundError", "DocumentNotFoundError", "IndexingError", diff --git a/pageindex/client.py b/pageindex/client.py index 806ebb638..fee7a0ea7 100644 --- a/pageindex/client.py +++ b/pageindex/client.py @@ -1,8 +1,11 @@ # pageindex/client.py from __future__ import annotations from pathlib import Path +from typing import Any, Iterator + from .collection import Collection from .config import IndexConfig +from .errors import PageIndexAPIError from .parser.protocol import DocumentParser @@ -39,21 +42,25 @@ class PageIndexClient: # Or use LocalClient / CloudClient for explicit mode selection """ - def __init__(self, api_key: str = None, model: str = None, + def __init__(self, api_key: str | None = None, model: str = None, retrieve_model: str = None, storage_path: str = None, storage=None, index_config: IndexConfig | dict = None): - if api_key: + if api_key is not None: self._init_cloud(api_key) else: self._init_local(model, retrieve_model, storage_path, storage, index_config) def _init_cloud(self, api_key: str): from .backend.cloud import CloudBackend + from .cloud_api import LegacyCloudAPI self._backend = CloudBackend(api_key=api_key) + self._legacy_cloud_api = LegacyCloudAPI(api_key=api_key) def _init_local(self, model: str = None, retrieve_model: str = None, storage_path: str = None, storage=None, index_config: IndexConfig | dict = None): + self._legacy_cloud_api = None + # Build IndexConfig: merge model/retrieve_model with index_config overrides = {} if model: @@ -123,6 +130,100 @@ def register_parser(self, parser: DocumentParser) -> None: raise PageIndexError("Custom parsers are not supported in cloud mode") self._backend.register_parser(parser) + def _require_cloud_api(self): + if self._legacy_cloud_api is None: + from .errors import PageIndexAPIError + raise PageIndexAPIError( + "This method is part of the pageindex 0.2.x cloud SDK API. " + "Initialize with api_key to use it." + ) + return self._legacy_cloud_api + + # pageindex 0.2.x cloud SDK compatibility methods + def submit_document( + self, + file_path: str, + mode: str | None = None, + beta_headers: list[str] | None = None, + folder_id: str | None = None, + ) -> dict[str, Any]: + return self._require_cloud_api().submit_document( + file_path=file_path, + mode=mode, + beta_headers=beta_headers, + folder_id=folder_id, + ) + + def get_ocr(self, doc_id: str, format: str = "page") -> dict[str, Any]: + return self._require_cloud_api().get_ocr(doc_id=doc_id, format=format) + + def get_tree(self, doc_id: str, node_summary: bool = False) -> dict[str, Any]: + return self._require_cloud_api().get_tree(doc_id=doc_id, node_summary=node_summary) + + def is_retrieval_ready(self, doc_id: str) -> bool: + return self._require_cloud_api().is_retrieval_ready(doc_id=doc_id) + + def submit_query(self, doc_id: str, query: str, thinking: bool = False) -> dict[str, Any]: + return self._require_cloud_api().submit_query( + doc_id=doc_id, + query=query, + thinking=thinking, + ) + + def get_retrieval(self, retrieval_id: str) -> dict[str, Any]: + return self._require_cloud_api().get_retrieval(retrieval_id=retrieval_id) + + def chat_completions( + self, + messages: list[dict[str, str]], + stream: bool = False, + doc_id: str | list[str] | None = None, + temperature: float | None = None, + stream_metadata: bool = False, + enable_citations: bool = False, + ) -> dict[str, Any] | Iterator[str] | Iterator[dict[str, Any]]: + return self._require_cloud_api().chat_completions( + messages=messages, + stream=stream, + doc_id=doc_id, + temperature=temperature, + stream_metadata=stream_metadata, + enable_citations=enable_citations, + ) + + def get_document(self, doc_id: str) -> dict[str, Any]: + return self._require_cloud_api().get_document(doc_id=doc_id) + + def delete_document(self, doc_id: str) -> dict[str, Any]: + return self._require_cloud_api().delete_document(doc_id=doc_id) + + def list_documents( + self, + limit: int = 50, + offset: int = 0, + folder_id: str | None = None, + ) -> dict[str, Any]: + return self._require_cloud_api().list_documents( + limit=limit, + offset=offset, + folder_id=folder_id, + ) + + def create_folder( + self, + name: str, + description: str | None = None, + parent_folder_id: str | None = None, + ) -> dict[str, Any]: + return self._require_cloud_api().create_folder( + name=name, + description=description, + parent_folder_id=parent_folder_id, + ) + + def list_folders(self, parent_folder_id: str | None = None) -> dict[str, Any]: + return self._require_cloud_api().list_folders(parent_folder_id=parent_folder_id) + class LocalClient(PageIndexClient): """Local mode — indexes and queries documents on your machine. diff --git a/pageindex/cloud_api.py b/pageindex/cloud_api.py new file mode 100644 index 000000000..2962299b5 --- /dev/null +++ b/pageindex/cloud_api.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import json +from typing import Any, Iterator + +import requests + +from .errors import PageIndexAPIError + + +class LegacyCloudAPI: + """Compatibility layer for the pageindex 0.2.x cloud SDK API.""" + + BASE_URL = "https://api.pageindex.ai" + REQUEST_TIMEOUT = 30 + STREAM_TIMEOUT = (30, None) + + def __init__(self, api_key: str): + self.api_key = api_key + + def _headers(self) -> dict[str, str]: + return {"api_key": self.api_key} + + def _request(self, method: str, path: str, error_prefix: str, **kwargs) -> requests.Response: + kwargs.setdefault("timeout", self.REQUEST_TIMEOUT) + try: + response = requests.request( + method, + f"{self.BASE_URL}{path}", + headers=self._headers(), + **kwargs, + ) + except requests.RequestException as e: + raise PageIndexAPIError(f"{error_prefix}: {e}") from e + + if response.status_code != 200: + raise PageIndexAPIError(f"{error_prefix}: {response.text}") + return response + + def submit_document( + self, + file_path: str, + mode: str | None = None, + beta_headers: list[str] | None = None, + folder_id: str | None = None, + ) -> dict[str, Any]: + files = {"file": open(file_path, "rb")} + data: dict[str, Any] = {"if_retrieval": True} + if mode is not None: + data["mode"] = mode + if beta_headers is not None: + data["beta_headers"] = json.dumps(beta_headers) + if folder_id is not None: + data["folder_id"] = folder_id + + try: + response = self._request( + "POST", + "/doc/", + "Failed to submit document", + files=files, + data=data, + ) + finally: + files["file"].close() + + return response.json() + + def get_ocr(self, doc_id: str, format: str = "page") -> dict[str, Any]: + if format not in ["page", "node", "raw"]: + raise ValueError("Format parameter must be 'page', 'node', or 'raw'") + + response = self._request( + "GET", + f"/doc/{doc_id}/?type=ocr&format={format}", + "Failed to get OCR result", + ) + return response.json() + + def get_tree(self, doc_id: str, node_summary: bool = False) -> dict[str, Any]: + response = self._request( + "GET", + f"/doc/{doc_id}/?type=tree&summary={node_summary}", + "Failed to get tree result", + ) + return response.json() + + def is_retrieval_ready(self, doc_id: str) -> bool: + try: + result = self.get_tree(doc_id) + return result.get("retrieval_ready", False) + except PageIndexAPIError: + return False + + def submit_query(self, doc_id: str, query: str, thinking: bool = False) -> dict[str, Any]: + payload = { + "doc_id": doc_id, + "query": query, + "thinking": thinking, + } + response = self._request( + "POST", + "/retrieval/", + "Failed to submit retrieval", + json=payload, + ) + return response.json() + + def get_retrieval(self, retrieval_id: str) -> dict[str, Any]: + response = self._request( + "GET", + f"/retrieval/{retrieval_id}/", + "Failed to get retrieval result", + ) + return response.json() + + def chat_completions( + self, + messages: list[dict[str, str]], + stream: bool = False, + doc_id: str | list[str] | None = None, + temperature: float | None = None, + stream_metadata: bool = False, + enable_citations: bool = False, + ) -> dict[str, Any] | Iterator[str] | Iterator[dict[str, Any]]: + payload: dict[str, Any] = { + "messages": messages, + "stream": stream, + } + + if doc_id is not None: + payload["doc_id"] = doc_id + if temperature is not None: + payload["temperature"] = temperature + if enable_citations: + payload["enable_citations"] = enable_citations + + response = self._request( + "POST", + "/chat/completions/", + "Failed to get chat completion", + json=payload, + stream=stream, + timeout=self.STREAM_TIMEOUT if stream else self.REQUEST_TIMEOUT, + ) + + if stream: + if stream_metadata: + return self._stream_chat_response_raw(response) + return self._stream_chat_response(response) + return response.json() + + def _stream_chat_response(self, response: requests.Response) -> Iterator[str]: + try: + for line in response.iter_lines(): + if not line: + continue + line = line.decode("utf-8") + if not line.startswith("data: "): + continue + data = line[6:] + if data == "[DONE]": + break + + try: + chunk = json.loads(data) + except json.JSONDecodeError: + continue + content = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "") + if content: + yield content + except requests.RequestException as e: + raise PageIndexAPIError(f"Failed to stream chat completion: {e}") from e + + def _stream_chat_response_raw(self, response: requests.Response) -> Iterator[dict[str, Any]]: + try: + for line in response.iter_lines(): + if not line: + continue + line = line.decode("utf-8") + if not line.startswith("data: "): + continue + data = line[6:] + if data == "[DONE]": + break + + try: + yield json.loads(data) + except json.JSONDecodeError: + continue + except requests.RequestException as e: + raise PageIndexAPIError(f"Failed to stream chat completion: {e}") from e + + def get_document(self, doc_id: str) -> dict[str, Any]: + response = self._request( + "GET", + f"/doc/{doc_id}/metadata/", + "Failed to get document metadata", + ) + return response.json() + + def delete_document(self, doc_id: str) -> dict[str, Any]: + response = self._request( + "DELETE", + f"/doc/{doc_id}/", + "Failed to delete document", + ) + return response.json() + + def list_documents( + self, + limit: int = 50, + offset: int = 0, + folder_id: str | None = None, + ) -> dict[str, Any]: + if limit < 1 or limit > 100: + raise ValueError("limit must be between 1 and 100") + if offset < 0: + raise ValueError("offset must be non-negative") + + params: dict[str, Any] = {"limit": limit, "offset": offset} + if folder_id is not None: + params["folder_id"] = folder_id + + response = self._request( + "GET", + "/docs/", + "Failed to list documents", + params=params, + ) + return response.json() + + def create_folder( + self, + name: str, + description: str | None = None, + parent_folder_id: str | None = None, + ) -> dict[str, Any]: + payload: dict[str, Any] = {"name": name} + if description is not None: + payload["description"] = description + if parent_folder_id is not None: + payload["parent_folder_id"] = parent_folder_id + + response = self._request( + "POST", + "/folder/", + "Failed to create folder", + json=payload, + ) + return response.json() + + def list_folders(self, parent_folder_id: str | None = None) -> dict[str, Any]: + params = {} + if parent_folder_id is not None: + params["parent_folder_id"] = parent_folder_id + + response = self._request( + "GET", + "/folders/", + "Failed to list folders", + params=params, + ) + return response.json() diff --git a/pageindex/errors.py b/pageindex/errors.py index 790b68ffd..045a9db40 100644 --- a/pageindex/errors.py +++ b/pageindex/errors.py @@ -18,7 +18,15 @@ class IndexingError(PageIndexError): pass -class CloudAPIError(PageIndexError): +class PageIndexAPIError(PageIndexError): + """PageIndex cloud API returned an error. + + Kept for compatibility with the pageindex 0.2.x cloud SDK. + """ + pass + + +class CloudAPIError(PageIndexAPIError): """Cloud API returned error.""" pass diff --git a/pageindex/utils.py b/pageindex/utils.py index f00ccf3a7..1e0080d21 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -15,6 +15,7 @@ import logging import yaml from pathlib import Path +from pprint import pprint from types import SimpleNamespace as config # Backward compatibility: support CHATGPT_API_KEY as alias for OPENAI_API_KEY @@ -23,6 +24,22 @@ litellm.drop_params = True +async def call_llm(prompt, api_key, model="gpt-4.1", temperature=0): + """Call an LLM to generate a response to a prompt. + + Kept for compatibility with the pageindex 0.2.x SDK utility API. + """ + import openai + + client = openai.AsyncOpenAI(api_key=api_key) + response = await client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + temperature=temperature, + ) + return response.choices[0].message.content.strip() + + def count_tokens(text, model=None): if not text: return 0 @@ -463,12 +480,14 @@ def clean_structure_post(data): clean_structure_post(section) return data -def remove_fields(data, fields=['text']): +def remove_fields(data, fields=['text'], max_len=None): if isinstance(data, dict): - return {k: remove_fields(v, fields) + return {k: remove_fields(v, fields, max_len) for k, v in data.items() if k not in fields} elif isinstance(data, list): - return [remove_fields(item, fields) for item in data] + return [remove_fields(item, fields, max_len) for item in data] + elif isinstance(data, str): + return data[:max_len] + '...' if max_len is not None and len(data) > max_len else data return data def print_toc(tree, indent=0): @@ -684,19 +703,65 @@ def load(self, user_opt=None) -> config: merged = {**self._default_dict, **user_dict} return config(**merged) -def create_node_mapping(tree): - """Create a flat dict mapping node_id to node for quick lookup.""" +def create_node_mapping(tree, include_page_ranges=False, max_page=None): + """Create a mapping of node_id to node for quick lookup. + + The optional page-range arguments are kept for compatibility with the + pageindex 0.2.x SDK utility API. + """ + def get_all_nodes(nodes): + if isinstance(nodes, dict): + return [nodes] + [ + child_node + for child in nodes.get('nodes', []) + for child_node in get_all_nodes(child) + ] + elif isinstance(nodes, list): + return [ + child_node + for item in nodes + for child_node in get_all_nodes(item) + ] + return [] + + all_nodes = get_all_nodes(tree) + + if not include_page_ranges: + return {node["node_id"]: node for node in all_nodes if node.get("node_id")} + mapping = {} - def _traverse(nodes): - for node in nodes: - if node.get('node_id'): - mapping[node['node_id']] = node - if node.get('nodes'): - _traverse(node['nodes']) - _traverse(tree) + for i, node in enumerate(all_nodes): + if not node.get("node_id"): + continue + start_page = node.get("page_index", node.get("start_index")) + if node.get("end_index") is not None: + end_page = node.get("end_index") + elif i + 1 < len(all_nodes): + next_node = all_nodes[i + 1] + end_page = next_node.get("page_index", next_node.get("start_index")) + else: + end_page = max_page + + mapping[node["node_id"]] = { + "node": node, + "start_index": start_page, + "end_index": end_page, + } + return mapping -def print_tree(tree, indent=0): +def print_tree(tree, exclude_fields=None, indent=None): + if exclude_fields is None: + exclude_fields = ['text', 'page_index'] + if isinstance(exclude_fields, int): + indent = exclude_fields + exclude_fields = None + if indent is None and exclude_fields is not None: + cleaned_tree = remove_fields(copy.deepcopy(tree), exclude_fields, max_len=40) + pprint(cleaned_tree, sort_dicts=False, width=100) + return + + indent = indent or 0 for node in tree: summary = node.get('summary') or node.get('prefix_summary', '') summary_str = f" — {summary[:60]}..." if summary else "" @@ -707,4 +772,3 @@ def print_tree(tree, indent=0): def print_wrapped(text, width=100): for line in text.splitlines(): print(textwrap.fill(line, width=width)) - diff --git a/tests/test_errors.py b/tests/test_errors.py index af55e7c57..ef71430db 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,5 +1,6 @@ from pageindex.errors import ( PageIndexError, + PageIndexAPIError, CollectionNotFoundError, DocumentNotFoundError, IndexingError, @@ -9,9 +10,10 @@ def test_all_errors_inherit_from_base(): - for cls in [CollectionNotFoundError, DocumentNotFoundError, IndexingError, CloudAPIError, FileTypeError]: + for cls in [PageIndexAPIError, CollectionNotFoundError, DocumentNotFoundError, IndexingError, CloudAPIError, FileTypeError]: assert issubclass(cls, PageIndexError) assert issubclass(cls, Exception) + assert issubclass(CloudAPIError, PageIndexAPIError) def test_error_message(): @@ -20,7 +22,7 @@ def test_error_message(): def test_catch_base_catches_all(): - for cls in [CollectionNotFoundError, DocumentNotFoundError, IndexingError, CloudAPIError, FileTypeError]: + for cls in [PageIndexAPIError, CollectionNotFoundError, DocumentNotFoundError, IndexingError, CloudAPIError, FileTypeError]: try: raise cls("test") except PageIndexError: diff --git a/tests/test_legacy_sdk_contract.py b/tests/test_legacy_sdk_contract.py new file mode 100644 index 000000000..05353134e --- /dev/null +++ b/tests/test_legacy_sdk_contract.py @@ -0,0 +1,262 @@ +import pytest +import requests + +from pageindex.client import PageIndexAPIError as ClientPageIndexAPIError +from pageindex import PageIndexAPIError, PageIndexClient +from pageindex.client import CloudClient + + +class FakeResponse: + def __init__(self, status_code=200, payload=None, text="ok", lines=None): + self.status_code = status_code + self._payload = payload or {} + self.text = text + self._lines = lines or [] + + def json(self): + return self._payload + + def iter_lines(self): + return iter(self._lines) + + +class StreamingErrorResponse(FakeResponse): + def iter_lines(self): + raise requests.ReadTimeout("stream stalled") + + +def test_legacy_imports_and_initializers(): + positional = PageIndexClient("pi-test") + keyword = PageIndexClient(api_key="pi-test") + cloud = CloudClient(api_key="pi-test") + + assert positional._legacy_cloud_api.api_key == "pi-test" + assert keyword._legacy_cloud_api.api_key == "pi-test" + assert cloud._legacy_cloud_api.api_key == "pi-test" + assert issubclass(PageIndexAPIError, Exception) + assert ClientPageIndexAPIError is PageIndexAPIError + + +def test_legacy_methods_exist(): + client = PageIndexClient("pi-test") + for method_name in [ + "submit_document", + "get_ocr", + "get_tree", + "is_retrieval_ready", + "submit_query", + "get_retrieval", + "chat_completions", + "get_document", + "delete_document", + "list_documents", + "create_folder", + "list_folders", + ]: + assert callable(getattr(client, method_name)) + + +def test_submit_document_uses_legacy_endpoint(monkeypatch, tmp_path): + calls = [] + + def fake_request(method, url, headers=None, files=None, data=None, **kwargs): + calls.append({ + "method": method, + "url": url, + "headers": headers, + "data": data, + "files": files, + "timeout": kwargs.get("timeout"), + }) + return FakeResponse(payload={"doc_id": "doc-1"}) + + monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request) + + pdf = tmp_path / "doc.pdf" + pdf.write_bytes(b"%PDF-1.4") + result = PageIndexClient("pi-test").submit_document( + str(pdf), + mode="mcp", + beta_headers=["block_reference"], + folder_id="folder-1", + ) + + assert result == {"doc_id": "doc-1"} + assert calls[0]["method"] == "POST" + assert calls[0]["url"] == "https://api.pageindex.ai/doc/" + assert calls[0]["headers"] == {"api_key": "pi-test"} + assert calls[0]["timeout"] == 30 + assert calls[0]["data"]["if_retrieval"] is True + assert calls[0]["data"]["mode"] == "mcp" + assert calls[0]["data"]["beta_headers"] == '["block_reference"]' + assert calls[0]["data"]["folder_id"] == "folder-1" + + +def test_get_ocr_and_tree_use_legacy_urls(monkeypatch): + get_calls = [] + + def fake_request(method, url, headers=None, **kwargs): + get_calls.append({"method": method, "url": url, "headers": headers}) + return FakeResponse(payload={"status": "completed", "retrieval_ready": True}) + + monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request) + client = PageIndexClient("pi-test") + + assert client.get_ocr("doc-1", format="page")["status"] == "completed" + assert client.get_tree("doc-1", node_summary=True)["retrieval_ready"] is True + + assert get_calls[0]["method"] == "GET" + assert get_calls[0]["url"] == "https://api.pageindex.ai/doc/doc-1/?type=ocr&format=page" + assert get_calls[1]["url"] == "https://api.pageindex.ai/doc/doc-1/?type=tree&summary=True" + + +def test_get_ocr_rejects_invalid_format(): + with pytest.raises(ValueError, match="Format parameter must be"): + PageIndexClient("pi-test").get_ocr("doc-1", format="bad") + + +def test_submit_query_uses_legacy_payload(monkeypatch): + calls = [] + + def fake_request(method, url, headers=None, json=None, **kwargs): + calls.append({"method": method, "url": url, "headers": headers, "json": json}) + return FakeResponse(payload={"retrieval_id": "ret-1"}) + + monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request) + + result = PageIndexClient("pi-test").submit_query("doc-1", "What changed?", thinking=True) + + assert result == {"retrieval_id": "ret-1"} + assert calls[0]["method"] == "POST" + assert calls[0]["url"] == "https://api.pageindex.ai/retrieval/" + assert calls[0]["json"] == { + "doc_id": "doc-1", + "query": "What changed?", + "thinking": True, + } + + +def test_chat_completions_non_stream_returns_json(monkeypatch): + calls = [] + payload = {"choices": [{"message": {"content": "answer"}}]} + + def fake_request(method, url, headers=None, json=None, stream=False, **kwargs): + calls.append({ + "method": method, + "url": url, + "headers": headers, + "json": json, + "stream": stream, + }) + return FakeResponse(payload=payload) + + monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request) + + result = PageIndexClient("pi-test").chat_completions( + [{"role": "user", "content": "hi"}], + doc_id=["doc-1"], + temperature=0.1, + enable_citations=True, + ) + + assert result == payload + assert calls[0]["method"] == "POST" + assert calls[0]["url"] == "https://api.pageindex.ai/chat/completions/" + assert calls[0]["stream"] is False + assert calls[0]["json"] == { + "messages": [{"role": "user", "content": "hi"}], + "stream": False, + "doc_id": ["doc-1"], + "temperature": 0.1, + "enable_citations": True, + } + + +def test_chat_completions_stream_parses_text_chunks(monkeypatch): + calls = [] + lines = [ + b'data: {"choices":[{"delta":{"content":"hel"}}]}', + b'data: {"choices":[{"delta":{"content":"lo"}}]}', + b"data: [DONE]", + ] + + def fake_request(method, url, **kwargs): + calls.append({"method": method, "url": url, "timeout": kwargs.get("timeout")}) + return FakeResponse(lines=lines) + + monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request) + + chunks = list(PageIndexClient("pi-test").chat_completions( + [{"role": "user", "content": "hi"}], + stream=True, + )) + + assert chunks == ["hel", "lo"] + assert calls[0]["timeout"] == (30, None) + + +def test_chat_completions_stream_metadata_returns_raw_chunks(monkeypatch): + calls = [] + lines = [ + b'data: {"object":"chat.completion.chunk"}', + b"data: [DONE]", + ] + + def fake_request(method, url, **kwargs): + calls.append({"method": method, "url": url, "json": kwargs.get("json")}) + return FakeResponse(lines=lines) + + monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request) + + chunks = list(PageIndexClient("pi-test").chat_completions( + [{"role": "user", "content": "hi"}], + stream=True, + stream_metadata=True, + )) + + assert chunks == [{"object": "chat.completion.chunk"}] + assert "stream_metadata" not in calls[0]["json"] + + +def test_chat_completions_stream_errors_are_pageindex_api_error(monkeypatch): + def fake_request(*args, **kwargs): + return StreamingErrorResponse() + + monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request) + + stream = PageIndexClient("pi-test").chat_completions( + [{"role": "user", "content": "hi"}], + stream=True, + ) + + with pytest.raises(PageIndexAPIError, match="Failed to stream chat completion: stream stalled"): + list(stream) + + +def test_api_errors_are_pageindex_api_error(monkeypatch): + def fake_request(*args, **kwargs): + return FakeResponse(status_code=500, text="server error") + + monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request) + + with pytest.raises(PageIndexAPIError, match="Failed to get document metadata"): + PageIndexClient("pi-test").get_document("doc-1") + + +def test_network_errors_are_wrapped_as_pageindex_api_error(monkeypatch): + def fake_request(*args, **kwargs): + raise requests.Timeout("slow network") + + monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request) + + with pytest.raises(PageIndexAPIError, match="Failed to get document metadata: slow network"): + PageIndexClient("pi-test").get_document("doc-1") + + +def test_list_documents_validates_legacy_pagination(): + client = PageIndexClient("pi-test") + + with pytest.raises(ValueError, match="limit must be between 1 and 100"): + client.list_documents(limit=0) + with pytest.raises(ValueError, match="offset must be non-negative"): + client.list_documents(offset=-1) diff --git a/tests/test_legacy_utils_contract.py b/tests/test_legacy_utils_contract.py new file mode 100644 index 000000000..2abf5fba6 --- /dev/null +++ b/tests/test_legacy_utils_contract.py @@ -0,0 +1,106 @@ +import sys +import asyncio +from types import SimpleNamespace + +from pageindex import utils + + +def test_remove_fields_keeps_legacy_max_len(): + data = { + "title": "A long title", + "text": "hidden", + "nodes": [{"summary": "abcdefghijklmnopqrstuvwxyz"}], + } + + result = utils.remove_fields(data, fields=["text"], max_len=5) + + assert "text" not in result + assert result["title"] == "A lon..." + assert result["nodes"][0]["summary"] == "abcde..." + + +def test_create_node_mapping_keeps_legacy_page_ranges(): + tree = [ + { + "node_id": "0001", + "title": "Root", + "page_index": 1, + "nodes": [ + {"node_id": "0002", "title": "Child", "page_index": 3, "nodes": []}, + ], + } + ] + + plain = utils.create_node_mapping(tree) + ranged = utils.create_node_mapping(tree, include_page_ranges=True, max_page=8) + + assert plain["0001"]["title"] == "Root" + assert ranged["0001"]["start_index"] == 1 + assert ranged["0001"]["end_index"] == 3 + assert ranged["0002"]["start_index"] == 3 + assert ranged["0002"]["end_index"] == 8 + + +def test_create_node_mapping_prefers_existing_start_end_ranges(): + tree = [ + { + "node_id": "0001", + "title": "Root", + "start_index": 1, + "end_index": 10, + "nodes": [ + {"node_id": "0002", "title": "Child", "start_index": 3, "end_index": 5}, + ], + } + ] + + ranged = utils.create_node_mapping(tree, include_page_ranges=True, max_page=12) + + assert ranged["0001"]["start_index"] == 1 + assert ranged["0001"]["end_index"] == 10 + assert ranged["0002"]["start_index"] == 3 + assert ranged["0002"]["end_index"] == 5 + + +def test_print_tree_keeps_legacy_exclude_fields(capsys): + tree = [{"node_id": "0001", "title": "Root", "text": "hidden", "page_index": 1}] + + utils.print_tree(tree) + + out = capsys.readouterr().out + assert "Root" in out + assert "hidden" not in out + assert "page_index" not in out + + +def test_call_llm_keeps_legacy_async_openai_contract(monkeypatch): + calls = [] + + class FakeCompletions: + async def create(self, **kwargs): + calls.append(kwargs) + message = SimpleNamespace(content=" answer ") + choice = SimpleNamespace(message=message) + return SimpleNamespace(choices=[choice]) + + class FakeAsyncOpenAI: + def __init__(self, api_key): + self.api_key = api_key + self.chat = SimpleNamespace(completions=FakeCompletions()) + + fake_openai = SimpleNamespace(AsyncOpenAI=FakeAsyncOpenAI) + monkeypatch.setitem(sys.modules, "openai", fake_openai) + + result = asyncio.run(utils.call_llm( + "hello", + api_key="sk-test", + model="gpt-test", + temperature=0.2, + )) + + assert result == "answer" + assert calls == [{ + "model": "gpt-test", + "messages": [{"role": "user", "content": "hello"}], + "temperature": 0.2, + }] From 37ae54551216035ffa8e59355ade361fe60758f6 Mon Sep 17 00:00:00 2001 From: XinyanZhou Date: Tue, 21 Apr 2026 20:44:40 +0800 Subject: [PATCH 2/3] corner cases fixed --- pageindex/client.py | 2 +- pageindex/cloud_api.py | 12 ++++++------ pageindex/utils.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pageindex/client.py b/pageindex/client.py index fee7a0ea7..a98976889 100644 --- a/pageindex/client.py +++ b/pageindex/client.py @@ -45,7 +45,7 @@ class PageIndexClient: def __init__(self, api_key: str | None = None, model: str = None, retrieve_model: str = None, storage_path: str = None, storage=None, index_config: IndexConfig | dict = None): - if api_key is not None: + if api_key is not None and api_key != "": self._init_cloud(api_key) else: self._init_local(model, retrieve_model, storage_path, storage, index_config) diff --git a/pageindex/cloud_api.py b/pageindex/cloud_api.py index 2962299b5..51d06820c 100644 --- a/pageindex/cloud_api.py +++ b/pageindex/cloud_api.py @@ -44,7 +44,6 @@ def submit_document( beta_headers: list[str] | None = None, folder_id: str | None = None, ) -> dict[str, Any]: - files = {"file": open(file_path, "rb")} data: dict[str, Any] = {"if_retrieval": True} if mode is not None: data["mode"] = mode @@ -53,16 +52,14 @@ def submit_document( if folder_id is not None: data["folder_id"] = folder_id - try: + with open(file_path, "rb") as f: response = self._request( "POST", "/doc/", "Failed to submit document", - files=files, + files={"file": f}, data=data, ) - finally: - files["file"].close() return response.json() @@ -166,7 +163,10 @@ def _stream_chat_response(self, response: requests.Response) -> Iterator[str]: chunk = json.loads(data) except json.JSONDecodeError: continue - content = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "") + choices = chunk.get("choices") or [] + if not choices: + continue + content = choices[0].get("delta", {}).get("content", "") if content: yield content except requests.RequestException as e: diff --git a/pageindex/utils.py b/pageindex/utils.py index 1e0080d21..8cfe1841e 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -767,7 +767,7 @@ def print_tree(tree, exclude_fields=None, indent=None): summary_str = f" — {summary[:60]}..." if summary else "" print(' ' * indent + f"[{node.get('node_id', '?')}] {node.get('title', '')}{summary_str}") if node.get('nodes'): - print_tree(node['nodes'], indent + 1) + print_tree(node['nodes'], exclude_fields=exclude_fields, indent=indent + 1) def print_wrapped(text, width=100): for line in text.splitlines(): From 882b9fb9c57c93a91ff1cc11007a1b2d97091d03 Mon Sep 17 00:00:00 2001 From: saccharin98 Date: Wed, 29 Apr 2026 19:27:16 +0800 Subject: [PATCH 3/3] fix: mock behavior of old SDK --- pageindex/client.py | 4 +++- pageindex/cloud_api.py | 9 +++------ pyproject.toml | 1 + tests/test_legacy_sdk_contract.py | 26 ++++++++++++++++++++++---- 4 files changed, 29 insertions(+), 11 deletions(-) diff --git a/pageindex/client.py b/pageindex/client.py index a98976889..57cf47534 100644 --- a/pageindex/client.py +++ b/pageindex/client.py @@ -42,6 +42,8 @@ class PageIndexClient: # Or use LocalClient / CloudClient for explicit mode selection """ + BASE_URL = "https://api.pageindex.ai" + def __init__(self, api_key: str | None = None, model: str = None, retrieve_model: str = None, storage_path: str = None, storage=None, index_config: IndexConfig | dict = None): @@ -54,7 +56,7 @@ def _init_cloud(self, api_key: str): from .backend.cloud import CloudBackend from .cloud_api import LegacyCloudAPI self._backend = CloudBackend(api_key=api_key) - self._legacy_cloud_api = LegacyCloudAPI(api_key=api_key) + self._legacy_cloud_api = LegacyCloudAPI(api_key=api_key, base_url=self.BASE_URL) def _init_local(self, model: str = None, retrieve_model: str = None, storage_path: str = None, storage=None, diff --git a/pageindex/cloud_api.py b/pageindex/cloud_api.py index 51d06820c..eaf338584 100644 --- a/pageindex/cloud_api.py +++ b/pageindex/cloud_api.py @@ -12,21 +12,19 @@ class LegacyCloudAPI: """Compatibility layer for the pageindex 0.2.x cloud SDK API.""" BASE_URL = "https://api.pageindex.ai" - REQUEST_TIMEOUT = 30 - STREAM_TIMEOUT = (30, None) - def __init__(self, api_key: str): + def __init__(self, api_key: str, base_url: str | None = None): self.api_key = api_key + self.base_url = base_url or self.BASE_URL def _headers(self) -> dict[str, str]: return {"api_key": self.api_key} def _request(self, method: str, path: str, error_prefix: str, **kwargs) -> requests.Response: - kwargs.setdefault("timeout", self.REQUEST_TIMEOUT) try: response = requests.request( method, - f"{self.BASE_URL}{path}", + f"{self.base_url}{path}", headers=self._headers(), **kwargs, ) @@ -138,7 +136,6 @@ def chat_completions( "Failed to get chat completion", json=payload, stream=stream, - timeout=self.STREAM_TIMEOUT if stream else self.REQUEST_TIMEOUT, ) if stream: diff --git a/pyproject.toml b/pyproject.toml index 3a72d8773..9b29851f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ pymupdf = ">=1.26.0" PyPDF2 = ">=3.0.0" python-dotenv = ">=1.0.0" pyyaml = ">=6.0" +openai = ">=1.70.0" openai-agents = ">=0.1.0" requests = ">=2.28.0" httpx = {extras = ["socks"], version = ">=0.28.1"} diff --git a/tests/test_legacy_sdk_contract.py b/tests/test_legacy_sdk_contract.py index 05353134e..521fd31bc 100644 --- a/tests/test_legacy_sdk_contract.py +++ b/tests/test_legacy_sdk_contract.py @@ -56,6 +56,24 @@ def test_legacy_methods_exist(): assert callable(getattr(client, method_name)) +def test_legacy_base_url_can_be_overridden_from_client(monkeypatch): + calls = [] + + def fake_request(method, url, headers=None, **kwargs): + calls.append({"method": method, "url": url, "headers": headers}) + return FakeResponse(payload={"id": "doc-1"}) + + monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request) + monkeypatch.setattr(PageIndexClient, "BASE_URL", "https://staging.pageindex.test") + + result = PageIndexClient("pi-test").get_document("doc-1") + + assert result == {"id": "doc-1"} + assert calls[0]["method"] == "GET" + assert calls[0]["url"] == "https://staging.pageindex.test/doc/doc-1/metadata/" + assert calls[0]["headers"] == {"api_key": "pi-test"} + + def test_submit_document_uses_legacy_endpoint(monkeypatch, tmp_path): calls = [] @@ -66,7 +84,7 @@ def fake_request(method, url, headers=None, files=None, data=None, **kwargs): "headers": headers, "data": data, "files": files, - "timeout": kwargs.get("timeout"), + "kwargs": kwargs, }) return FakeResponse(payload={"doc_id": "doc-1"}) @@ -85,7 +103,7 @@ def fake_request(method, url, headers=None, files=None, data=None, **kwargs): assert calls[0]["method"] == "POST" assert calls[0]["url"] == "https://api.pageindex.ai/doc/" assert calls[0]["headers"] == {"api_key": "pi-test"} - assert calls[0]["timeout"] == 30 + assert "timeout" not in calls[0]["kwargs"] assert calls[0]["data"]["if_retrieval"] is True assert calls[0]["data"]["mode"] == "mcp" assert calls[0]["data"]["beta_headers"] == '["block_reference"]' @@ -181,7 +199,7 @@ def test_chat_completions_stream_parses_text_chunks(monkeypatch): ] def fake_request(method, url, **kwargs): - calls.append({"method": method, "url": url, "timeout": kwargs.get("timeout")}) + calls.append({"method": method, "url": url, "kwargs": kwargs}) return FakeResponse(lines=lines) monkeypatch.setattr("pageindex.cloud_api.requests.request", fake_request) @@ -192,7 +210,7 @@ def fake_request(method, url, **kwargs): )) assert chunks == ["hel", "lo"] - assert calls[0]["timeout"] == (30, None) + assert "timeout" not in calls[0]["kwargs"] def test_chat_completions_stream_metadata_returns_raw_chunks(monkeypatch):