From 7c52791342ec8c03d3e6f5bbafa35c03750ddff0 Mon Sep 17 00:00:00 2001 From: XinyanZhou Date: Mon, 20 Apr 2026 16:48:33 +0800 Subject: [PATCH] feat:allow api_base (base_url) in .env --- README.md | 31 ++++++++ pageindex/client.py | 33 ++++++-- pageindex/index/utils.py | 47 +++++++++++ pageindex/utils.py | 47 ++++++++++- run_pageindex.py | 5 ++ tests/test_client.py | 48 ++++++++++++ tests/test_litellm_api_base.py | 137 +++++++++++++++++++++++++++++++++ 7 files changed, 341 insertions(+), 7 deletions(-) create mode 100644 tests/test_litellm_api_base.py diff --git a/README.md b/README.md index a85fbd01d..5a4080915 100644 --- a/README.md +++ b/README.md @@ -170,6 +170,7 @@ You can customize the processing with additional optional arguments: ``` --model LLM model to use (default: gpt-4o-2024-11-20) +--base-url Base URL for OpenAI-compatible API providers --toc-check-pages Pages to check for table of contents (default: 20) --max-pages-per-node Max pages per node (default: 10) --max-tokens-per-node Max tokens per node (default: 20000) @@ -179,6 +180,36 @@ You can customize the processing with additional optional arguments: ``` +
+OpenAI-compatible API base URL +
+PageIndex uses LiteLLM for local LLM calls. To point requests at an OpenAI-compatible endpoint, set `OPENAI_BASE_URL` or pass `--base-url`: + +```bash +OPENAI_BASE_URL=http://localhost:11434/v1 +python3 run_pageindex.py --pdf_path /path/to/your/document.pdf --model openai/llama3.1 +``` + +You can also pass the endpoint directly: + +```bash +python3 run_pageindex.py --pdf_path /path/to/your/document.pdf \ + --model openai/llama3.1 \ + --base-url http://localhost:11434/v1 +``` + +The same setting is available from the Python SDK in local mode: + +```python +from pageindex import PageIndexClient + +client = PageIndexClient( + model="ollama/llama3.1", + base_url="http://localhost:11434/v1", +) +``` +
+
Markdown support
diff --git a/pageindex/client.py b/pageindex/client.py index 806ebb638..67c6b687e 100644 --- a/pageindex/client.py +++ b/pageindex/client.py @@ -1,5 +1,7 @@ # pageindex/client.py from __future__ import annotations + +import os from pathlib import Path from .collection import Collection from .config import IndexConfig @@ -16,6 +18,14 @@ def _normalize_retrieve_model(model: str) -> str: return f"litellm/{model}" +def _configured_openai_base_url() -> str | None: + return ( + os.getenv("OPENAI_BASE_URL") + or os.getenv("OPENAI_API_BASE") + or os.getenv("CHATGPT_API_BASE") + ) + + class PageIndexClient: """PageIndex client — supports both local and cloud modes. @@ -24,6 +34,7 @@ class PageIndexClient: and local-only params (model, storage_path, index_config, …) are ignored. model: LLM model for indexing (local mode only, default: gpt-4o-2024-11-20). retrieve_model: LLM model for agent QA (local mode only, default: same as model). + base_url: Base URL for OpenAI-compatible LLM endpoints (local mode only). storage_path: Directory for SQLite DB and files (local mode only, default: ./.pageindex). storage: Custom StorageEngine instance (local mode only). index_config: Advanced indexing parameters (local mode only, optional). @@ -41,11 +52,12 @@ class PageIndexClient: def __init__(self, api_key: str = None, model: str = None, retrieve_model: str = None, storage_path: str = None, - storage=None, index_config: IndexConfig | dict = None): + storage=None, index_config: IndexConfig | dict = None, + base_url: str = None): if api_key: self._init_cloud(api_key) else: - self._init_local(model, retrieve_model, storage_path, storage, index_config) + self._init_local(model, retrieve_model, storage_path, storage, index_config, base_url) def _init_cloud(self, api_key: str): from .backend.cloud import CloudBackend @@ -53,7 +65,11 @@ def _init_cloud(self, api_key: str): def _init_local(self, model: str = None, retrieve_model: str = None, storage_path: str = None, storage=None, - index_config: IndexConfig | dict = None): + index_config: IndexConfig | dict = None, + base_url: str = None): + if base_url: + os.environ["OPENAI_BASE_URL"] = base_url + # Build IndexConfig: merge model/retrieve_model with index_config overrides = {} if model: @@ -89,14 +105,17 @@ def _validate_llm_provider(model: str) -> None: """Validate model and check API key via litellm. Warns if key seems missing.""" try: import litellm + from .index.utils import _model_uses_openai_base_url litellm.model_cost_map_url = "" _, provider, _, _ = litellm.get_llm_provider(model=model) except Exception: return + if _configured_openai_base_url() and _model_uses_openai_base_url(model): + return + key = litellm.get_api_key(llm_provider=provider, dynamic_api_key=None) if not key: - import os common_var = f"{provider.upper()}_API_KEY" if not os.getenv(common_var): from .errors import PageIndexError @@ -130,6 +149,7 @@ class LocalClient(PageIndexClient): Args: model: LLM model for indexing (default: gpt-4o-2024-11-20) retrieve_model: LLM model for agent QA (default: same as model) + base_url: Base URL for OpenAI-compatible LLM endpoints. storage_path: Directory for SQLite DB and files (default: ./.pageindex) storage: Custom StorageEngine instance (default: SQLiteStorage) index_config: Advanced indexing parameters. Pass an IndexConfig instance @@ -150,8 +170,9 @@ class LocalClient(PageIndexClient): def __init__(self, model: str = None, retrieve_model: str = None, storage_path: str = None, storage=None, - index_config: IndexConfig | dict = None): - self._init_local(model, retrieve_model, storage_path, storage, index_config) + index_config: IndexConfig | dict = None, + base_url: str = None): + self._init_local(model, retrieve_model, storage_path, storage, index_config, base_url) class CloudClient(PageIndexClient): diff --git a/pageindex/index/utils.py b/pageindex/index/utils.py index f416d6d3d..5b86e97c9 100644 --- a/pageindex/index/utils.py +++ b/pageindex/index/utils.py @@ -1,5 +1,6 @@ import litellm import logging +import os import time import json import copy @@ -10,6 +11,50 @@ logger = logging.getLogger(__name__) +_OPENAI_BASE_URL_PROVIDERS = { + "openai", + "openai_like", + "custom_openai", + "text-completion-openai", + "aiohttp_openai", + "ollama", + "ollama_chat", + "lm_studio", + "hosted_vllm", + "vllm", + "llamafile", + "xinference", + "oobabooga", +} + + +def _normalize_litellm_model(model): + return model.removeprefix("litellm/") if model else model + + +def _model_uses_openai_base_url(model): + model = _normalize_litellm_model(model) + if not model: + return False + if "/" in model: + provider = model.split("/", 1)[0] + return provider in _OPENAI_BASE_URL_PROVIDERS + try: + _, provider, _, _ = litellm.get_llm_provider(model=model) + return provider in _OPENAI_BASE_URL_PROVIDERS + except Exception: + return True + + +def _litellm_api_base_kwargs(model): + api_base = ( + os.getenv("OPENAI_BASE_URL") + or os.getenv("OPENAI_API_BASE") + or os.getenv("CHATGPT_API_BASE") + ) + return {"api_base": api_base} if api_base and _model_uses_openai_base_url(model) else {} + + def count_tokens(text, model=None): if not text: return 0 @@ -28,6 +73,7 @@ def llm_completion(model, prompt, chat_history=None, return_finish_reason=False) model=model, messages=messages, temperature=0, + **_litellm_api_base_kwargs(model), ) content = response.choices[0].message.content if return_finish_reason: @@ -57,6 +103,7 @@ async def llm_acompletion(model, prompt): model=model, messages=messages, temperature=0, + **_litellm_api_base_kwargs(model), ) return response.choices[0].message.content except Exception as e: diff --git a/pageindex/utils.py b/pageindex/utils.py index f00ccf3a7..ea0789a11 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -23,6 +23,50 @@ litellm.drop_params = True +_OPENAI_BASE_URL_PROVIDERS = { + "openai", + "openai_like", + "custom_openai", + "text-completion-openai", + "aiohttp_openai", + "ollama", + "ollama_chat", + "lm_studio", + "hosted_vllm", + "vllm", + "llamafile", + "xinference", + "oobabooga", +} + + +def _normalize_litellm_model(model): + return model.removeprefix("litellm/") if model else model + + +def _model_uses_openai_base_url(model): + model = _normalize_litellm_model(model) + if not model: + return False + if "/" in model: + provider = model.split("/", 1)[0] + return provider in _OPENAI_BASE_URL_PROVIDERS + try: + _, provider, _, _ = litellm.get_llm_provider(model=model) + return provider in _OPENAI_BASE_URL_PROVIDERS + except Exception: + return True + + +def _litellm_api_base_kwargs(model): + api_base = ( + os.getenv("OPENAI_BASE_URL") + or os.getenv("OPENAI_API_BASE") + or os.getenv("CHATGPT_API_BASE") + ) + return {"api_base": api_base} if api_base and _model_uses_openai_base_url(model) else {} + + def count_tokens(text, model=None): if not text: return 0 @@ -40,6 +84,7 @@ def llm_completion(model, prompt, chat_history=None, return_finish_reason=False) model=model, messages=messages, temperature=0, + **_litellm_api_base_kwargs(model), ) content = response.choices[0].message.content if return_finish_reason: @@ -70,6 +115,7 @@ async def llm_acompletion(model, prompt): model=model, messages=messages, temperature=0, + **_litellm_api_base_kwargs(model), ) return response.choices[0].message.content except Exception as e: @@ -707,4 +753,3 @@ def print_tree(tree, indent=0): def print_wrapped(text, width=100): for line in text.splitlines(): print(textwrap.fill(line, width=width)) - diff --git a/run_pageindex.py b/run_pageindex.py index a2d4c3185..9ef27880d 100644 --- a/run_pageindex.py +++ b/run_pageindex.py @@ -12,6 +12,8 @@ parser.add_argument('--md_path', type=str, help='Path to the Markdown file') parser.add_argument('--model', type=str, default=None, help='Model to use') + parser.add_argument('--base-url', '--api-base', dest='base_url', type=str, default=None, + help='Base URL for OpenAI-compatible API providers') parser.add_argument('--toc-check-pages', type=int, default=None, help='Number of pages to check for table of contents (PDF only)') @@ -44,6 +46,9 @@ if args.pdf_path and args.md_path: raise ValueError("Only one of --pdf_path or --md_path can be specified") + if args.base_url: + os.environ["OPENAI_BASE_URL"] = args.base_url + # Build IndexConfig from CLI args (None values use defaults) config_overrides = { k: v for k, v in { diff --git a/tests/test_client.py b/tests/test_client.py index 2c78c92cc..9b29b59e3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,4 +1,6 @@ # tests/sdk/test_client.py +import os + import pytest from pageindex.client import PageIndexClient, LocalClient, CloudClient @@ -49,3 +51,49 @@ class FakeParser: def supported_extensions(self): return [".txt"] def parse(self, file_path, **kwargs): pass client.register_parser(FakeParser()) + + +def test_pageindex_client_base_url_configures_local_openai_compatible_backend(monkeypatch, tmp_path): + monkeypatch.delenv("OPENAI_BASE_URL", raising=False) + monkeypatch.delenv("OLLAMA_API_KEY", raising=False) + + client = PageIndexClient( + model="ollama/llama3.1", + base_url="http://example.test/v1", + storage_path=str(tmp_path / "pi"), + ) + + assert isinstance(client, PageIndexClient) + assert client._backend._model == "ollama/llama3.1" + assert os.environ["OPENAI_BASE_URL"] == "http://example.test/v1" + + +def test_local_client_accepts_base_url(monkeypatch, tmp_path): + monkeypatch.delenv("OPENAI_BASE_URL", raising=False) + monkeypatch.delenv("OLLAMA_API_KEY", raising=False) + + client = LocalClient( + model="ollama/llama3.1", + base_url="http://example.test/v1", + storage_path=str(tmp_path / "pi"), + ) + + assert isinstance(client, PageIndexClient) + assert os.environ["OPENAI_BASE_URL"] == "http://example.test/v1" + + +def test_pageindex_client_accepts_openai_api_base_env_for_local_compatible_backend( + monkeypatch, + tmp_path, +): + monkeypatch.delenv("OPENAI_BASE_URL", raising=False) + monkeypatch.delenv("OLLAMA_API_KEY", raising=False) + monkeypatch.setenv("OPENAI_API_BASE", "http://api-base.example/v1") + + client = PageIndexClient( + model="ollama/llama3.1", + storage_path=str(tmp_path / "pi"), + ) + + assert isinstance(client, PageIndexClient) + assert client._backend._model == "ollama/llama3.1" diff --git a/tests/test_litellm_api_base.py b/tests/test_litellm_api_base.py new file mode 100644 index 000000000..76501274b --- /dev/null +++ b/tests/test_litellm_api_base.py @@ -0,0 +1,137 @@ +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest + + +def _make_response(content="ok"): + message = SimpleNamespace(content=content) + choice = SimpleNamespace(message=message, finish_reason="stop") + return SimpleNamespace(choices=[choice]) + + +def test_index_utils_completion_passes_openai_base_url(monkeypatch): + from pageindex.index import utils + + monkeypatch.setenv("OPENAI_BASE_URL", "http://example.test/v1") + + with patch.object(utils.litellm, "completion", return_value=_make_response()) as mock_completion: + assert utils.llm_completion("gpt-4o", "hello") == "ok" + + assert mock_completion.call_args.kwargs["api_base"] == "http://example.test/v1" + + +def test_index_utils_completion_passes_openai_base_url_for_openai_prefix(monkeypatch): + from pageindex.index import utils + + monkeypatch.setenv("OPENAI_BASE_URL", "http://example.test/v1") + + with patch.object(utils.litellm, "completion", return_value=_make_response()) as mock_completion: + assert utils.llm_completion("openai/gpt-4o", "hello") == "ok" + + assert mock_completion.call_args.kwargs["api_base"] == "http://example.test/v1" + + +@pytest.mark.parametrize("model", [ + "ollama/llama3.1", + "lm_studio/llama3.1", + "hosted_vllm/llama3.1", + "litellm/ollama/llama3.1", +]) +def test_index_utils_completion_passes_openai_base_url_for_compatible_provider_prefixes( + monkeypatch, + model, +): + from pageindex.index import utils + + monkeypatch.setenv("OPENAI_BASE_URL", "http://example.test/v1") + + with patch.object(utils.litellm, "completion", return_value=_make_response()) as mock_completion: + assert utils.llm_completion(model, "hello") == "ok" + + assert mock_completion.call_args.kwargs["api_base"] == "http://example.test/v1" + + +def test_index_utils_completion_omits_openai_base_url_for_anthropic(monkeypatch): + from pageindex.index import utils + + monkeypatch.setenv("OPENAI_BASE_URL", "http://example.test/v1") + + with patch.object(utils.litellm, "completion", return_value=_make_response()) as mock_completion: + assert utils.llm_completion("anthropic/claude-3-5-sonnet-20241022", "hello") == "ok" + + assert "api_base" not in mock_completion.call_args.kwargs + + +def test_index_utils_completion_omits_openai_base_url_for_gemini(monkeypatch): + from pageindex.index import utils + + monkeypatch.setenv("OPENAI_BASE_URL", "http://example.test/v1") + + with patch.object(utils.litellm, "completion", return_value=_make_response()) as mock_completion: + assert utils.llm_completion("gemini/gemini-1.5-pro", "hello") == "ok" + + assert "api_base" not in mock_completion.call_args.kwargs + + +def test_index_utils_completion_omits_api_base_by_default(monkeypatch): + from pageindex.index import utils + + monkeypatch.delenv("OPENAI_BASE_URL", raising=False) + monkeypatch.delenv("OPENAI_API_BASE", raising=False) + monkeypatch.delenv("CHATGPT_API_BASE", raising=False) + + with patch.object(utils.litellm, "completion", return_value=_make_response()) as mock_completion: + assert utils.llm_completion("gpt-4o", "hello") == "ok" + + assert "api_base" not in mock_completion.call_args.kwargs + + +def test_index_utils_acompletion_passes_openai_api_base(monkeypatch): + from pageindex.index import utils + + monkeypatch.delenv("OPENAI_BASE_URL", raising=False) + monkeypatch.setenv("OPENAI_API_BASE", "http://api-base.example/v1") + + with patch.object(utils.litellm, "acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value = _make_response("async ok") + result = asyncio.run(utils.llm_acompletion("gpt-4o", "hello")) + + assert result == "async ok" + assert mock_acompletion.call_args.kwargs["api_base"] == "http://api-base.example/v1" + + +def test_legacy_utils_completion_passes_chatgpt_api_base(monkeypatch): + import pageindex.utils as utils + + monkeypatch.delenv("OPENAI_BASE_URL", raising=False) + monkeypatch.delenv("OPENAI_API_BASE", raising=False) + monkeypatch.setenv("CHATGPT_API_BASE", "http://legacy.example/v1") + + with patch.object(utils.litellm, "completion", return_value=_make_response()) as mock_completion: + assert utils.llm_completion("gpt-4o", "hello") == "ok" + + assert mock_completion.call_args.kwargs["api_base"] == "http://legacy.example/v1" + + +def test_legacy_utils_completion_passes_openai_base_url_for_lm_studio(monkeypatch): + import pageindex.utils as utils + + monkeypatch.setenv("OPENAI_BASE_URL", "http://example.test/v1") + + with patch.object(utils.litellm, "completion", return_value=_make_response()) as mock_completion: + assert utils.llm_completion("lm_studio/llama3.1", "hello") == "ok" + + assert mock_completion.call_args.kwargs["api_base"] == "http://example.test/v1" + + +def test_legacy_utils_completion_omits_openai_base_url_for_anthropic(monkeypatch): + import pageindex.utils as utils + + monkeypatch.setenv("OPENAI_BASE_URL", "http://example.test/v1") + + with patch.object(utils.litellm, "completion", return_value=_make_response()) as mock_completion: + assert utils.llm_completion("anthropic/claude-3-5-sonnet-20241022", "hello") == "ok" + + assert "api_base" not in mock_completion.call_args.kwargs