diff --git a/pageindex/client.py b/pageindex/client.py index 894dab181..d94106ba3 100644 --- a/pageindex/client.py +++ b/pageindex/client.py @@ -32,7 +32,7 @@ class PageIndexClient: For agent-based QA, see examples/agentic_vectorless_rag_demo.py. """ - def __init__(self, api_key: str = None, model: str = None, retrieve_model: str = None, workspace: str = None): + def __init__(self, api_key: str = None, model: str = None, retrieve_model: str = None, workspace: str = None, llm_kwargs: dict | None = None): if api_key: os.environ["OPENAI_API_KEY"] = api_key elif not os.getenv("OPENAI_API_KEY") and os.getenv("CHATGPT_API_KEY"): @@ -43,14 +43,18 @@ def __init__(self, api_key: str = None, model: str = None, retrieve_model: str = overrides["model"] = model if retrieve_model: overrides["retrieve_model"] = retrieve_model + if llm_kwargs is not None: + overrides["llm_kwargs"] = llm_kwargs opt = ConfigLoader().load(overrides or None) self.model = opt.model self.retrieve_model = _normalize_retrieve_model(opt.retrieve_model or self.model) + self.llm_kwargs = opt.llm_kwargs if self.workspace: self.workspace.mkdir(parents=True, exist_ok=True) self.documents = {} if self.workspace: self._load_workspace() + def index(self, file_path: str, mode: str = "auto") -> str: """Index a document. Returns a document_id.""" @@ -71,6 +75,7 @@ def index(self, file_path: str, mode: str = "auto") -> str: result = page_index( doc=file_path, model=self.model, + llm_kwargs=self.llm_kwargs, if_add_node_summary='yes', if_add_node_text='yes', if_add_node_id='yes', @@ -102,6 +107,7 @@ def index(self, file_path: str, mode: str = "auto") -> str: if_add_node_summary='yes', summary_token_threshold=200, model=self.model, + llm_kwargs=self.llm_kwargs, if_add_doc_description='yes', if_add_node_text='yes', if_add_node_id='yes' diff --git a/pageindex/config.yaml b/pageindex/config.yaml index 591fe9331..bd7cd0be8 100644 --- a/pageindex/config.yaml +++ b/pageindex/config.yaml @@ -1,4 +1,5 @@ model: "gpt-4o-2024-11-20" +llm_kwargs: {} # model: "anthropic/claude-sonnet-4-6" retrieve_model: "gpt-5.4" # defaults to `model` if not set toc_check_page_num: 20 diff --git a/pageindex/page_index.py b/pageindex/page_index.py index 9004309fb..dbc0cdfcd 100644 --- a/pageindex/page_index.py +++ b/pageindex/page_index.py @@ -1073,44 +1073,45 @@ def page_index_main(doc, opt=None): if not is_valid_pdf: raise ValueError("Unsupported input type. Expected a PDF file path or BytesIO object.") - print('Parsing PDF...') - page_list = get_page_tokens(doc, model=opt.model) - - logger.info({'total_page_number': len(page_list)}) - logger.info({'total_token': sum([page[1] for page in page_list])}) - - async def page_index_builder(): - structure = await tree_parser(page_list, opt, doc=doc, logger=logger) - if opt.if_add_node_id == 'yes': - write_node_id(structure) - if opt.if_add_node_text == 'yes': - add_node_text(structure, page_list) - if opt.if_add_node_summary == 'yes': - if opt.if_add_node_text == 'no': + with llm_kwargs_scope(getattr(opt, "llm_kwargs", None)): + print('Parsing PDF...') + page_list = get_page_tokens(doc, model=opt.model) + + logger.info({'total_page_number': len(page_list)}) + logger.info({'total_token': sum([page[1] for page in page_list])}) + + async def page_index_builder(): + structure = await tree_parser(page_list, opt, doc=doc, logger=logger) + if opt.if_add_node_id == 'yes': + write_node_id(structure) + if opt.if_add_node_text == 'yes': add_node_text(structure, page_list) - await generate_summaries_for_structure(structure, model=opt.model) - if opt.if_add_node_text == 'no': - remove_structure_text(structure) - if opt.if_add_doc_description == 'yes': - # Create a clean structure without unnecessary fields for description generation - clean_structure = create_clean_structure_for_description(structure) - doc_description = generate_doc_description(clean_structure, model=opt.model) - structure = format_structure(structure, order=['title', 'node_id', 'start_index', 'end_index', 'summary', 'text', 'nodes']) - return { - 'doc_name': get_pdf_name(doc), - 'doc_description': doc_description, - 'structure': structure, - } - structure = format_structure(structure, order=['title', 'node_id', 'start_index', 'end_index', 'summary', 'text', 'nodes']) - return { - 'doc_name': get_pdf_name(doc), - 'structure': structure, - } + if opt.if_add_node_summary == 'yes': + if opt.if_add_node_text == 'no': + add_node_text(structure, page_list) + await generate_summaries_for_structure(structure, model=opt.model) + if opt.if_add_node_text == 'no': + remove_structure_text(structure) + if opt.if_add_doc_description == 'yes': + # Create a clean structure without unnecessary fields for description generation + clean_structure = create_clean_structure_for_description(structure) + doc_description = generate_doc_description(clean_structure, model=opt.model) + structure = format_structure(structure, order=['title', 'node_id', 'start_index', 'end_index', 'summary', 'text', 'nodes']) + return { + 'doc_name': get_pdf_name(doc), + 'doc_description': doc_description, + 'structure': structure, + } + structure = format_structure(structure, order=['title', 'node_id', 'start_index', 'end_index', 'summary', 'text', 'nodes']) + return { + 'doc_name': get_pdf_name(doc), + 'structure': structure, + } - return asyncio.run(page_index_builder()) + return asyncio.run(page_index_builder()) -def page_index(doc, model=None, toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None, +def page_index(doc, model=None, llm_kwargs=None, toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None, if_add_node_id=None, if_add_node_summary=None, if_add_doc_description=None, if_add_node_text=None): user_opt = { @@ -1151,4 +1152,4 @@ def validate_and_truncate_physical_indices(toc_with_page_number, page_list_lengt if truncated_items: print(f"Truncated {len(truncated_items)} TOC items that exceeded document length") - return toc_with_page_number \ No newline at end of file + return toc_with_page_number diff --git a/pageindex/page_index_md.py b/pageindex/page_index_md.py index 5a5971690..79b357d66 100644 --- a/pageindex/page_index_md.py +++ b/pageindex/page_index_md.py @@ -240,64 +240,65 @@ def clean_tree_for_output(tree_nodes): return cleaned_nodes -async def md_to_tree(md_path, if_thinning=False, min_token_threshold=None, if_add_node_summary='no', summary_token_threshold=None, model=None, if_add_doc_description='no', if_add_node_text='no', if_add_node_id='yes'): - with open(md_path, 'r', encoding='utf-8') as f: - markdown_content = f.read() - line_count = markdown_content.count('\n') + 1 +async def md_to_tree(md_path, if_thinning=False, min_token_threshold=None, if_add_node_summary='no', summary_token_threshold=None, model=None, llm_kwargs=None, if_add_doc_description='no', if_add_node_text='no', if_add_node_id='yes'): + with llm_kwargs_scope(llm_kwargs): + with open(md_path, 'r', encoding='utf-8') as f: + markdown_content = f.read() + line_count = markdown_content.count('\n') + 1 - print(f"Extracting nodes from markdown...") - node_list, markdown_lines = extract_nodes_from_markdown(markdown_content) + print(f"Extracting nodes from markdown...") + node_list, markdown_lines = extract_nodes_from_markdown(markdown_content) - print(f"Extracting text content from nodes...") - nodes_with_content = extract_node_text_content(node_list, markdown_lines) - - if if_thinning: - nodes_with_content = update_node_list_with_text_token_count(nodes_with_content, model=model) - print(f"Thinning nodes...") - nodes_with_content = tree_thinning_for_index(nodes_with_content, min_token_threshold, model=model) - - print(f"Building tree from nodes...") - tree_structure = build_tree_from_nodes(nodes_with_content) - - if if_add_node_id == 'yes': - write_node_id(tree_structure) - - print(f"Formatting tree structure...") - - if if_add_node_summary == 'yes': - # Always include text for summary generation - tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'text', 'nodes']) + print(f"Extracting text content from nodes...") + nodes_with_content = extract_node_text_content(node_list, markdown_lines) - print(f"Generating summaries for each node...") - tree_structure = await generate_summaries_for_structure_md(tree_structure, summary_token_threshold=summary_token_threshold, model=model) + if if_thinning: + nodes_with_content = update_node_list_with_text_token_count(nodes_with_content, model=model) + print(f"Thinning nodes...") + nodes_with_content = tree_thinning_for_index(nodes_with_content, min_token_threshold, model=model) - if if_add_node_text == 'no': - # Remove text after summary generation if not requested - tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'nodes']) + print(f"Building tree from nodes...") + tree_structure = build_tree_from_nodes(nodes_with_content) + + if if_add_node_id == 'yes': + write_node_id(tree_structure) + + print(f"Formatting tree structure...") - if if_add_doc_description == 'yes': - print(f"Generating document description...") - # Create a clean structure without unnecessary fields for description generation - clean_structure = create_clean_structure_for_description(tree_structure) - doc_description = generate_doc_description(clean_structure, model=model) - return { - 'doc_name': os.path.splitext(os.path.basename(md_path))[0], - 'doc_description': doc_description, - 'line_count': line_count, - 'structure': tree_structure, - } - else: - # No summaries needed, format based on text preference - if if_add_node_text == 'yes': + if if_add_node_summary == 'yes': + # Always include text for summary generation tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'text', 'nodes']) + + print(f"Generating summaries for each node...") + tree_structure = await generate_summaries_for_structure_md(tree_structure, summary_token_threshold=summary_token_threshold, model=model) + + if if_add_node_text == 'no': + # Remove text after summary generation if not requested + tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'nodes']) + + if if_add_doc_description == 'yes': + print(f"Generating document description...") + # Create a clean structure without unnecessary fields for description generation + clean_structure = create_clean_structure_for_description(tree_structure) + doc_description = generate_doc_description(clean_structure, model=model) + return { + 'doc_name': os.path.splitext(os.path.basename(md_path))[0], + 'doc_description': doc_description, + 'line_count': line_count, + 'structure': tree_structure, + } else: - tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'nodes']) - - return { - 'doc_name': os.path.splitext(os.path.basename(md_path))[0], - 'line_count': line_count, - 'structure': tree_structure, - } + # No summaries needed, format based on text preference + if if_add_node_text == 'yes': + tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'text', 'nodes']) + else: + tree_structure = format_structure(tree_structure, order = ['title', 'node_id', 'line_num', 'summary', 'prefix_summary', 'nodes']) + + return { + 'doc_name': os.path.splitext(os.path.basename(md_path))[0], + 'line_count': line_count, + 'structure': tree_structure, + } if __name__ == "__main__": @@ -339,4 +340,4 @@ async def md_to_tree(md_path, if_thinning=False, min_token_threshold=None, if_ad with open(output_path, 'w', encoding='utf-8') as f: json.dump(tree_structure, f, indent=2, ensure_ascii=False) - print(f"\nTree structure saved to: {output_path}") \ No newline at end of file + print(f"\nTree structure saved to: {output_path}") diff --git a/pageindex/utils.py b/pageindex/utils.py index f00ccf3a7..09398a596 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -10,6 +10,8 @@ import asyncio import pymupdf from io import BytesIO +from contextlib import contextmanager +from contextvars import ContextVar from dotenv import load_dotenv load_dotenv() import logging @@ -22,6 +24,20 @@ os.environ["OPENAI_API_KEY"] = os.getenv("CHATGPT_API_KEY") litellm.drop_params = True +_DEFAULT_LLM_KWARGS = ContextVar("default_llm_kwargs", default={}) + + +@contextmanager +def llm_kwargs_scope(llm_kwargs=None): + token = _DEFAULT_LLM_KWARGS.set(dict(llm_kwargs or {})) + try: + yield + finally: + _DEFAULT_LLM_KWARGS.reset(token) + + +def _merge_llm_kwargs(llm_kwargs=None): + return dict(_DEFAULT_LLM_KWARGS.get() if llm_kwargs is None else llm_kwargs) def count_tokens(text, model=None): if not text: @@ -29,10 +45,11 @@ def count_tokens(text, model=None): return litellm.token_counter(model=model, text=text) -def llm_completion(model, prompt, chat_history=None, return_finish_reason=False): +def llm_completion(model, prompt, chat_history=None, return_finish_reason=False, llm_kwargs=None): if model: model = model.removeprefix("litellm/") max_retries = 10 + llm_kwargs = _merge_llm_kwargs(llm_kwargs) messages = list(chat_history) + [{"role": "user", "content": prompt}] if chat_history else [{"role": "user", "content": prompt}] for i in range(max_retries): try: @@ -40,6 +57,7 @@ def llm_completion(model, prompt, chat_history=None, return_finish_reason=False) model=model, messages=messages, temperature=0, + **llm_kwargs ) content = response.choices[0].message.content if return_finish_reason: @@ -59,10 +77,11 @@ def llm_completion(model, prompt, chat_history=None, return_finish_reason=False) -async def llm_acompletion(model, prompt): +async def llm_acompletion(model, prompt, llm_kwargs=None): if model: model = model.removeprefix("litellm/") max_retries = 10 + llm_kwargs = _merge_llm_kwargs(llm_kwargs) messages = [{"role": "user", "content": prompt}] for i in range(max_retries): try: @@ -70,6 +89,7 @@ async def llm_acompletion(model, prompt): model=model, messages=messages, temperature=0, + **llm_kwargs ) return response.choices[0].message.content except Exception as e: @@ -707,4 +727,3 @@ def print_tree(tree, indent=0): def print_wrapped(text, width=100): for line in text.splitlines(): print(textwrap.fill(line, width=width)) -