diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..4b8e69b --- /dev/null +++ b/.env.example @@ -0,0 +1,7 @@ +OPENAI_API_KEY= +CLASSIFY_MODEL=o4-mini +EXTRACT_MODEL=o4-mini +OCR_MODEL=o4-mini +EXTRACTLY_TIMEOUT_S=40 +EXTRACTLY_MAX_RETRIES=2 +EXTRACTLY_RETRY_BACKOFF_S=1.5 diff --git a/Home.py b/Home.py index b1fe57c..42e1203 100644 --- a/Home.py +++ b/Home.py @@ -1,155 +1,133 @@ -""" -Landing page โ€“ stylish hero header + live stats. -""" +from __future__ import annotations -from datetime import datetime, timezone +from datetime import datetime +from pathlib import Path import streamlit as st -from utils.utils import load_feedback -from dotenv import load_dotenv -from utils.ui_components import inject_logo, inject_common_styles - -# Load API key from .env -load_dotenv(override=True) - -st.set_page_config("Extractly", page_icon="๐Ÿช„", layout="wide") - -# Inject logo and common styles -inject_logo("data/assets/data_reply.svg", height="80px") # Adjust height as needed -inject_common_styles() - -# Theme-adaptive CSS using Streamlit's CSS variables -if "home_css" not in st.session_state: - st.markdown( - """ - - """, - unsafe_allow_html=True, - ) - st.session_state.home_css = True -# Hero header +from extractly.config import load_config +from extractly.domain.run_store import RunStore +from extractly.logging import setup_logging +from extractly.ui.components import inject_branding, inject_global_styles, section_title + + +config = load_config() +setup_logging() + +st.set_page_config(page_title="Extractly", page_icon="โœจ", layout="wide") + +inject_branding(Path("data/assets/data_reply.svg")) +inject_global_styles() + +run_store = RunStore(config.run_store_dir) +runs = run_store.list_runs() + st.markdown( """ -
-

๐Ÿช„ Extractly

-

AI-powered metadata classification & extraction for every document.

-
-""", +
+

Extractly โ€” Document Metadata Extraction Studio

+

Design schemas, classify incoming documents, and extract structured metadata in minutes. Built for + client-ready demos with traceability, exports, and run history baked in.

+
+ """, unsafe_allow_html=True, ) -# Live stats with enhanced confidence metrics -feedback = load_feedback() -today_utc = datetime.now(timezone.utc).date() - -total_docs = len({r["doc_id"] for r in feedback}) -total_fields_corrected = sum(len(r.get("fields_corrected", [])) for r in feedback) - -docs_today = 0 -high_confidence_docs = 0 - -for r in feedback: - try: - if datetime.fromisoformat(r["timestamp"]).date() == today_utc: - docs_today += 1 - - # Count high confidence extractions - if r.get("metadata_extracted"): - non_empty_fields = sum( - bool(v and str(v).strip()) for v in r["metadata_extracted"].values() - ) - total_fields = len(r["metadata_extracted"]) - if total_fields > 0 and (non_empty_fields / total_fields) >= 0.7: - high_confidence_docs += 1 - except Exception: - continue - -# Calculate success rate percentage -success_rate = int((high_confidence_docs / total_docs) * 100) if total_docs > 0 else 0 - -# Metric cards -cols = st.columns(4) -values = [ - ("Docs Today", docs_today, None), - ("Total Docs", total_docs, None), - ("Success Rate", f"{success_rate}%", success_rate), - ("Fields Corrected", total_fields_corrected, None), -] - -for col, (label, val, rate) in zip(cols, values): - # Color coding for success rate - color_style = "" - if label == "Success Rate": - if success_rate >= 80: - color_style = "color: #10b981;" # green - elif success_rate >= 60: - color_style = "color: #f59e0b;" # yellow - else: - color_style = "color: #ef4444;" # red - - col.markdown( - f""" -
-

{val}

-

{label}

+cta_cols = st.columns([1, 1, 2]) +with cta_cols[0]: + st.page_link("pages/1_Schema_Studio.py", label="๐Ÿš€ Build a schema", use_container_width=True) +with cta_cols[1]: + st.page_link("pages/2_Extract.py", label="โšก Run extraction", use_container_width=True) + +st.markdown("
", unsafe_allow_html=True) + +section_title("How it works", "A streamlined workflow your clients understand in seconds.") +steps = st.columns(3) +steps[0].markdown( + """ +
+ Step A โ€” Define a schema +

Design fields, types, and requirements in Schema Studio or import JSON templates.

""", - unsafe_allow_html=True, - ) + unsafe_allow_html=True, +) +steps[1].markdown( + """ +
+ Step B โ€” Upload documents +

Batch PDFs, images, or text. Enable OCR or fast mode depending on fidelity.

+
+ """, + unsafe_allow_html=True, +) +steps[2].markdown( + """ +
+ Step C โ€” Review results +

View JSON, confidence scores, warnings, and exportable tables.

+
+ """, + unsafe_allow_html=True, +) -st.markdown("---") +st.markdown("
", unsafe_allow_html=True) -st.markdown( - '', +section_title("Product highlights", "Purpose-built for metadata extraction teams and demos.") +features = st.columns(3) +features[0].markdown( + """ +
+

Schema Studio

+

Field editor, JSON preview, templates, and validation in one place.

+
+ """, + unsafe_allow_html=True, +) +features[1].markdown( + """ +
+

Extraction Pipeline

+

Classification, extraction, validation, and export with transparent logs.

+
+ """, + unsafe_allow_html=True, +) +features[2].markdown( + """ +
+

Run History

+

Every run is stored locally with artifacts for traceability and demos.

+
+ """, unsafe_allow_html=True, ) + +st.markdown("
", unsafe_allow_html=True) + +section_title("Live workspace snapshot") +col_a, col_b, col_c = st.columns(3) +col_a.metric("Runs stored", len(runs)) +latest_run = runs[0]["started_at"] if runs else "โ€”" +col_b.metric("Latest run", latest_run) +col_c.metric("Schemas ready", len(list(config.schema_dir.glob("*.json")))) + +st.markdown("---") + +section_title("Demo flow") +st.write( + "Use the sample schemas and documents shipped in the repo to walk through a full demo. " + "Start in Schema Studio, then upload a sample document in Extract, and finish in Results." +) + +sample_dir = config.sample_data_dir +if sample_dir.exists(): + samples = [p.name for p in sample_dir.glob("*.txt")] + if samples: + st.caption(f"Sample docs: {', '.join(samples)}") + +st.info( + "Need configuration? Visit Settings to review model choice, retries, and environment checks.", + icon="โš™๏ธ", +) + +st.caption(f"Last refreshed: {datetime.now().strftime('%Y-%m-%d %H:%M')}") diff --git a/README.md b/README.md index 1ba8712..8f7c04f 100644 --- a/README.md +++ b/README.md @@ -1,60 +1,79 @@ -# Extractly +# Extractly โ€” Document Metadata Extraction Studio -## Overview -A plug-and-play Python framework to classify and extract metadata from PDFs and images using a hybrid OCR + GPT pipeline. Supports built-in and custom document types with editable schemas. +Extractly is a Streamlit app for defining document schemas, classifying incoming files, and extracting structured metadata with traceable runs. The current version delivers a client-ready demo experience with a clean information architecture and modular codebase. ## Features -- **Modular pipeline**: Preprocessing, classification, extraction, validation, export. -- **Hybrid OCR + LLM**: LLM (configurable) for robust extraction. -- **Custom schemas**: Define new document types and field descriptors via JSON in `schemas/` or input in UI. -- **Configurable models**: Choose between `gpt-o4-mini`, `gpt-o3`, `gpt-4o`, etc. -- **Validation & context**: Extracted values accompanied by context snippets and LLM reasoning. -- **CLI & UI**: Streamlit app and CLI interface for automation and demos. -- **Caching & metrics**: Preprocess results cached to speed up repeated runs; timings logged per step. -- **Export options**: JSON, CSV, Excel downloads. - -## Getting Started -1. **Clone repo** - ```bash - git clone https://github.com/yourorg/universal_extractor.git - cd universal_extractor - ``` -2. **Create `.env` with your OpenAI key** - ```bash - echo "OPENAI_API_KEY=your_api_key_here" > .env - ``` -3. **Run locally** - ```bash - pip install -r requirements.txt - streamlit run app.py - ``` -4. **Or with Docker** - ```bash - docker build -t universal_extractor . - docker run -e OPENAI_API_KEY=$OPENAI_API_KEY -p 8501:8501 universal_extractor - ``` - -## Directory Structure -- `app.py`: Streamlit frontend -- `cli.py`: Command-line interface -- `schemas/`: JSON schema files for built-in and custom types -- `utils/`: Core modules (preprocess, schema management, OpenAI client, classification, extraction) - - -## Extending -- Add new schema JSON to `schemas/`, restart app, it appears automatically. -- Implement additional OCR engines by extending `utils/preprocess.py`. -- Swap or fine-tune models via `utils/openai_client.py`. - - -# Future Work - -## Code Improvements -- Metaprompting sui tipi di documenti attesi -- Aggiustare cards in Home (feedbacks are scuffed overall) -- Rivedere la confidence (mettere nel json dei controlli extra tipo rgex o agenti ) -- TO DISCUSS: Improve from errors, saving last N feedbacks for a give doc type (only incorrect ones?) and pass in context - -## Business Improvements -- Tracciare kpi -- Fare sides fighe in cui dici flusso ed โ€œagentiโ€ +- **Schema Studio**: create, edit, validate, and export schemas with live JSON preview. +- **Extraction pipeline**: classify โ†’ extract โ†’ validate with confidence scores and OCR support. +- **Run history**: every run is stored locally with outputs, warnings, and errors. +- **Results workspace**: table view, JSON view, per-field confidence, and CSV/JSON exports. +- **Configurable LLM usage**: timeouts, retries, model selection via `.env`. + +## Quickstart +```bash +# 1) Install dependencies +pip install -e . + +# 2) Set API key +cp .env.example .env +# edit .env and set OPENAI_API_KEY + +# 3) Run the app +streamlit run Home.py +``` + +## Environment Variables +``` +OPENAI_API_KEY=your_key_here +CLASSIFY_MODEL=o4-mini +EXTRACT_MODEL=o4-mini +OCR_MODEL=o4-mini +EXTRACTLY_TIMEOUT_S=40 +EXTRACTLY_MAX_RETRIES=2 +EXTRACTLY_RETRY_BACKOFF_S=1.5 +``` + +## Information Architecture +- **Home**: landing page + demo entry points. +- **Schema Studio**: schema creation, validation, templates, import/export. +- **Extract**: upload files and run classification + extraction. +- **Results**: browse run history and export data. +- **Settings**: environment checks and config visibility. + +## Repo Structure +``` +Home.py +pages/ + 1_Schema_Studio.py + 2_Extract.py + 3_Results.py + 4_Settings.py +src/extractly/ + config.py + logging.py + domain/ + integrations/ + pipeline/ + ui/ +``` + +## Demo Script (5 minutes) +1. Open **Home** and describe the workflow. +2. Navigate to **Schema Studio** and load the โ€œInvoice Liteโ€ template. +3. Save the schema, then go to **Extract**. +4. Upload `data/sample_docs/sample_invoice.txt` and run extraction. +5. Open **Results**, select the run, and export JSON/CSV. + +## Samples +- Sample docs are in `data/sample_docs/`. +- Example schemas live in `schemas/`. + +## Tests +```bash +pytest +``` + +## Notes +- Runs are stored in `./runs` (configurable). +- Configure models and timeouts via `.env`. +- The app relies on Streamlit and the OpenAI API. No secrets are committed. diff --git a/data/sample_docs/sample_invoice.txt b/data/sample_docs/sample_invoice.txt new file mode 100644 index 0000000..6100635 --- /dev/null +++ b/data/sample_docs/sample_invoice.txt @@ -0,0 +1,8 @@ +Invoice +Invoice Number: INV-2025-0042 +Invoice Date: 2025-01-15 +Supplier: Northwind Analytics LLC +Client: Contoso Retail +Total Amount: 1,248.50 EUR +VAT: 22% +Payment Terms: Net 30 diff --git a/data/sample_docs/sample_resume.txt b/data/sample_docs/sample_resume.txt new file mode 100644 index 0000000..1aac74a --- /dev/null +++ b/data/sample_docs/sample_resume.txt @@ -0,0 +1,6 @@ +Resume +Name: Jordan Lee +Role: Product Engineer +Experience: 6 years +Location: Milan, IT +Skills: Streamlit, Python, OCR, LLM pipelines diff --git a/pages/1Schemas Builder.py b/pages/1Schemas Builder.py deleted file mode 100644 index 495198d..0000000 --- a/pages/1Schemas Builder.py +++ /dev/null @@ -1,218 +0,0 @@ -# sourcery skip: swap-if-else-branches, use-named-expression -import pandas as pd -import streamlit as st -import json -from pathlib import Path -from src.schema_manager import SchemaManager -from utils.ui_components import inject_logo, inject_common_styles - -# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ constants / init โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -# Get the project root directory (parent of the pages directory) -PROJECT_ROOT = Path(__file__).parent.parent -DATA_DIR = PROJECT_ROOT / "schemas" -CUSTOM_JSON = DATA_DIR / "custom_schemas.json" -DATA_DIR.mkdir(exist_ok=True) - -SM = SchemaManager() # load built-ins -custom_data = {} -if CUSTOM_JSON.exists(): # preload customs - custom_data = json.loads(CUSTOM_JSON.read_text()) - SM.add_custom(custom_data) - -# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ session defaults โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -st.session_state.setdefault("editing_doc", None) -st.session_state.setdefault("field_data", [{"name": "", "description": ""}]) -st.session_state.setdefault("rename_open", False) -st.session_state.setdefault("reset_now", False) -st.session_state.setdefault("schema_saved", False) # New flag for handling save refresh - -# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ handle one-shot reset flag BEFORE widgets โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -if st.session_state.reset_now: - st.session_state.pop("doc_name_input", None) # clear text box value - st.session_state.editing_doc = None - st.session_state.field_data = [{"name": "", "description": ""}] - st.session_state.reset_now = False # consume flag - -# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ handle schema save refresh โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -if st.session_state.schema_saved: - st.session_state.schema_saved = False # consume flag - st.session_state.pop("doc_name_input", None) # clear text box value - st.session_state.pop("field_editor", None) # clear data editor - st.session_state.editing_doc = None - st.session_state.field_data = [{"name": "", "description": ""}] - st.rerun() # This will refresh the page and update the sidebar -# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ page meta โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -st.set_page_config("Schema Builder", "๐Ÿงฌ", layout="wide") -st.title("๐Ÿงฌ Schema Builder") - -inject_logo("data/assets/data_reply.svg", height="80px") # Adjust height as needed -inject_common_styles() - -# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ sidebar: schema list & actions โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -with st.sidebar: - st.subheader("๐Ÿ“ฆ Schemas") - - for dt in SM.get_types(): - col_name, col_del = st.columns([3, 1]) - with col_name: - if st.button(dt, key=f"sel_{dt}"): - st.session_state.field_data = SM.get(dt) or [ - {"name": "", "description": ""} - ] - st.session_state.editing_doc = dt - st.session_state.rename_open = False - st.rerun() - - with col_del: - if st.button("๐Ÿ—‘๏ธ", key=f"del_{dt}", help="Delete"): - SM.delete(dt) - custom_data.pop(dt, None) - CUSTOM_JSON.write_text( - json.dumps(custom_data, indent=2, ensure_ascii=False) - ) - if st.session_state.editing_doc == dt: - st.session_state.editing_doc = None - st.session_state.rename_open = False - st.session_state.field_data = [{"name": "", "description": ""}] - st.rerun() - - # โ”€โ”€ Rename + Clear all (same row) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - col_ren, col_clear = st.columns(2) - - # โ”€โ”€ Rename current โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - with col_ren: - if st.button("โœ๏ธ Rename current", disabled=not st.session_state.editing_doc): - st.session_state.rename_open = not st.session_state.rename_open - - if st.session_state.rename_open and st.session_state.editing_doc: - with st.expander("Rename schema", expanded=True): - new_name = st.text_input( - "New doc-type name", - value=st.session_state.editing_doc, - key="rename_input", - ) - - # โœ… sidebar-safe: no st.columns(), just stacked buttons - if st.button("โœ”๏ธ Confirm rename"): - old = st.session_state.editing_doc - SM.rename(old, new_name) - custom_data[new_name] = custom_data.pop(old) - CUSTOM_JSON.write_text( - json.dumps(custom_data, indent=2, ensure_ascii=False) - ) - st.session_state.editing_doc = new_name - st.session_state.field_data = SM.get(new_name) - st.session_state.rename_open = False - st.rerun() - - if st.button("โœ–๏ธ Cancel"): - st.session_state.rename_open = False - - # โ”€โ”€ Clear all โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - with col_clear: - if st.button("๐Ÿšฎ Clear all CUSTOM schemas"): - custom_data.clear() - CUSTOM_JSON.unlink(missing_ok=True) - st.session_state.editing_doc = None - st.session_state.rename_open = False - st.session_state.field_data = [{"name": "", "description": ""}] - st.rerun() - - # โ”€โ”€ JSON import โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - st.subheader("โ‡ก Import JSON schema file") - up = st.file_uploader( - label=" ", - type=["json"], - label_visibility="collapsed", - ) - if up: - try: - data = json.load(up) - SM.add_custom(data) - custom_data.update(data) - CUSTOM_JSON.write_text( - json.dumps(custom_data, indent=2, ensure_ascii=False) - ) - st.success(f"Imported {len(data)} doc-types.") - st.rerun() - except Exception as e: - st.error(f"Bad JSON: {e}") - -# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ MAIN AREA โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -doc_name = st.text_input( - "Document type name (create new or edit existing)", - value=st.session_state.editing_doc or "", - key="doc_name_input", -) - -if not doc_name: - st.info("Enter a document type name to begin.") - st.stop() - -# โ”€โ”€ keep or clear table based on doc_name validity โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -if doc_name != st.session_state.editing_doc: - if doc_name in SM.get_types(): # switch to known schema - st.session_state.editing_doc = doc_name - st.session_state.field_data = SM.get(doc_name) or [ - {"name": "", "description": ""} - ] - else: # unknown name โ†’ blank table - st.session_state.editing_doc = None - st.session_state.field_data = [{"name": "", "description": ""}] - -st.subheader(f"Fields for **{doc_name}**") - -schema_desc = st.text_area( - "High-level description", - value=SM.get_description(doc_name), - placeholder="e.g. Italian electronic invoice issued by suppliersโ€ฆ", -) - -raw_table = st.data_editor( - st.session_state.field_data, # static snapshot - num_rows="dynamic", - width="stretch", - key="field_editor", -) - -# normalise โ†“ -table_rows = ( - raw_table.fillna("").to_dict(orient="records") - if isinstance(raw_table, pd.DataFrame) - else raw_table -) -table_rows = [ - { - "name": r.get("name", " ").strip() if r.get("name", " ") else "", - "description": r.get("description", r.get("description ", " ")).strip() - if r.get("description", r.get("description ", " ")) - else "", - } - for r in table_rows -] - - -# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ save / reset buttons โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -col_save, col_reset = st.columns(2) - -with col_save: - if st.button("๐Ÿ’พ Save schema"): - clean = [row for row in table_rows if row["name"]] - if not clean: - st.error("Must have at least one field.") - else: - payload = {"description": schema_desc.strip(), "fields": clean} - SM.add_custom({doc_name: payload}) - custom_data[doc_name] = payload # โ† store whole dict - CUSTOM_JSON.write_text( - json.dumps(custom_data, indent=2, ensure_ascii=False) - ) - st.success(f"Saved {len(clean)} fields for '{doc_name}'.") - # Set the flag to trigger page refresh and cleanup - st.session_state.schema_saved = True - # Trigger immediate rerun to refresh the page - st.rerun() -with col_reset: - if st.button("โ†บ Reset"): - st.session_state.reset_now = True - st.rerun() diff --git a/pages/1_Schema_Studio.py b/pages/1_Schema_Studio.py new file mode 100644 index 0000000..32b1bbd --- /dev/null +++ b/pages/1_Schema_Studio.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import json +from pathlib import Path +import streamlit as st + +from extractly.config import load_config +from extractly.domain.schema_store import SchemaStore, schemas_to_table, table_to_schema +from extractly.domain.validation import validate_schema +from extractly.logging import setup_logging +from extractly.ui.components import inject_branding, inject_global_styles, section_title + + +config = load_config() +setup_logging() +store = SchemaStore(config.schema_dir) + +st.set_page_config(page_title="Schema Studio", page_icon="๐Ÿงฌ", layout="wide") + +inject_branding(Path("data/assets/data_reply.svg")) +inject_global_styles() + +st.title("๐Ÿงฌ Schema Studio") +st.caption("Design, validate, and version schemas used in extraction runs.") + +TEMPLATES = { + "Invoice Lite": { + "description": "Basic invoice metadata for demo flows.", + "fields": [ + {"name": "Invoice Number", "type": "string", "required": True}, + {"name": "Invoice Date", "type": "date", "required": True}, + {"name": "Supplier", "type": "string", "required": True}, + {"name": "Total Amount", "type": "number", "required": True}, + ], + }, + "Resume Snapshot": { + "description": "Lightweight resume extraction fields.", + "fields": [ + {"name": "Candidate Name", "type": "string", "required": True}, + {"name": "Primary Role", "type": "string", "required": True}, + {"name": "Years of Experience", "type": "integer"}, + {"name": "Location", "type": "string"}, + ], + }, +} + +schemas = store.list_schemas() + +with st.sidebar: + st.subheader("Schemas") + if schemas: + selected_name = st.selectbox( + "Choose schema", + options=[schema.name for schema in schemas], + index=0, + ) + else: + selected_name = None + st.caption("No schemas yet. Create one below.") + st.markdown("---") + st.subheader("Templates") + template_choice = st.selectbox("Load template", options=["โ€”"] + list(TEMPLATES)) + if st.button("Use template", use_container_width=True): + if template_choice != "โ€”": + template_payload = TEMPLATES[template_choice] + st.session_state["schema_payload"] = { + "name": template_choice, + "description": template_payload["description"], + "rows": template_payload["fields"], + } + st.rerun() + + st.markdown("---") + st.subheader("Import / Export") + upload = st.file_uploader("Import schema JSON", type=["json"]) + if upload: + try: + payload = json.load(upload) + imported = store.import_payload(payload) + st.success(f"Imported {len(imported)} schema(s).") + st.rerun() + except Exception as exc: + st.error(f"Import failed: {exc}") + + if selected_name: + schema = store.get_schema(selected_name) + if schema: + export_json = store.export_schema(schema) + st.download_button( + "Download schema JSON", + data=export_json, + file_name=f"{selected_name}.json", + mime="application/json", + use_container_width=True, + ) + + if selected_name and st.button("Delete schema", type="secondary", use_container_width=True): + store.delete_schema(selected_name) + st.success("Schema deleted.") + st.rerun() + +if schemas and selected_name: + active_schema = store.get_schema(selected_name) +else: + active_schema = None + +payload = st.session_state.get( + "schema_payload", + { + "name": active_schema.name if active_schema else "", + "description": active_schema.description if active_schema else "", + "rows": schemas_to_table(active_schema) if active_schema else [], + }, +) + +section_title("Schema editor") +name = st.text_input("Schema name", value=payload.get("name", "")) +description = st.text_area("Description", value=payload.get("description", "")) + +rows = payload.get("rows", []) + +data = st.data_editor( + rows, + num_rows="dynamic", + use_container_width=True, + column_config={ + "name": st.column_config.TextColumn("Field name", required=True), + "type": st.column_config.SelectboxColumn( + "Type", + options=["string", "number", "integer", "boolean", "date", "enum", "object", "array"], + required=True, + ), + "required": st.column_config.CheckboxColumn("Required"), + "description": st.column_config.TextColumn("Description"), + "example": st.column_config.TextColumn("Example"), + "enum": st.column_config.TextColumn("Enum values (comma-separated)"), + }, + key="schema_editor", +) + +schema = table_to_schema(name=name, description=description, rows=data) +validation = validate_schema(schema) + +col_a, col_b = st.columns([1, 1]) +with col_a: + if st.button("๐Ÿ’พ Save schema", use_container_width=True): + result = store.save_schema(schema) + if result.is_valid: + st.success("Schema saved.") + st.session_state.pop("schema_payload", None) + st.rerun() + else: + st.error("Schema failed validation. Fix errors below.") + +with col_b: + if st.button("Reset", use_container_width=True, type="secondary"): + st.session_state.pop("schema_payload", None) + st.rerun() + +section_title("Validation") +if validation.errors: + st.error("\n".join(validation.errors)) +else: + st.success("Schema is valid and ready for extraction.") + +if validation.warnings: + st.warning("\n".join(validation.warnings)) + +section_title("Live JSON preview") +st.code(store.export_schema(schema), language="json") diff --git a/pages/2Extraction.py b/pages/2Extraction.py deleted file mode 100644 index 8332509..0000000 --- a/pages/2Extraction.py +++ /dev/null @@ -1,605 +0,0 @@ -""" -Batch Extraction โ€“ multi-doc classification, OCR, extraction & correction UI. -""" - -from __future__ import annotations -import contextlib -import io -import json -from datetime import datetime, timezone -import streamlit as st -import pandas as pd - -from src.ocr_engine import run_ocr -from utils.preprocess import preprocess -from src.schema_manager import SchemaManager -from src.classifier import classify -from src.extractor import extract -from utils.ui_components import inject_logo, inject_common_styles -from utils.utils import ( - generate_doc_id, - load_feedback, - upsert_feedback, - diff_fields, -) - -# โ”€โ”€โ”€โ”€โ”€ page & globals โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -st.set_page_config("Extraction", page_icon="๐Ÿ”", layout="wide") -st.title("๐Ÿ” Extraction") -schema_mgr = SchemaManager() - -inject_logo("data/assets/data_reply.svg", height="80px") # Adjust height as needed -inject_common_styles() - -# โ”€โ”€โ”€โ”€โ”€ CSS for image hover zoom โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -st.markdown( - """ - -""", - unsafe_allow_html=True, -) - -# โ”€โ”€โ”€โ”€โ”€ sidebar: settings โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -with st.sidebar: - st.header("โš™๏ธ Options") - run_ocr_checkbox = st.checkbox("Run OCR first", value=False) - calc_conf = st.checkbox("Compute confidences", value=False) - conf_threshold = st.slider( - "Confidence threshold (%)", - 0, - 100, - 70, - help="Documents below this confidence will be flagged", - ) - # NEW: System Prompts Section - st.header("๐Ÿค– System Prompts") - - # Updated default prompts - DEFAULT_CLASSIFIER_PROMPT = """You are an expert document classifier specialized in analyzing business and legal documents. - Your task is to classify the document type based on visual layout, text content, headers, logos, and structural elements. Consider: - - Document formatting and layout patterns - - Official headers, letterheads, and logos - - Specific terminology and field labels - - Regulatory compliance markers - - Standard document structures - - Respond with only the most accurate document type from the provided list. If unsure, choose "Unknown". - """ - - DEFAULT_EXTRACTOR_PROMPT = """You are a precise metadata extraction specialist. Your task is to extract specific field values from documents with high accuracy. - - Instructions: - 1. Analyze the document image carefully for text, tables, and structured data - 2. Extract only the exact values for the requested fields - 3. Use OCR context when provided to improve accuracy - 4. If a field is not clearly visible or readable, return null - 5. Maintain original formatting for dates, numbers, and codes - 6. For confidence scores, rate 0.0-1.0 based on text clarity and certainty - - Return valid JSON with three sections: - - "metadata": field values as key-value pairs - - "snippets": supporting text evidence for each field - - "confidence": confidence scores (0.0-1.0) for each extraction - """ - - with st.expander("๐Ÿ“ Edit Prompts", expanded=False): - st.subheader("Classification Prompt") - classifier_prompt = st.text_area( - "System prompt for document classification:", - value=st.session_state.get("classifier_prompt", DEFAULT_CLASSIFIER_PROMPT), - height=150, - help="This prompt guides how the AI classifies document types", - key="classifier_prompt_input", - ) - - st.subheader("Extraction Prompt") - extractor_prompt = st.text_area( - "System prompt for metadata extraction:", - value=st.session_state.get("extractor_prompt", DEFAULT_EXTRACTOR_PROMPT), - height=200, - help="This prompt guides how the AI extracts metadata fields", - key="extractor_prompt_input", - ) - - col1, col2 = st.columns(2) - with col1: - if st.button("๐Ÿ’พ Save Prompts"): - st.session_state["classifier_prompt"] = classifier_prompt - st.session_state["extractor_prompt"] = extractor_prompt - st.success("Prompts saved!") - - with col2: - if st.button("๐Ÿ”„ Reset to Default"): - st.session_state["classifier_prompt"] = DEFAULT_CLASSIFIER_PROMPT - st.session_state["extractor_prompt"] = DEFAULT_EXTRACTOR_PROMPT - st.rerun() - - # โœ‚๏ธŽโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ OCR-preview toggle โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โœ‚๏ธŽ - if run_ocr_checkbox and "ocr_map" in st.session_state: - with st.sidebar.expander("๐Ÿ” Preview raw OCR text", expanded=False): - for name, ocr_txt in st.session_state["ocr_map"].items(): - st.markdown(f"**{name}**") - st.text_area( - label=" ", - value=ocr_txt[:10_000], - height=200, - key=f"ocr_{name}", - label_visibility="collapsed", - ) - -# โ”€โ”€โ”€โ”€โ”€ file uploader โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -files = st.file_uploader( - "Upload PDFs or images", - type=["pdf", "png", "jpg", "jpeg"], - accept_multiple_files=True, -) - -if not files: - st.info("Awaiting uploads โ€ฆ") - st.stop() - -# โ”€โ”€โ”€โ”€โ”€ load past corrections by filename โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -past = load_feedback() -corrected_map = {} -past_metadata_map = {} -for r in past: - if r.get("file_name"): - corrected_map[r["file_name"]] = r["doc_type"] - if r.get("metadata_corrected") and r["metadata_corrected"] != "{}": - with contextlib.suppress(json.JSONDecodeError): - past_metadata_map[r["file_name"]] = json.loads(r["metadata_corrected"]) - -# โ”€โ”€โ”€โ”€โ”€ ensure session state โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -st.session_state.setdefault("doc_rows", []) -st.session_state.setdefault("extracted", False) -doc_rows: list[dict] = st.session_state["doc_rows"] - -# โ”€โ”€โ”€โ”€โ”€ run buttons โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -c1, c2, c3 = st.columns(3) -seq_clicked = c1.button("๐Ÿš€ Classify & Extract", width="stretch") -classify_clicked = c2.button("โ–ถ๏ธ Classify Only", width="stretch") -extract_clicked = c3.button("โšก Extract All", disabled=not doc_rows, width="stretch") - -if st.button("๐Ÿ”„ Start over (keep uploads)", key="reset_all", type="secondary"): - for k in ("doc_rows", "extracted", "ocr_map", "ocr_preview"): - st.session_state.pop(k, None) - st.toast("Workspace cleared โ€“ you can run the pipeline again.", icon="๐Ÿ”„") - st.rerun() - -st.markdown("
", unsafe_allow_html=True) - -# โ”€โ”€โ”€โ”€โ”€ 1๏ธโƒฃ on-click: build doc_rows โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -if classify_clicked or seq_clicked: - st.session_state["extracted"] = False - st.session_state["ocr_preview"] = "" - st.session_state["doc_rows"] = [] - prog = st.progress(0.0, "Classifyingโ€ฆ") - cls_results = [] - - # CHANGE: Add Unknown/Other to classification candidates - classification_candidates = schema_mgr.get_types() + ["Unknown", "Other"] - - for i, up in enumerate(files, start=1): - images = preprocess(up) - fname = up.name - doc_id = generate_doc_id(up) or f"id_{i}" - ocr_txt = "" - - if run_ocr_checkbox: - ocr_txt = run_ocr(images) # Removed engine parameter - LLM only - st.session_state.setdefault("ocr_map", {}) - st.session_state["ocr_map"][fname] = ocr_txt - if ( - "ocr_preview" in st.session_state - and not st.session_state["ocr_preview"] - ): - st.session_state["ocr_preview"] = ocr_txt - - with io.BytesIO() as buf: - img = images[0].copy() - img.thumbnail((140, 140)) - img.save(buf, format="PNG") - thumb = buf.getvalue() - - # ---------- build rows with enhanced classification ------------------------- - if fname in corrected_map: - doc_type = corrected_map[fname] - confidence = None - reasoning = "retrieved from past correction" - else: - cls_resp = classify( - images, - classification_candidates, - use_confidence=calc_conf, - n_votes=5, - system_prompt=st.session_state.get( - "classifier_prompt", DEFAULT_CLASSIFIER_PROMPT - ), - ) - doc_type = cls_resp["doc_type"] - confidence = cls_resp.get("confidence") - reasoning = cls_resp.get("reasoning", "") - - cls_results.append( - { - "file_name": fname, - "doc_id": doc_id, - "thumb": thumb, - "images": images, - "detected": doc_type, - "final_type": doc_type, - "reasoning": reasoning, - "fields": None, - "fields_corrected": None, - "confidence": confidence, - } - ) - prog.progress(i / len(files), f"Classifying: {fname}") - - st.session_state["doc_rows"] = cls_results - doc_rows = st.session_state["doc_rows"] - st.toast("Classification finished โ€“ adjust below if needed.", icon="โœ…") - if not seq_clicked: - st.rerun() - -# โ”€โ”€โ”€โ”€โ”€ 2๏ธโƒฃ render classification review โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -if doc_rows and not st.session_state.get("extracted"): - st.subheader("๐Ÿ“‘ Review detected document types") - - for idx, row in enumerate(doc_rows): - # CHANGE: Enhanced confidence display and unrecognized document handling - confidence_val = row.get("confidence") - conf_pct = int(confidence_val * 100) if confidence_val is not None else None - is_unrecognized = row["final_type"] in ["Unknown", "Other"] - is_low_confidence = conf_pct is not None and conf_pct < conf_threshold - - # Color-coded title based on status - if is_unrecognized or is_low_confidence: - status_emoji = "โš ๏ธ" - title_color = "๐Ÿ”ด" - else: - status_emoji = "โœ…" - title_color = "๐ŸŸข" - - # Build title with conditional confidence display - if calc_conf and conf_pct is not None: - title = f"{status_emoji} {row['file_name']} โ€” {row['final_type']} {title_color} {conf_pct}%" - else: - title = f"{status_emoji} {row['file_name']} โ€” {row['final_type']}" - - with st.expander( - title, - expanded=is_unrecognized - or is_low_confidence, # Auto-expand problematic ones - ): - col1, col2, col3 = st.columns([1, 2, 1]) - - with col1: - st.image(row["thumb"], width=140) - - with col2: - # CHANGE: Enhanced type selection with Unknown/Other - choices = schema_mgr.get_types() + ["Unknown", "Other"] - if row["final_type"] not in choices: - choices.insert(0, row["final_type"]) - - sel = st.selectbox( - "Document type", - choices, - index=choices.index(row["final_type"]), - key=f"type_{row['doc_id']}", - ) - st.session_state["doc_rows"][idx]["final_type"] = sel - - # Enhanced status display - if is_unrecognized: - st.error("โŒ Unrecognized document type") - elif is_low_confidence: - st.warning(f"โš ๏ธ Low confidence ({conf_pct}% < {conf_threshold}%)") - with col3: - # CHANGE: Enhanced confidence display in % - only show if calc_conf is enabled - if calc_conf and conf_pct is not None: - if conf_pct >= 80: - st.success(f"๐ŸŸข {conf_pct}%") - elif conf_pct >= 60: - st.warning(f"๐ŸŸก {conf_pct}%") - else: - st.error(f"๐Ÿ”ด {conf_pct}%") - - st.caption(f"Threshold: {conf_threshold}%") - elif calc_conf: - st.caption("Confidence not computed") - - if st.button("๐Ÿ’พ Save type corrections"): - for row in doc_rows: - upsert_feedback( - { - "doc_id": row["doc_id"], - "file_name": row["file_name"], - "doc_type": row["final_type"], - "metadata_extracted": "{}", - "metadata_corrected": "{}", - "timestamp": datetime.now(timezone.utc).isoformat(), - } - ) - st.toast("Document types saved ๐Ÿ‘", icon="๐Ÿ’พ") - -st.markdown("

", unsafe_allow_html=True) - -# โ”€โ”€โ”€โ”€โ”€ 3๏ธโƒฃ ON-CLICK: run extraction with unrecognized document handling โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -if (extract_clicked or seq_clicked) and not st.session_state.get("extracted"): - st.subheader("๐Ÿ“ฆ Extracting Data...") - prog = st.progress(0.0, "Extractingโ€ฆ") - - for i, row in enumerate(doc_rows, start=1): - if "cancel_extraction" in st.session_state: - break - - # CHANGE: Skip extraction for unrecognized documents - if row["final_type"] in ["Unknown", "Other"]: - st.session_state["doc_rows"][i - 1]["fields"] = {} - st.session_state["doc_rows"][i - 1]["field_conf"] = {} - st.session_state["doc_rows"][i - 1]["fields_corrected"] = {} - prog.progress( - i / len(doc_rows), f"Skipping unrecognized: {row['file_name']}" - ) - continue - - schema = schema_mgr.get(row["final_type"]) or [] - if not schema: - st.session_state["doc_rows"][i - 1]["fields"] = {} - st.session_state["doc_rows"][i - 1]["field_conf"] = {} - st.session_state["doc_rows"][i - 1]["fields_corrected"] = {} - continue - - ocr_txt = None - if run_ocr_checkbox: - ocr_txt = run_ocr(row["images"]) # Removed engine parameter - LLM only - out = extract( - row["images"], - schema, - ocr_text=ocr_txt, - with_confidence=calc_conf, - system_prompt=st.session_state.get( - "extractor_prompt", DEFAULT_EXTRACTOR_PROMPT - ), - ) or {"metadata": {}, "confidence": {}} - - if not any(out["metadata"].values()): - st.toast(f"No fields detected in {row['file_name']}", icon="โš ๏ธ") - - st.session_state["doc_rows"][i - 1]["fields"] = out["metadata"] - st.session_state["doc_rows"][i - 1]["field_conf"] = out.get("confidence", {}) - - file_name = row["file_name"] - if file_name in past_metadata_map: - st.session_state["doc_rows"][i - 1]["fields_corrected"] = past_metadata_map[ - file_name - ] - else: - st.session_state["doc_rows"][i - 1]["fields_corrected"] = out[ - "metadata" - ].copy() - - prog.progress(i / len(doc_rows), f"Extracting from: {row['file_name']}") - - if st.session_state.get("extracted") is False: - st.button("๐Ÿ›‘ Cancel extraction", key="cancel_extraction") - - st.session_state["extracted"] = True - extract_clicked = False - seq_clicked = False - st.toast("Extraction completed โ€“ review below", icon="๐Ÿ“ฆ") - st.rerun() - -# โ”€โ”€โ”€โ”€โ”€ 4๏ธโƒฃ RENDER: extraction review & correction UI โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -if st.session_state.get("extracted"): - st.subheader("๐Ÿ“ฆ Review and Correct Extracted Data") - - for i, row in enumerate(doc_rows): - # CHANGE: Handle unrecognized documents in review - if row["final_type"] in ["Unknown", "Other"]: - with st.expander( - f"โš ๏ธ {row['file_name']} โ€” {row['final_type']} (SKIPPED)", expanded=False - ): - st.warning( - "This document was skipped because it's unrecognized. Please reclassify it first." - ) - continue - - with st.expander(f"{row['file_name']} โ€” {row['final_type']}", expanded=True): - if not row.get("fields_corrected"): - st.warning("No fields were extracted or loaded for this document type.") - continue - - # CHANGE: Enhanced confidence display in data editor (% format) - row_conf = row.get("field_conf", {}) - # Convert confidence to percentage format - conf_display = {} - for k, v in row_conf.items(): - if isinstance(v, (int, float)): - conf_display[k] = f"{int(v * 100)}%" - else: - conf_display[k] = str(v) if v else "" - - df = pd.DataFrame( - { - "Field": list(row["fields_corrected"].keys()), - "Value": list(row["fields_corrected"].values()), - "Conf.": [ - conf_display.get(k, "") for k in row["fields_corrected"] - ], # CHANGE: Show % - } - ) - - edited_df = st.data_editor( - df, - key=f"grid_{row['doc_id']}", - disabled=["Field", "Conf."], - width="stretch", - ) - - updated_values = pd.Series( - edited_df.Value.values, index=edited_df.Field - ).to_dict() - - st.session_state["doc_rows"][i]["fields_corrected"] = updated_values - - # โ”€โ”€ enhanced action buttons โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - col_re, col_save, col_download = st.columns([0.08, 0.08, 0.08]) - - with col_re: - if st.button("โ†ป Re-extract", key=f"reextract_{row['doc_id']}"): - schema = schema_mgr.get(row["final_type"]) or [] - ocr_txt = ( - st.session_state.get("ocr_map", {}).get(row["file_name"]) - if run_ocr_checkbox - else None - ) - new_out = extract( - row["images"], - schema, - ocr_text=ocr_txt, - with_confidence=calc_conf, - system_prompt=st.session_state.get( - "extractor_prompt", DEFAULT_EXTRACTOR_PROMPT - ), - ) or {"metadata": {}, "confidence": {}} - - row["fields"] = new_out["metadata"] - row["field_conf"] = new_out.get("confidence", {}) - row["fields_corrected"] = new_out["metadata"].copy() - st.toast(f"Re-extracted {row['file_name']}", icon="โœ…") - st.rerun() - - with col_save: - if st.button("๐Ÿ’พ Save", key=f"save_{row['doc_id']}"): - current_row = st.session_state["doc_rows"][i] - changed = diff_fields( - current_row["fields"], current_row["fields_corrected"] - ) - upsert_feedback( - { - "doc_id": current_row["doc_id"], - "file_name": current_row["file_name"], - "doc_type": current_row["final_type"], - "metadata_extracted": json.dumps( - current_row["fields"], ensure_ascii=False - ), - "metadata_corrected": json.dumps( - current_row["fields_corrected"], ensure_ascii=False - ), - "fields_corrected": changed, - "timestamp": datetime.now(timezone.utc).isoformat(), - } - ) - st.toast(f"Saved {current_row['file_name']}", icon="๐Ÿ’พ") - - # NEW: Individual download button - with col_download: - if st.button("๐Ÿ“„ JSON", key=f"download_{row['doc_id']}"): - # Fix: Handle None confidence properly - confidence_val = row.get("confidence") - if confidence_val is not None: - confidence_display = f"{int(confidence_val * 100)}%" - else: - confidence_display = "N/A" - - result = { - "document_info": { - "filename": row["file_name"], - "document_type": row["final_type"], - "confidence": confidence_display, # Fixed - "timestamp": datetime.now().isoformat(), - }, - "original_extraction": row.get("fields", {}), - "corrected_metadata": row.get("fields_corrected", {}), - "confidence_scores": { - k: f"{int(v * 100)}%" - for k, v in row.get("field_conf", {}).items() - if isinstance(v, (int, float)) - and v is not None # Added None check - }, - "processing_info": { - "ocr_used": run_ocr_checkbox, - "confidence_threshold": f"{conf_threshold}%", - }, - } - - # Enhanced bulk save with session summary - col1, col2 = st.columns([3, 1]) - with col1: - if st.button("๐Ÿ’พ Save all corrections", width="stretch", type="primary"): - saved_count = 0 - for row in st.session_state["doc_rows"]: - if row["final_type"] not in ["Unknown", "Other"]: - upsert_feedback( - { - "doc_id": row["doc_id"], - "file_name": row["file_name"], - "doc_type": row["final_type"], - "metadata_extracted": json.dumps( - row["fields"], ensure_ascii=False - ), - "metadata_corrected": json.dumps( - row["fields_corrected"], ensure_ascii=False - ), - "timestamp": datetime.now(timezone.utc).isoformat(), - } - ) - saved_count += 1 - st.toast(f"Saved {saved_count} documents โ€“ thank you!", icon="๐Ÿ’พ") - - with col2: - # NEW: Bulk download button - if st.button("๐Ÿ“ฅ Download All JSON", width="stretch"): - # Prepare bulk export - bulk_results = { - "export_info": { - "timestamp": datetime.now().isoformat(), - "total_documents": len(doc_rows), - "processed_documents": len( - [ - r - for r in doc_rows - if r["final_type"] not in ["Unknown", "Other"] - ] - ), - "confidence_threshold": f"{conf_threshold}%", - }, - "documents": [], - } - - for row in doc_rows: - confidence_val = row.get("confidence") - if confidence_val is not None: - confidence_display = f"{int(confidence_val * 100)}%" - else: - confidence_display = "N/A" - - doc_result = { - "filename": row["file_name"], - "document_type": row["final_type"], - "confidence": confidence_display, # Fixed - "original_extraction": row.get("fields", {}), - "corrected_metadata": row.get("fields_corrected", {}), - "confidence_scores": { - k: f"{int(v * 100)}%" - for k, v in row.get("field_conf", {}).items() - if isinstance(v, (int, float)) - and v is not None # Added None check - }, - } - bulk_results["documents"].append(doc_result) - - json_str = json.dumps(bulk_results, indent=2, ensure_ascii=False) - st.download_button( - label="๐Ÿ“„ Download Complete Session", - data=json_str, - file_name=f"bulk_extraction_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", - mime="application/json", - key="bulk_download", - ) diff --git a/pages/2_Extract.py b/pages/2_Extract.py new file mode 100644 index 0000000..19a5b91 --- /dev/null +++ b/pages/2_Extract.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from datetime import datetime +from pathlib import Path +import streamlit as st +from PIL import Image + +from extractly.config import load_config +from extractly.domain.run_store import RunStore +from extractly.domain.schema_store import SchemaStore +from extractly.integrations.preprocess import preprocess +from extractly.pipeline.classification import DEFAULT_CLASSIFIER_PROMPT +from extractly.pipeline.extraction import DEFAULT_EXTRACTION_PROMPT +from extractly.pipeline.runner import PipelineOptions, run_pipeline +from extractly.logging import setup_logging +from extractly.ui.components import inject_branding, inject_global_styles, section_title + + +config = load_config() +setup_logging() +store = SchemaStore(config.schema_dir) +run_store = RunStore(config.run_store_dir) + +st.set_page_config(page_title="Extract", page_icon="โšก", layout="wide") + +inject_branding(Path("data/assets/data_reply.svg")) +inject_global_styles() + +st.title("โšก Run Extraction") +st.caption("Upload documents, select a schema, and run the extraction pipeline.") + +schemas = store.list_schemas() +if not schemas: + st.warning("No schemas found. Create one in Schema Studio first.") + st.stop() + +schema_name = st.selectbox("Schema", options=[schema.name for schema in schemas]) +active_schema = store.get_schema(schema_name) + +left, right = st.columns([2, 1]) +with left: + files = st.file_uploader( + "Upload documents", + type=["pdf", "png", "jpg", "jpeg", "txt"], + accept_multiple_files=True, + ) + +with right: + section_title("Pipeline options") + enable_ocr = st.toggle("Enable OCR", value=False) + compute_conf = st.toggle("Field confidence", value=True) + mode = st.radio("Mode", options=["fast", "accurate"], horizontal=True) + +with st.expander("Advanced prompts", expanded=False): + classifier_prompt = st.text_area( + "Classifier prompt", + value=st.session_state.get("classifier_prompt", DEFAULT_CLASSIFIER_PROMPT), + height=140, + ) + extractor_prompt = st.text_area( + "Extraction prompt", + value=st.session_state.get("extractor_prompt", DEFAULT_EXTRACTION_PROMPT), + height=160, + ) + if st.button("Save prompts"): + st.session_state["classifier_prompt"] = classifier_prompt + st.session_state["extractor_prompt"] = extractor_prompt + st.success("Prompts saved.") + +st.markdown("---") + +section_title("Pipeline steps") +steps = st.columns(5) +steps[0].markdown("โœ… 1. Parse") +steps[1].markdown("โœ… 2. Classify") +steps[2].markdown("โœ… 3. Extract") +steps[3].markdown("โœ… 4. Validate") +steps[4].markdown("โœ… 5. Export") + +if st.button("Run extraction", type="primary", use_container_width=True): + if not files: + st.error("Upload at least one document.") + st.stop() + + parsed_files = [] + progress = st.progress(0.0, "Parsing files") + + for idx, upload in enumerate(files, start=1): + filename = upload.name + if filename.lower().endswith(".txt"): + content = upload.read().decode("utf-8", errors="ignore") + blank_image = Image.new("RGB", (800, 1000), color="white") + images = [blank_image] + parsed_files.append( + { + "name": filename, + "images": images, + "ocr_text": content, + "doc_type_override": active_schema.name, + } + ) + else: + images = preprocess(upload, filename) + parsed_files.append({"name": filename, "images": images}) + + progress.progress(idx / len(files), f"Parsed {filename}") + + options = PipelineOptions( + enable_ocr=enable_ocr, + compute_confidence=compute_conf, + mode=mode, + classifier_prompt=st.session_state.get("classifier_prompt"), + extraction_prompt=st.session_state.get("extractor_prompt"), + ) + + run = run_pipeline( + files=parsed_files, + schema=active_schema, + candidates=[schema.name for schema in schemas] + ["Unknown", "Other"], + run_store=run_store, + options=options, + ) + + st.session_state["latest_run_id"] = run.run_id + st.success("Extraction completed.") + st.page_link("pages/3_Results.py", label="View results", use_container_width=True) + + st.caption(f"Run {run.run_id} stored at {config.run_store_dir}") + st.code(f"Run completed at {datetime.now().isoformat()}") diff --git a/pages/3_Results.py b/pages/3_Results.py new file mode 100644 index 0000000..919385f --- /dev/null +++ b/pages/3_Results.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import csv +import io +import json +from pathlib import Path +import streamlit as st + +from extractly.config import load_config +from extractly.domain.run_store import RunStore +from extractly.logging import setup_logging +from extractly.ui.components import inject_branding, inject_global_styles, section_title + + +config = load_config() +setup_logging() +run_store = RunStore(config.run_store_dir) + +st.set_page_config(page_title="Results", page_icon="๐Ÿ“Š", layout="wide") + +inject_branding(Path("data/assets/data_reply.svg")) +inject_global_styles() + +st.title("๐Ÿ“Š Results") +st.caption("Browse extraction runs, review outputs, and export data.") + +runs = run_store.list_runs() +if not runs: + st.info("No runs yet. Run an extraction first.") + st.stop() + +run_ids = [run["run_id"] for run in runs] +latest_id = st.session_state.get("latest_run_id") +selected_id = st.selectbox( + "Select a run", + options=run_ids, + index=run_ids.index(latest_id) if latest_id in run_ids else 0, +) + +run = run_store.load(selected_id) +if not run: + st.error("Run not found.") + st.stop() + +section_title("Run summary") +summary_cols = st.columns(4) +summary_cols[0].metric("Run ID", run["run_id"]) +summary_cols[1].metric("Schema", run.get("schema_name", "โ€”")) +summary_cols[2].metric("Mode", run.get("mode", "โ€”")) +summary_cols[3].metric("Documents", len(run.get("documents", []))) + +st.markdown("---") + +section_title("Documents") +doc_rows = [] +for doc in run.get("documents", []): + doc_rows.append( + { + "filename": doc.get("filename"), + "document_type": doc.get("document_type"), + "confidence": doc.get("confidence"), + "warnings": len(doc.get("warnings", [])), + "errors": len(doc.get("errors", [])), + } + ) + +st.dataframe(doc_rows, use_container_width=True) + +selected_doc_name = st.selectbox( + "View document details", + options=[doc["filename"] for doc in run.get("documents", [])], +) + +selected_doc = next( + (doc for doc in run.get("documents", []) if doc["filename"] == selected_doc_name), + None, +) + +if selected_doc: + section_title("Extracted fields") + field_rows = [ + { + "field": key, + "value": value, + "confidence": selected_doc.get("field_confidence", {}).get(key, ""), + } + for key, value in selected_doc.get("corrected", {}).items() + ] + st.dataframe(field_rows, use_container_width=True) + + section_title("JSON output") + st.json(selected_doc.get("corrected", {})) + + if selected_doc.get("warnings"): + st.warning("\n".join(selected_doc.get("warnings"))) + if selected_doc.get("errors"): + st.error("\n".join(selected_doc.get("errors"))) + +section_title("Exports") + +json_payload = json.dumps(run, indent=2, ensure_ascii=False) + +st.download_button( + "Download run JSON", + data=json_payload, + file_name=f"{selected_id}.json", + mime="application/json", +) + +csv_buffer = io.StringIO() +fieldnames = {"filename", "document_type", "confidence"} +for doc in run.get("documents", []): + fieldnames.update(doc.get("corrected", {}).keys()) + +writer = csv.DictWriter(csv_buffer, fieldnames=sorted(fieldnames)) +writer.writeheader() +for doc in run.get("documents", []): + row = { + "filename": doc.get("filename"), + "document_type": doc.get("document_type"), + "confidence": doc.get("confidence"), + } + row.update(doc.get("corrected", {})) + writer.writerow(row) + +st.download_button( + "Download CSV", + data=csv_buffer.getvalue(), + file_name=f"{selected_id}.csv", + mime="text/csv", +) diff --git a/pages/4_Settings.py b/pages/4_Settings.py new file mode 100644 index 0000000..1605316 --- /dev/null +++ b/pages/4_Settings.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from pathlib import Path +import streamlit as st + +from extractly.config import load_config +from extractly.ui.components import inject_branding, inject_global_styles, section_title +from extractly.logging import setup_logging + + +config = load_config() +setup_logging() + +st.set_page_config(page_title="Settings", page_icon="โš™๏ธ", layout="wide") + +inject_branding(Path("data/assets/data_reply.svg")) +inject_global_styles() + +st.title("โš™๏ธ Settings") +st.caption("Review configuration, models, and environment status.") + +section_title("Environment") +cols = st.columns(3) +cols[0].metric("OpenAI key", "โœ… Found" if config.openai_api_key else "โŒ Missing") +cols[1].metric("Timeout (s)", config.request_timeout_s) +cols[2].metric("Max retries", config.max_retries) + +section_title("Models") +model_cols = st.columns(3) +model_cols[0].text_input("Classifier model", value=config.classify_model, disabled=True) +model_cols[1].text_input("Extractor model", value=config.extract_model, disabled=True) +model_cols[2].text_input("OCR model", value=config.ocr_model, disabled=True) + +section_title("Directories") +st.write(f"Schemas: `{config.schema_dir}`") +st.write(f"Runs: `{config.run_store_dir}`") +st.write(f"Sample docs: `{config.sample_data_dir}`") + +section_title("Notes") +st.info( + "To update models or pipeline settings, set the environment variables in your `.env` file. " + "Run `streamlit run Home.py` after changing them.", + icon="๐Ÿ“", +) diff --git a/schemas/demo_invoice.json b/schemas/demo_invoice.json new file mode 100644 index 0000000..f6cdd26 --- /dev/null +++ b/schemas/demo_invoice.json @@ -0,0 +1,48 @@ +{ + "Invoice Demo": { + "description": "Client-ready demo schema for invoice metadata.", + "version": "v1", + "fields": [ + { + "name": "Invoice Number", + "type": "string", + "required": true, + "description": "Unique invoice identifier", + "example": "INV-2025-0042" + }, + { + "name": "Invoice Date", + "type": "date", + "required": true, + "description": "Issue date", + "example": "2025-01-15" + }, + { + "name": "Supplier", + "type": "string", + "required": true, + "description": "Supplier company name" + }, + { + "name": "Client", + "type": "string", + "required": true, + "description": "Client company name" + }, + { + "name": "Total Amount", + "type": "number", + "required": true, + "description": "Invoice total", + "example": "1248.50" + }, + { + "name": "Payment Terms", + "type": "string", + "required": false, + "description": "Payment terms", + "example": "Net 30" + } + ] + } +} diff --git a/src/classifier.py b/src/classifier.py index 3dd7c8e..b165a28 100644 --- a/src/classifier.py +++ b/src/classifier.py @@ -1,12 +1,8 @@ -# utils/classifier.py +from __future__ import annotations -import os -import io -import base64 from PIL import Image -from utils.openai_client import get_chat_completion -from statistics import mode, StatisticsError -from utils.utils import DEFAULT_OPENAI_MODEL + +from extractly.pipeline.classification import classify_document def classify( @@ -14,39 +10,13 @@ def classify( candidates: list[str], *, use_confidence: bool = False, - n_votes: int = 5, # number of self-consistency calls + n_votes: int = 5, system_prompt: str = "", ) -> dict: - """Returns {'doc_type': str} or additionally a 'confidence' field.""" - - def _single_vote() -> str: - buf = io.BytesIO() - images[0].save(buf, format="PNG") - data_uri = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" - - prompt = f"Choose one type from: {candidates}. Return only the type." - msgs = [ - {"role": "system", "content": system_prompt}, - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - {"type": "image_url", "image_url": {"url": data_uri}}, - ], - }, - ] - return get_chat_completion( - msgs, model=os.getenv("CLASSIFY_MODEL", DEFAULT_OPENAI_MODEL) - ).strip() - - if not use_confidence: - return {"doc_type": _single_vote()} - - votes = [_single_vote() for _ in range(n_votes)] - try: - best = mode(votes) - confidence = votes.count(best) / n_votes - except StatisticsError: # all votes different - best, confidence = votes[0], 1 / n_votes - - return {"doc_type": best, "confidence": confidence} + return classify_document( + images, + candidates, + use_confidence=use_confidence, + n_votes=n_votes, + system_prompt=system_prompt or None, + ) diff --git a/src/extractly/__init__.py b/src/extractly/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/extractly/config.py b/src/extractly/config.py new file mode 100644 index 0000000..8012926 --- /dev/null +++ b/src/extractly/config.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path +from dotenv import load_dotenv + + +PROJECT_ROOT = Path(__file__).resolve().parents[2] + + +@dataclass(frozen=True) +class AppConfig: + app_name: str + openai_api_key: str | None + classify_model: str + extract_model: str + ocr_model: str + request_timeout_s: int + max_retries: int + retry_backoff_s: float + run_store_dir: Path + schema_dir: Path + sample_data_dir: Path + + +def load_config() -> AppConfig: + load_dotenv(override=True) + + return AppConfig( + app_name=os.getenv("EXTRACTLY_APP_NAME", "Extractly"), + openai_api_key=os.getenv("OPENAI_API_KEY"), + classify_model=os.getenv("CLASSIFY_MODEL", "o4-mini"), + extract_model=os.getenv("EXTRACT_MODEL", "o4-mini"), + ocr_model=os.getenv("OCR_MODEL", "o4-mini"), + request_timeout_s=int(os.getenv("EXTRACTLY_TIMEOUT_S", "40")), + max_retries=int(os.getenv("EXTRACTLY_MAX_RETRIES", "2")), + retry_backoff_s=float(os.getenv("EXTRACTLY_RETRY_BACKOFF_S", "1.5")), + run_store_dir=Path(os.getenv("EXTRACTLY_RUNS_DIR", PROJECT_ROOT / "runs")), + schema_dir=Path(os.getenv("EXTRACTLY_SCHEMAS_DIR", PROJECT_ROOT / "schemas")), + sample_data_dir=Path(os.getenv("EXTRACTLY_SAMPLE_DATA_DIR", PROJECT_ROOT / "data" / "sample_docs")), + ) diff --git a/src/extractly/domain/__init__.py b/src/extractly/domain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/extractly/domain/models.py b/src/extractly/domain/models.py new file mode 100644 index 0000000..32b99c6 --- /dev/null +++ b/src/extractly/domain/models.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +FIELD_TYPES = { + "string", + "number", + "integer", + "boolean", + "date", + "enum", + "object", + "array", +} + + +@dataclass +class SchemaField: + name: str + field_type: str = "string" + required: bool = False + description: str = "" + example: str = "" + enum_values: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + payload = { + "name": self.name, + "type": self.field_type, + "required": self.required, + "description": self.description, + "example": self.example, + } + if self.enum_values: + payload["enum"] = self.enum_values + return payload + + +@dataclass +class DocumentSchema: + name: str + description: str = "" + fields: list[SchemaField] = field(default_factory=list) + version: str = "v1" + + def to_dict(self) -> dict[str, Any]: + return { + "description": self.description, + "version": self.version, + "fields": [field.to_dict() for field in self.fields], + } diff --git a/src/extractly/domain/run_store.py b/src/extractly/domain/run_store.py new file mode 100644 index 0000000..70f2c1f --- /dev/null +++ b/src/extractly/domain/run_store.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any +from uuid import uuid4 + + +@dataclass +class RunDocument: + filename: str + document_type: str + confidence: float | None + extracted: dict[str, Any] + corrected: dict[str, Any] + field_confidence: dict[str, float] = field(default_factory=dict) + warnings: list[str] = field(default_factory=list) + errors: list[str] = field(default_factory=list) + + +@dataclass +class ExtractionRun: + run_id: str + started_at: str + schema_name: str + mode: str + documents: list[RunDocument] + status: str = "completed" + logs: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + return { + "run_id": self.run_id, + "started_at": self.started_at, + "schema_name": self.schema_name, + "mode": self.mode, + "status": self.status, + "logs": self.logs, + "documents": [ + { + "filename": doc.filename, + "document_type": doc.document_type, + "confidence": doc.confidence, + "extracted": doc.extracted, + "corrected": doc.corrected, + "field_confidence": doc.field_confidence, + "warnings": doc.warnings, + "errors": doc.errors, + } + for doc in self.documents + ], + } + + +class RunStore: + def __init__(self, base_dir: Path): + self.base_dir = base_dir + self.base_dir.mkdir(parents=True, exist_ok=True) + + def create_run_id(self) -> str: + timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + return f"run_{timestamp}_{uuid4().hex[:6]}" + + def save(self, run: ExtractionRun) -> Path: + run_dir = self.base_dir / run.run_id + run_dir.mkdir(parents=True, exist_ok=True) + run_path = run_dir / "run.json" + with run_path.open("w", encoding="utf-8") as fp: + json.dump(run.to_dict(), fp, indent=2, ensure_ascii=False) + return run_path + + def list_runs(self) -> list[dict[str, Any]]: + runs: list[dict[str, Any]] = [] + for run_dir in sorted(self.base_dir.glob("run_*"), reverse=True): + run_path = run_dir / "run.json" + if not run_path.exists(): + continue + with run_path.open("r", encoding="utf-8") as fp: + payload = json.load(fp) + runs.append(payload) + return runs + + def load(self, run_id: str) -> dict[str, Any] | None: + run_path = self.base_dir / run_id / "run.json" + if not run_path.exists(): + return None + with run_path.open("r", encoding="utf-8") as fp: + return json.load(fp) diff --git a/src/extractly/domain/schema_store.py b/src/extractly/domain/schema_store.py new file mode 100644 index 0000000..59092bb --- /dev/null +++ b/src/extractly/domain/schema_store.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import json +import re +from pathlib import Path +from typing import Iterable + +from extractly.domain.models import DocumentSchema, SchemaField +from extractly.domain.validation import validate_schema, ValidationResult + + +class SchemaStore: + def __init__(self, schema_dir: Path): + self.schema_dir = schema_dir + self.schema_dir.mkdir(parents=True, exist_ok=True) + + def list_schemas(self) -> list[DocumentSchema]: + schemas: list[DocumentSchema] = [] + for file_path in sorted(self.schema_dir.glob("*.json")): + with file_path.open("r", encoding="utf-8") as fp: + payload = json.load(fp) + schemas.extend(self._parse_payload(payload)) + return sorted(schemas, key=lambda s: s.name.lower()) + + def get_schema(self, name: str) -> DocumentSchema | None: + for schema in self.list_schemas(): + if schema.name == name: + return schema + return None + + def save_schema(self, schema: DocumentSchema) -> ValidationResult: + validation = validate_schema(schema) + if not validation.is_valid: + return validation + + file_slug = self._slugify(schema.name) + file_path = self.schema_dir / f"{file_slug}.json" + payload = {schema.name: schema.to_dict()} + with file_path.open("w", encoding="utf-8") as fp: + json.dump(payload, fp, indent=2, ensure_ascii=False) + return validation + + def delete_schema(self, name: str) -> bool: + deleted = False + for file_path in self.schema_dir.glob("*.json"): + with file_path.open("r", encoding="utf-8") as fp: + payload = json.load(fp) + if name in payload: + payload.pop(name) + deleted = True + if payload: + with file_path.open("w", encoding="utf-8") as fp: + json.dump(payload, fp, indent=2, ensure_ascii=False) + else: + file_path.unlink(missing_ok=True) + return deleted + + def import_payload(self, payload: dict) -> list[DocumentSchema]: + schemas = self._parse_payload(payload) + for schema in schemas: + self.save_schema(schema) + return schemas + + def export_schema(self, schema: DocumentSchema) -> str: + return json.dumps({schema.name: schema.to_dict()}, indent=2, ensure_ascii=False) + + @staticmethod + def _slugify(name: str) -> str: + cleaned = re.sub(r"[^a-zA-Z0-9_-]+", "-", name.strip()) + return cleaned.strip("-").lower() or "schema" + + def _parse_payload(self, payload: dict) -> list[DocumentSchema]: + schemas: list[DocumentSchema] = [] + for name, data in payload.items(): + if isinstance(data, list): + fields = [self._parse_field(field) for field in data] + schemas.append(DocumentSchema(name=name, fields=fields)) + continue + + description = data.get("description", "") + version = data.get("version", "v1") + raw_fields = data.get("fields", []) + fields = [self._parse_field(field) for field in raw_fields] + schemas.append( + DocumentSchema( + name=name, + description=description, + fields=fields, + version=version, + ) + ) + return schemas + + @staticmethod + def _parse_field(field: dict) -> SchemaField: + return SchemaField( + name=str(field.get("name", "")).strip(), + field_type=field.get("type", field.get("field_type", "string")), + required=bool(field.get("required", False)), + description=str(field.get("description", "")), + example=str(field.get("example", "")), + enum_values=list(field.get("enum", field.get("enum_values", [])) or []), + ) + + +def schemas_to_table(schema: DocumentSchema) -> list[dict]: + return [ + { + "name": field.name, + "type": field.field_type, + "required": field.required, + "description": field.description, + "example": field.example, + "enum": ", ".join(field.enum_values), + } + for field in schema.fields + ] + + +def table_to_schema(name: str, description: str, rows: Iterable[dict]) -> DocumentSchema: + fields: list[SchemaField] = [] + for row in rows: + enum_values = [ + item.strip() + for item in str(row.get("enum", "")).split(",") + if item.strip() + ] + fields.append( + SchemaField( + name=str(row.get("name", "")).strip(), + field_type=row.get("type", "string"), + required=bool(row.get("required", False)), + description=str(row.get("description", "")).strip(), + example=str(row.get("example", "")).strip(), + enum_values=enum_values, + ) + ) + return DocumentSchema(name=name, description=description, fields=fields) diff --git a/src/extractly/domain/validation.py b/src/extractly/domain/validation.py new file mode 100644 index 0000000..e9fd140 --- /dev/null +++ b/src/extractly/domain/validation.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +from extractly.domain.models import DocumentSchema, FIELD_TYPES + + +@dataclass +class ValidationResult: + errors: list[str] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + + @property + def is_valid(self) -> bool: + return not self.errors + + +def validate_schema(schema: DocumentSchema) -> ValidationResult: + result = ValidationResult() + + if not schema.name.strip(): + result.errors.append("Schema name is required.") + + seen_names = set() + for idx, field in enumerate(schema.fields, start=1): + if not field.name.strip(): + result.errors.append(f"Field #{idx} is missing a name.") + continue + + if field.name in seen_names: + result.errors.append(f"Field name '{field.name}' is duplicated.") + seen_names.add(field.name) + + if field.field_type not in FIELD_TYPES: + result.errors.append( + f"Field '{field.name}' has invalid type '{field.field_type}'." + ) + + if field.field_type == "enum" and not field.enum_values: + result.errors.append( + f"Field '{field.name}' is enum but has no enum values." + ) + + if field.enum_values: + if len(set(field.enum_values)) != len(field.enum_values): + result.errors.append( + f"Field '{field.name}' has duplicate enum values." + ) + + if not schema.fields: + result.errors.append("At least one field is required.") + + if schema.description and len(schema.description) < 10: + result.warnings.append("Schema description is very short.") + + return result diff --git a/src/extractly/integrations/__init__.py b/src/extractly/integrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/extractly/integrations/ocr.py b/src/extractly/integrations/ocr.py new file mode 100644 index 0000000..a5d7b1e --- /dev/null +++ b/src/extractly/integrations/ocr.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import base64 +import io +from typing import Iterable +from PIL import Image + +from extractly.config import load_config +from extractly.integrations.openai_client import get_chat_completion +from extractly.logging import get_logger + + +logger = get_logger(__name__) + + +_OCR_SYSTEM_PROMPT = """ +You are an expert OCR (Optical Character Recognition) assistant. Extract every visible piece +of text from the document image. Preserve the reading order and line breaks where meaningful. +Return only the raw text content. +""" + + +def _page_to_data_uri(page: Image.Image) -> str: + buf = io.BytesIO() + page.save(buf, format="PNG") + return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" + + +def ocr_page(page: Image.Image) -> str: + config = load_config() + messages = [ + {"role": "system", "content": _OCR_SYSTEM_PROMPT}, + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": _page_to_data_uri(page)}}, + ], + }, + ] + + return get_chat_completion(messages, model=config.ocr_model, temperature=0.0) + + +def run_ocr(pages: Iterable[Image.Image]) -> str: + outputs: list[str] = [] + for page in pages: + try: + outputs.append(ocr_page(page)) + except Exception as exc: + logger.error("OCR failed: %s", exc) + outputs.append("") + return "\n\n".join(outputs).strip() diff --git a/src/extractly/integrations/openai_client.py b/src/extractly/integrations/openai_client.py new file mode 100644 index 0000000..c783d5f --- /dev/null +++ b/src/extractly/integrations/openai_client.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import time +from typing import Any + +from openai import OpenAI + +from extractly.config import load_config +from extractly.logging import get_logger + + +logger = get_logger(__name__) + + +_client: OpenAI | None = None + + +def get_client() -> OpenAI: + global _client + if _client is None: + config = load_config() + _client = OpenAI(api_key=config.openai_api_key, timeout=config.request_timeout_s) + return _client + + +def get_chat_completion( + messages: list[dict[str, Any]], + *, + model: str, + temperature: float = 0.2, + max_retries: int | None = None, + timeout_s: int | None = None, +) -> str: + config = load_config() + retries = max_retries if max_retries is not None else config.max_retries + timeout = timeout_s if timeout_s is not None else config.request_timeout_s + + if not config.openai_api_key: + raise RuntimeError("OPENAI_API_KEY is not set. Add it to your environment.") + + for attempt in range(retries + 1): + try: + response = get_client().chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + timeout=timeout, + ) + return response.choices[0].message.content or "" + except Exception as exc: + logger.warning("OpenAI request failed (attempt %s): %s", attempt + 1, exc) + if attempt >= retries: + raise + time.sleep(config.retry_backoff_s * (attempt + 1)) + + return "" diff --git a/src/extractly/integrations/preprocess.py b/src/extractly/integrations/preprocess.py new file mode 100644 index 0000000..93fb89e --- /dev/null +++ b/src/extractly/integrations/preprocess.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import io +from typing import BinaryIO + +from pdf2image import convert_from_bytes +from PIL import Image + + +def preprocess(uploaded: BinaryIO, filename: str) -> list[Image.Image]: + data = uploaded.read() + uploaded.seek(0) + + if filename.lower().endswith(".pdf"): + return convert_from_bytes(data) + + return [Image.open(io.BytesIO(data))] diff --git a/src/extractly/logging.py b/src/extractly/logging.py new file mode 100644 index 0000000..511a262 --- /dev/null +++ b/src/extractly/logging.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +import logging + + +LOG_FORMAT = "%(asctime)s | %(levelname)s | %(name)s | %(message)s" + + +def setup_logging(level: int = logging.INFO) -> None: + logging.basicConfig(level=level, format=LOG_FORMAT) + + +def get_logger(name: str) -> logging.Logger: + return logging.getLogger(name) diff --git a/src/extractly/pipeline/__init__.py b/src/extractly/pipeline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/extractly/pipeline/classification.py b/src/extractly/pipeline/classification.py new file mode 100644 index 0000000..c232480 --- /dev/null +++ b/src/extractly/pipeline/classification.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import base64 +import io +from statistics import StatisticsError, mode +from typing import Any + +from PIL import Image + +from extractly.config import load_config +from extractly.integrations.openai_client import get_chat_completion +from extractly.logging import get_logger + + +logger = get_logger(__name__) + + +DEFAULT_CLASSIFIER_PROMPT = """ +You are an expert document classifier. Choose the most likely document type based on layout, +visual cues, and key text. Respond only with a single label from the provided list. +If uncertain, choose "Unknown". +""" + + +def _image_to_data_uri(image: Image.Image) -> str: + buf = io.BytesIO() + image.save(buf, format="PNG") + return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" + + +def classify_document( + images: list[Image.Image], + candidates: list[str], + *, + use_confidence: bool = False, + n_votes: int = 3, + system_prompt: str | None = None, +) -> dict[str, Any]: + config = load_config() + prompt = system_prompt or DEFAULT_CLASSIFIER_PROMPT + + def _single_vote() -> str: + messages = [ + {"role": "system", "content": prompt}, + { + "role": "user", + "content": [ + { + "type": "text", + "text": f"Choose one type from: {candidates}.", + }, + { + "type": "image_url", + "image_url": {"url": _image_to_data_uri(images[0])}, + }, + ], + }, + ] + return get_chat_completion(messages, model=config.classify_model).strip() + + if not use_confidence: + return {"doc_type": _single_vote()} + + votes = [_single_vote() for _ in range(n_votes)] + try: + best = mode(votes) + confidence = votes.count(best) / n_votes + except StatisticsError: + best, confidence = votes[0], 1 / n_votes + + return {"doc_type": best, "confidence": confidence} diff --git a/src/extractly/pipeline/extraction.py b/src/extractly/pipeline/extraction.py new file mode 100644 index 0000000..60be5e5 --- /dev/null +++ b/src/extractly/pipeline/extraction.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import base64 +import contextlib +import io +import json +import re +from typing import Any + +from PIL import Image + +from extractly.config import load_config +from extractly.domain.models import SchemaField +from extractly.integrations.openai_client import get_chat_completion +from extractly.logging import get_logger + + +logger = get_logger(__name__) + + +DEFAULT_EXTRACTION_PROMPT = """ +You are a metadata extraction specialist. Extract the requested fields with high accuracy. +Return JSON with keys: metadata, snippets, confidence. Use null when a field is missing. +""" + + +def _image_to_data_uri(image: Image.Image) -> str: + buf = io.BytesIO() + image.save(buf, format="PNG") + return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" + + +def _truncate(text: str, max_chars: int = 64_000) -> str: + return text[:max_chars] + + +def _schema_payload(fields: list[SchemaField]) -> dict[str, Any]: + return { + field.name: { + "type": field.field_type, + "required": field.required, + "description": field.description, + "example": field.example, + "enum": field.enum_values, + } + for field in fields + } + + +def extract_metadata( + images: list[Image.Image], + fields: list[SchemaField], + *, + ocr_text: str | None = None, + with_confidence: bool = False, + system_prompt: str | None = None, +) -> dict[str, Any]: + config = load_config() + field_names = [field.name for field in fields] + + schema_json = json.dumps(_schema_payload(fields), ensure_ascii=False) + + user_content = [ + {"type": "image_url", "image_url": {"url": _image_to_data_uri(images[0])}}, + {"type": "text", "text": f"Schema: {schema_json}"}, + ] + if ocr_text: + user_content.append( + { + "type": "text", + "text": "OCR context:\n" + _truncate(ocr_text), + } + ) + + messages = [ + {"role": "system", "content": system_prompt or DEFAULT_EXTRACTION_PROMPT}, + {"role": "user", "content": user_content}, + ] + + response = get_chat_completion(messages, model=config.extract_model) + if not response.strip(): + raise RuntimeError("Empty extraction response") + + raw: dict[str, Any] | None = None + with contextlib.suppress(json.JSONDecodeError): + raw = json.loads(response) + + if raw is None and (match := re.search(r"\{.*\}", response, flags=re.S)): + with contextlib.suppress(json.JSONDecodeError): + raw = json.loads(match.group()) + + if not isinstance(raw, dict): + logger.error("Bad extraction JSON. Returning blanks.") + raw = {} + + raw_meta = raw.get("metadata") if isinstance(raw.get("metadata"), dict) else {} + raw_snippets = ( + raw.get("snippets") if isinstance(raw.get("snippets"), dict) else {} + ) + raw_conf = ( + raw.get("confidence") if isinstance(raw.get("confidence"), dict) else {} + ) + + + if with_confidence and not raw_conf: + raw_conf = {name: 1.0 if raw_meta.get(name) else 0.0 for name in field_names} + + metadata = {name: raw_meta.get(name) for name in field_names} + snippets = {name: raw_snippets.get(name) for name in field_names} + + confidence = {} + if with_confidence: + confidence = {name: raw_conf.get(name, 0.0) for name in field_names} + + return { + "metadata": metadata, + "snippets": snippets, + "confidence": confidence, + } diff --git a/src/extractly/pipeline/runner.py b/src/extractly/pipeline/runner.py new file mode 100644 index 0000000..20bb7bd --- /dev/null +++ b/src/extractly/pipeline/runner.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any + +from PIL import Image + +from extractly.domain.models import DocumentSchema +from extractly.domain.run_store import ExtractionRun, RunDocument, RunStore +from extractly.integrations.ocr import run_ocr +from extractly.pipeline.classification import classify_document +from extractly.pipeline.extraction import extract_metadata +from extractly.logging import get_logger + + +logger = get_logger(__name__) + + +@dataclass +class PipelineOptions: + enable_ocr: bool = False + compute_confidence: bool = False + mode: str = "fast" + classifier_prompt: str | None = None + extraction_prompt: str | None = None + + +def run_pipeline( + *, + files: list[dict[str, Any]], + schema: DocumentSchema, + candidates: list[str], + run_store: RunStore, + options: PipelineOptions, +) -> ExtractionRun: + run_id = run_store.create_run_id() + logs: list[str] = [] + documents: list[RunDocument] = [] + + for payload in files: + filename = payload["name"] + images: list[Image.Image] = payload["images"] + + logs.append(f"Parsing {filename}") + ocr_text = payload.get("ocr_text") + if ocr_text is None and options.enable_ocr: + ocr_text = run_ocr(images) + + doc_type_override = payload.get("doc_type_override") + if doc_type_override: + doc_type = doc_type_override + confidence = None + logs.append(f"Using provided document type for {filename}: {doc_type}") + else: + classification = classify_document( + images, + candidates, + use_confidence=options.compute_confidence, + system_prompt=options.classifier_prompt, + ) + doc_type = classification.get("doc_type", "Unknown") + confidence = classification.get("confidence") + logs.append(f"Classified {filename} as {doc_type}") + + warnings: list[str] = [] + errors: list[str] = [] + extracted: dict[str, Any] = {} + field_confidence: dict[str, float] = {} + + if doc_type in {"Unknown", "Other"}: + warnings.append("Document type is unknown. Extraction skipped.") + else: + try: + extraction = extract_metadata( + images, + schema.fields, + ocr_text=ocr_text, + with_confidence=options.compute_confidence, + system_prompt=options.extraction_prompt, + ) + extracted = extraction.get("metadata", {}) + field_confidence = extraction.get("confidence", {}) + except Exception as exc: + logger.error("Extraction failed for %s: %s", filename, exc) + errors.append(str(exc)) + + documents.append( + RunDocument( + filename=filename, + document_type=doc_type, + confidence=confidence, + extracted=extracted, + corrected=extracted.copy(), + field_confidence=field_confidence, + warnings=warnings, + errors=errors, + ) + ) + + run = ExtractionRun( + run_id=run_id, + started_at=datetime.now(timezone.utc).isoformat(), + schema_name=schema.name, + mode=options.mode, + documents=documents, + logs=logs, + ) + run_store.save(run) + return run diff --git a/src/extractly/ui/__init__.py b/src/extractly/ui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/extractly/ui/components.py b/src/extractly/ui/components.py new file mode 100644 index 0000000..0527394 --- /dev/null +++ b/src/extractly/ui/components.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import base64 +from pathlib import Path +import streamlit as st + + +def inject_branding(logo_path: str | Path, height: str = "64px") -> None: + logo_path = Path(logo_path) + if not logo_path.exists(): + return + + encoded = base64.b64encode(logo_path.read_bytes()).decode() + st.markdown( + f""" + + """, + unsafe_allow_html=True, + ) + + +def inject_global_styles() -> None: + st.markdown( + """ + + """, + unsafe_allow_html=True, + ) + + +def section_title(title: str, subtitle: str | None = None) -> None: + st.markdown(f"
{title}
", unsafe_allow_html=True) + if subtitle: + st.caption(subtitle) diff --git a/src/extractor.py b/src/extractor.py index 1c27ff7..82a3af7 100644 --- a/src/extractor.py +++ b/src/extractor.py @@ -1,111 +1,35 @@ -import contextlib -import os -import io -import base64 -import json -import logging -import re -from typing import List, Dict -from PIL import Image -from utils.openai_client import get_chat_completion -from utils.utils import DEFAULT_OPENAI_MODEL -from utils.confidence_utils import score_confidence +from __future__ import annotations +from PIL import Image -def _truncate(txt: str, max_chars: int = 64_000) -> str: - """Guard-rail so we don't blow up the context window with huge OCR dumps.""" - return txt[:max_chars] +from extractly.domain.models import SchemaField +from extractly.pipeline.extraction import extract_metadata def extract( - images: List[Image.Image], - schema: List[Dict], - ocr_text: str | None = None, # Changed from Mapping to str - *, # keyword-only "tuning" flags + images: list[Image.Image], + schema: list[dict], + ocr_text: str | None = None, + *, with_confidence: bool = False, system_prompt: str = "", -) -> Dict: - """ - Return a dict with exactly three top-level keys: - metadata โ€“ field -> value - snippets โ€“ field -> supporting text (ocr or model) - confidence โ€“ field -> float in [0,1] (empty if with_confidence=False) - """ - # 1๏ธโƒฃ template & helpers -------------------------------------------------- - field_names = [f["name"] for f in schema] - blank_dict = {n: None for n in field_names} - schema_json = json.dumps(blank_dict, ensure_ascii=False) - - buf = io.BytesIO() - images[0].save(buf, format="PNG") - data_uri = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" - - usr = [ - {"type": "image_url", "image_url": {"url": data_uri}}, - {"type": "text", "text": f"Fields schema: {schema_json}"}, - ] - - if ocr_text: - usr.append( - { - "type": "text", - "text": "Extra context (OCR dump):\n\n" + _truncate(ocr_text), - } +) -> dict: + fields = [ + SchemaField( + name=field.get("name", ""), + field_type=field.get("type", "string"), + required=field.get("required", False), + description=field.get("description", ""), + example=field.get("example", ""), + enum_values=list(field.get("enum", []) or []), ) - - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": usr}, + for field in schema ] - # 2๏ธโƒฃ LLM call ------------------------------------------------------------ - resp = get_chat_completion( - messages, model=os.getenv("EXTRACT_MODEL", DEFAULT_OPENAI_MODEL) + return extract_metadata( + images, + fields, + ocr_text=ocr_text, + with_confidence=with_confidence, + system_prompt=system_prompt or None, ) - - if not resp.strip(): - raise RuntimeError("empty LLM response") - - # 3๏ธโƒฃ JSON-safe parse ----------------------------------------------------- - raw: Dict | None = None - with contextlib.suppress(json.JSONDecodeError): - raw = json.loads(resp) - - if raw is None and (m := re.search(r"\{.*\}", resp, flags=re.S)): - with contextlib.suppress(json.JSONDecodeError): - raw = json.loads(m.group()) - - if not isinstance(raw, dict): - logging.error("Bad extraction JSON โ†’ returning blanks") - raw = {} - - # 4๏ธโƒฃ normalise sections -------------------------------------------------- - raw_meta = raw.get("metadata") or {} - raw_snip = raw.get("snippets") or {} - raw_conf = raw.get("confidence") or {} - - # force dicts (LLMs sometimes give a scalar there) - if not isinstance(raw_meta, dict): - raw_meta = {} - if not isinstance(raw_snip, dict): - raw_snip = {} - if not isinstance(raw_conf, dict): - raw_conf = {} - - # merge external OCR into snippets (OCR wins) - if ocr_text: # Changed from isinstance(ocr_text, dict) check - raw_snip["ocr_content"] = ocr_text - - # 5๏ธโƒฃ fallback confidence -------------------------------------------------- - if with_confidence and not raw_conf: - # heuristic: 1.0 if field present & not null, else 0.0 - raw_conf = score_confidence(raw_meta, schema) - - # 6๏ธโƒฃ final payload with *exact* keys ------------------------------------- - return { - "metadata": {n: raw_meta.get(n) for n in field_names}, - "snippets": {n: raw_snip.get(n) for n in field_names}, - "confidence": {n: raw_conf.get(n) for n in field_names} - if with_confidence - else {}, - } diff --git a/src/ocr_engine.py b/src/ocr_engine.py index a913cb3..3c81505 100644 --- a/src/ocr_engine.py +++ b/src/ocr_engine.py @@ -1,73 +1,12 @@ -""" -src/ocr_engine.py -A super-thin wrapper so you can swap engines without touching the rest of the app. -""" - from __future__ import annotations -from PIL import Image -import io -import base64 -import os -import logging -from openai import OpenAI -from utils.utils import DEFAULT_OPENAI_MODEL - -api_key = os.getenv("OPENAI_API_KEY") -if not api_key: - logging.error("OPENAI_API_KEY not found in env") - -# โ”€โ”€ local OpenAI helper โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ -_client = OpenAI(api_key=api_key) # โ‘  -_VISION_MODEL = os.getenv("OCR_MODEL", DEFAULT_OPENAI_MODEL) # โ‘ก - -def _ocr_llm(page: Image.Image) -> str: - """Do a *vision* chat-completion round-trip and return plain text.""" - buf = io.BytesIO() - page.save(buf, format="PNG") - data_uri = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" - - system_prompt = """ - You are an expert OCR (Optical Character Recognition) assistant. Your task is to extract ALL visible text from the document image with perfect accuracy. - - Instructions: - 1. Read every piece of text visible in the image, including: - - Headers, titles, and headings - - Body text and paragraphs - - Table contents and data - - Form fields and labels - - Numbers, dates, and codes - - Fine print and footnotes - - Watermarks or stamps (if readable) +from PIL import Image - 2. Maintain the logical reading order (top to bottom, left to right) - 3. Preserve line breaks and spacing where meaningful - 4. Return ONLY the literal text - no commentary, no JSON formatting - 5. If text is unclear or partially obscured, make your best attempt - 6. Format as plain text but keep line breaks as they appear so humans can read it easily - """ +from extractly.integrations.ocr import run_ocr as _run_ocr - msg = [ - {"role": "system", "content": system_prompt}, - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": data_uri}}, - ], - }, - ] - try: - resp = _client.chat.completions.create(model=_VISION_MODEL, messages=msg) - return resp.choices[0].message.content.strip() - except Exception as e: - logging.error(f"LLM-OCR failed: {e}") - return "" +def run_ocr(pages: list[Image.Image]) -> str: + return _run_ocr(pages) -def run_ocr(pages: list[Image.Image]) -> str: - """ - Concatenate OCR text from **all** pages with double newlines. - Now LLM-only. - """ - return "\n\n".join(_ocr_llm(p) for p in pages) +__all__ = ["run_ocr"] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..2d24259 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,7 @@ +import sys +from pathlib import Path + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +SRC_PATH = PROJECT_ROOT / "src" +if str(SRC_PATH) not in sys.path: + sys.path.insert(0, str(SRC_PATH)) diff --git a/tests/test_run_store.py b/tests/test_run_store.py new file mode 100644 index 0000000..e68e533 --- /dev/null +++ b/tests/test_run_store.py @@ -0,0 +1,30 @@ +from pathlib import Path + +from extractly.domain.run_store import ExtractionRun, RunDocument, RunStore + + +def test_run_store_round_trip(tmp_path: Path): + store = RunStore(tmp_path) + run = ExtractionRun( + run_id=store.create_run_id(), + started_at="2025-01-01T00:00:00Z", + schema_name="Invoice", + mode="fast", + documents=[ + RunDocument( + filename="sample.pdf", + document_type="Invoice", + confidence=0.8, + extracted={"Invoice Number": "INV-1"}, + corrected={"Invoice Number": "INV-1"}, + field_confidence={"Invoice Number": 0.8}, + ) + ], + ) + + store.save(run) + runs = store.list_runs() + assert runs + loaded = store.load(run.run_id) + assert loaded is not None + assert loaded["schema_name"] == "Invoice" diff --git a/tests/test_schema_validation.py b/tests/test_schema_validation.py new file mode 100644 index 0000000..d333ebf --- /dev/null +++ b/tests/test_schema_validation.py @@ -0,0 +1,34 @@ +from extractly.domain.models import DocumentSchema, SchemaField +from extractly.domain.validation import validate_schema + + +def test_validate_schema_detects_errors(): + schema = DocumentSchema( + name="", + fields=[ + SchemaField(name="", field_type="string"), + SchemaField(name="status", field_type="enum", enum_values=[]), + SchemaField(name="status", field_type="string"), + ], + ) + + result = validate_schema(schema) + assert not result.is_valid + assert "Schema name is required." in result.errors + assert any("Field #1" in err for err in result.errors) + assert any("enum" in err for err in result.errors) + assert any("duplicated" in err for err in result.errors) + + +def test_validate_schema_accepts_valid_schema(): + schema = DocumentSchema( + name="Invoice", + description="Invoice metadata", + fields=[ + SchemaField(name="Invoice Number", field_type="string", required=True), + SchemaField(name="Total", field_type="number"), + ], + ) + + result = validate_schema(schema) + assert result.is_valid