diff --git a/.env.example b/.env.example
new file mode 100644
index 0000000..4b8e69b
--- /dev/null
+++ b/.env.example
@@ -0,0 +1,7 @@
+OPENAI_API_KEY=
+CLASSIFY_MODEL=o4-mini
+EXTRACT_MODEL=o4-mini
+OCR_MODEL=o4-mini
+EXTRACTLY_TIMEOUT_S=40
+EXTRACTLY_MAX_RETRIES=2
+EXTRACTLY_RETRY_BACKOFF_S=1.5
diff --git a/Home.py b/Home.py
index b1fe57c..42e1203 100644
--- a/Home.py
+++ b/Home.py
@@ -1,155 +1,133 @@
-"""
-Landing page โ stylish hero header + live stats.
-"""
+from __future__ import annotations
-from datetime import datetime, timezone
+from datetime import datetime
+from pathlib import Path
import streamlit as st
-from utils.utils import load_feedback
-from dotenv import load_dotenv
-from utils.ui_components import inject_logo, inject_common_styles
-
-# Load API key from .env
-load_dotenv(override=True)
-
-st.set_page_config("Extractly", page_icon="๐ช", layout="wide")
-
-# Inject logo and common styles
-inject_logo("data/assets/data_reply.svg", height="80px") # Adjust height as needed
-inject_common_styles()
-
-# Theme-adaptive CSS using Streamlit's CSS variables
-if "home_css" not in st.session_state:
- st.markdown(
- """
-
- """,
- unsafe_allow_html=True,
- )
- st.session_state.home_css = True
-# Hero header
+from extractly.config import load_config
+from extractly.domain.run_store import RunStore
+from extractly.logging import setup_logging
+from extractly.ui.components import inject_branding, inject_global_styles, section_title
+
+
+config = load_config()
+setup_logging()
+
+st.set_page_config(page_title="Extractly", page_icon="โจ", layout="wide")
+
+inject_branding(Path("data/assets/data_reply.svg"))
+inject_global_styles()
+
+run_store = RunStore(config.run_store_dir)
+runs = run_store.list_runs()
+
st.markdown(
"""
-
-
{val}
-
{label}
+cta_cols = st.columns([1, 1, 2])
+with cta_cols[0]:
+ st.page_link("pages/1_Schema_Studio.py", label="๐ Build a schema", use_container_width=True)
+with cta_cols[1]:
+ st.page_link("pages/2_Extract.py", label="โก Run extraction", use_container_width=True)
+
+st.markdown("
", unsafe_allow_html=True)
+
+section_title("How it works", "A streamlined workflow your clients understand in seconds.")
+steps = st.columns(3)
+steps[0].markdown(
+ """
+
""",
- unsafe_allow_html=True,
- )
+ unsafe_allow_html=True,
+)
+steps[1].markdown(
+ """
+
+ """,
+ unsafe_allow_html=True,
+)
+steps[2].markdown(
+ """
+
+ """,
+ unsafe_allow_html=True,
+)
-st.markdown("---")
+st.markdown("
", unsafe_allow_html=True)
-st.markdown(
- '',
+section_title("Product highlights", "Purpose-built for metadata extraction teams and demos.")
+features = st.columns(3)
+features[0].markdown(
+ """
+
+ """,
+ unsafe_allow_html=True,
+)
+features[1].markdown(
+ """
+
+ """,
+ unsafe_allow_html=True,
+)
+features[2].markdown(
+ """
+
+ """,
unsafe_allow_html=True,
)
+
+st.markdown("
", unsafe_allow_html=True)
+
+section_title("Live workspace snapshot")
+col_a, col_b, col_c = st.columns(3)
+col_a.metric("Runs stored", len(runs))
+latest_run = runs[0]["started_at"] if runs else "โ"
+col_b.metric("Latest run", latest_run)
+col_c.metric("Schemas ready", len(list(config.schema_dir.glob("*.json"))))
+
+st.markdown("---")
+
+section_title("Demo flow")
+st.write(
+ "Use the sample schemas and documents shipped in the repo to walk through a full demo. "
+ "Start in Schema Studio, then upload a sample document in Extract, and finish in Results."
+)
+
+sample_dir = config.sample_data_dir
+if sample_dir.exists():
+ samples = [p.name for p in sample_dir.glob("*.txt")]
+ if samples:
+ st.caption(f"Sample docs: {', '.join(samples)}")
+
+st.info(
+ "Need configuration? Visit Settings to review model choice, retries, and environment checks.",
+ icon="โ๏ธ",
+)
+
+st.caption(f"Last refreshed: {datetime.now().strftime('%Y-%m-%d %H:%M')}")
diff --git a/README.md b/README.md
index 1ba8712..8f7c04f 100644
--- a/README.md
+++ b/README.md
@@ -1,60 +1,79 @@
-# Extractly
+# Extractly โ Document Metadata Extraction Studio
-## Overview
-A plug-and-play Python framework to classify and extract metadata from PDFs and images using a hybrid OCR + GPT pipeline. Supports built-in and custom document types with editable schemas.
+Extractly is a Streamlit app for defining document schemas, classifying incoming files, and extracting structured metadata with traceable runs. The current version delivers a client-ready demo experience with a clean information architecture and modular codebase.
## Features
-- **Modular pipeline**: Preprocessing, classification, extraction, validation, export.
-- **Hybrid OCR + LLM**: LLM (configurable) for robust extraction.
-- **Custom schemas**: Define new document types and field descriptors via JSON in `schemas/` or input in UI.
-- **Configurable models**: Choose between `gpt-o4-mini`, `gpt-o3`, `gpt-4o`, etc.
-- **Validation & context**: Extracted values accompanied by context snippets and LLM reasoning.
-- **CLI & UI**: Streamlit app and CLI interface for automation and demos.
-- **Caching & metrics**: Preprocess results cached to speed up repeated runs; timings logged per step.
-- **Export options**: JSON, CSV, Excel downloads.
-
-## Getting Started
-1. **Clone repo**
- ```bash
- git clone https://github.com/yourorg/universal_extractor.git
- cd universal_extractor
- ```
-2. **Create `.env` with your OpenAI key**
- ```bash
- echo "OPENAI_API_KEY=your_api_key_here" > .env
- ```
-3. **Run locally**
- ```bash
- pip install -r requirements.txt
- streamlit run app.py
- ```
-4. **Or with Docker**
- ```bash
- docker build -t universal_extractor .
- docker run -e OPENAI_API_KEY=$OPENAI_API_KEY -p 8501:8501 universal_extractor
- ```
-
-## Directory Structure
-- `app.py`: Streamlit frontend
-- `cli.py`: Command-line interface
-- `schemas/`: JSON schema files for built-in and custom types
-- `utils/`: Core modules (preprocess, schema management, OpenAI client, classification, extraction)
-
-
-## Extending
-- Add new schema JSON to `schemas/`, restart app, it appears automatically.
-- Implement additional OCR engines by extending `utils/preprocess.py`.
-- Swap or fine-tune models via `utils/openai_client.py`.
-
-
-# Future Work
-
-## Code Improvements
-- Metaprompting sui tipi di documenti attesi
-- Aggiustare cards in Home (feedbacks are scuffed overall)
-- Rivedere la confidence (mettere nel json dei controlli extra tipo rgex o agenti )
-- TO DISCUSS: Improve from errors, saving last N feedbacks for a give doc type (only incorrect ones?) and pass in context
-
-## Business Improvements
-- Tracciare kpi
-- Fare sides fighe in cui dici flusso ed โagentiโ
+- **Schema Studio**: create, edit, validate, and export schemas with live JSON preview.
+- **Extraction pipeline**: classify โ extract โ validate with confidence scores and OCR support.
+- **Run history**: every run is stored locally with outputs, warnings, and errors.
+- **Results workspace**: table view, JSON view, per-field confidence, and CSV/JSON exports.
+- **Configurable LLM usage**: timeouts, retries, model selection via `.env`.
+
+## Quickstart
+```bash
+# 1) Install dependencies
+pip install -e .
+
+# 2) Set API key
+cp .env.example .env
+# edit .env and set OPENAI_API_KEY
+
+# 3) Run the app
+streamlit run Home.py
+```
+
+## Environment Variables
+```
+OPENAI_API_KEY=your_key_here
+CLASSIFY_MODEL=o4-mini
+EXTRACT_MODEL=o4-mini
+OCR_MODEL=o4-mini
+EXTRACTLY_TIMEOUT_S=40
+EXTRACTLY_MAX_RETRIES=2
+EXTRACTLY_RETRY_BACKOFF_S=1.5
+```
+
+## Information Architecture
+- **Home**: landing page + demo entry points.
+- **Schema Studio**: schema creation, validation, templates, import/export.
+- **Extract**: upload files and run classification + extraction.
+- **Results**: browse run history and export data.
+- **Settings**: environment checks and config visibility.
+
+## Repo Structure
+```
+Home.py
+pages/
+ 1_Schema_Studio.py
+ 2_Extract.py
+ 3_Results.py
+ 4_Settings.py
+src/extractly/
+ config.py
+ logging.py
+ domain/
+ integrations/
+ pipeline/
+ ui/
+```
+
+## Demo Script (5 minutes)
+1. Open **Home** and describe the workflow.
+2. Navigate to **Schema Studio** and load the โInvoice Liteโ template.
+3. Save the schema, then go to **Extract**.
+4. Upload `data/sample_docs/sample_invoice.txt` and run extraction.
+5. Open **Results**, select the run, and export JSON/CSV.
+
+## Samples
+- Sample docs are in `data/sample_docs/`.
+- Example schemas live in `schemas/`.
+
+## Tests
+```bash
+pytest
+```
+
+## Notes
+- Runs are stored in `./runs` (configurable).
+- Configure models and timeouts via `.env`.
+- The app relies on Streamlit and the OpenAI API. No secrets are committed.
diff --git a/data/sample_docs/sample_invoice.txt b/data/sample_docs/sample_invoice.txt
new file mode 100644
index 0000000..6100635
--- /dev/null
+++ b/data/sample_docs/sample_invoice.txt
@@ -0,0 +1,8 @@
+Invoice
+Invoice Number: INV-2025-0042
+Invoice Date: 2025-01-15
+Supplier: Northwind Analytics LLC
+Client: Contoso Retail
+Total Amount: 1,248.50 EUR
+VAT: 22%
+Payment Terms: Net 30
diff --git a/data/sample_docs/sample_resume.txt b/data/sample_docs/sample_resume.txt
new file mode 100644
index 0000000..1aac74a
--- /dev/null
+++ b/data/sample_docs/sample_resume.txt
@@ -0,0 +1,6 @@
+Resume
+Name: Jordan Lee
+Role: Product Engineer
+Experience: 6 years
+Location: Milan, IT
+Skills: Streamlit, Python, OCR, LLM pipelines
diff --git a/pages/1Schemas Builder.py b/pages/1Schemas Builder.py
deleted file mode 100644
index 495198d..0000000
--- a/pages/1Schemas Builder.py
+++ /dev/null
@@ -1,218 +0,0 @@
-# sourcery skip: swap-if-else-branches, use-named-expression
-import pandas as pd
-import streamlit as st
-import json
-from pathlib import Path
-from src.schema_manager import SchemaManager
-from utils.ui_components import inject_logo, inject_common_styles
-
-# โโโโโโโโโโโโโโโโโ constants / init โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
-# Get the project root directory (parent of the pages directory)
-PROJECT_ROOT = Path(__file__).parent.parent
-DATA_DIR = PROJECT_ROOT / "schemas"
-CUSTOM_JSON = DATA_DIR / "custom_schemas.json"
-DATA_DIR.mkdir(exist_ok=True)
-
-SM = SchemaManager() # load built-ins
-custom_data = {}
-if CUSTOM_JSON.exists(): # preload customs
- custom_data = json.loads(CUSTOM_JSON.read_text())
- SM.add_custom(custom_data)
-
-# โโโโโโโโโโโโโโโโโ session defaults โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
-st.session_state.setdefault("editing_doc", None)
-st.session_state.setdefault("field_data", [{"name": "", "description": ""}])
-st.session_state.setdefault("rename_open", False)
-st.session_state.setdefault("reset_now", False)
-st.session_state.setdefault("schema_saved", False) # New flag for handling save refresh
-
-# โโโโโโโโโโโโโโโโโ handle one-shot reset flag BEFORE widgets โโโโโโโ
-if st.session_state.reset_now:
- st.session_state.pop("doc_name_input", None) # clear text box value
- st.session_state.editing_doc = None
- st.session_state.field_data = [{"name": "", "description": ""}]
- st.session_state.reset_now = False # consume flag
-
-# โโโโโโโโโโโโโโโโโ handle schema save refresh โโโโโโโโโโโโโโโโโโโโโ
-if st.session_state.schema_saved:
- st.session_state.schema_saved = False # consume flag
- st.session_state.pop("doc_name_input", None) # clear text box value
- st.session_state.pop("field_editor", None) # clear data editor
- st.session_state.editing_doc = None
- st.session_state.field_data = [{"name": "", "description": ""}]
- st.rerun() # This will refresh the page and update the sidebar
-# โโโโโโโโโโโโโโโโโ page meta โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
-st.set_page_config("Schema Builder", "๐งฌ", layout="wide")
-st.title("๐งฌ Schema Builder")
-
-inject_logo("data/assets/data_reply.svg", height="80px") # Adjust height as needed
-inject_common_styles()
-
-# โโโโโโโโโโโโโโโโโ sidebar: schema list & actions โโโโโโโโโโโโโโโโโโ
-with st.sidebar:
- st.subheader("๐ฆ Schemas")
-
- for dt in SM.get_types():
- col_name, col_del = st.columns([3, 1])
- with col_name:
- if st.button(dt, key=f"sel_{dt}"):
- st.session_state.field_data = SM.get(dt) or [
- {"name": "", "description": ""}
- ]
- st.session_state.editing_doc = dt
- st.session_state.rename_open = False
- st.rerun()
-
- with col_del:
- if st.button("๐๏ธ", key=f"del_{dt}", help="Delete"):
- SM.delete(dt)
- custom_data.pop(dt, None)
- CUSTOM_JSON.write_text(
- json.dumps(custom_data, indent=2, ensure_ascii=False)
- )
- if st.session_state.editing_doc == dt:
- st.session_state.editing_doc = None
- st.session_state.rename_open = False
- st.session_state.field_data = [{"name": "", "description": ""}]
- st.rerun()
-
- # โโ Rename + Clear all (same row) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
- col_ren, col_clear = st.columns(2)
-
- # โโ Rename current โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
- with col_ren:
- if st.button("โ๏ธ Rename current", disabled=not st.session_state.editing_doc):
- st.session_state.rename_open = not st.session_state.rename_open
-
- if st.session_state.rename_open and st.session_state.editing_doc:
- with st.expander("Rename schema", expanded=True):
- new_name = st.text_input(
- "New doc-type name",
- value=st.session_state.editing_doc,
- key="rename_input",
- )
-
- # โ
sidebar-safe: no st.columns(), just stacked buttons
- if st.button("โ๏ธ Confirm rename"):
- old = st.session_state.editing_doc
- SM.rename(old, new_name)
- custom_data[new_name] = custom_data.pop(old)
- CUSTOM_JSON.write_text(
- json.dumps(custom_data, indent=2, ensure_ascii=False)
- )
- st.session_state.editing_doc = new_name
- st.session_state.field_data = SM.get(new_name)
- st.session_state.rename_open = False
- st.rerun()
-
- if st.button("โ๏ธ Cancel"):
- st.session_state.rename_open = False
-
- # โโ Clear all โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
- with col_clear:
- if st.button("๐ฎ Clear all CUSTOM schemas"):
- custom_data.clear()
- CUSTOM_JSON.unlink(missing_ok=True)
- st.session_state.editing_doc = None
- st.session_state.rename_open = False
- st.session_state.field_data = [{"name": "", "description": ""}]
- st.rerun()
-
- # โโ JSON import โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
- st.subheader("โก Import JSON schema file")
- up = st.file_uploader(
- label=" ",
- type=["json"],
- label_visibility="collapsed",
- )
- if up:
- try:
- data = json.load(up)
- SM.add_custom(data)
- custom_data.update(data)
- CUSTOM_JSON.write_text(
- json.dumps(custom_data, indent=2, ensure_ascii=False)
- )
- st.success(f"Imported {len(data)} doc-types.")
- st.rerun()
- except Exception as e:
- st.error(f"Bad JSON: {e}")
-
-# โโโโโโโโโโโโโโโโโ MAIN AREA โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
-doc_name = st.text_input(
- "Document type name (create new or edit existing)",
- value=st.session_state.editing_doc or "",
- key="doc_name_input",
-)
-
-if not doc_name:
- st.info("Enter a document type name to begin.")
- st.stop()
-
-# โโ keep or clear table based on doc_name validity โโโโโโโโโโโโโโโโ
-if doc_name != st.session_state.editing_doc:
- if doc_name in SM.get_types(): # switch to known schema
- st.session_state.editing_doc = doc_name
- st.session_state.field_data = SM.get(doc_name) or [
- {"name": "", "description": ""}
- ]
- else: # unknown name โ blank table
- st.session_state.editing_doc = None
- st.session_state.field_data = [{"name": "", "description": ""}]
-
-st.subheader(f"Fields for **{doc_name}**")
-
-schema_desc = st.text_area(
- "High-level description",
- value=SM.get_description(doc_name),
- placeholder="e.g. Italian electronic invoice issued by suppliersโฆ",
-)
-
-raw_table = st.data_editor(
- st.session_state.field_data, # static snapshot
- num_rows="dynamic",
- width="stretch",
- key="field_editor",
-)
-
-# normalise โ
-table_rows = (
- raw_table.fillna("").to_dict(orient="records")
- if isinstance(raw_table, pd.DataFrame)
- else raw_table
-)
-table_rows = [
- {
- "name": r.get("name", " ").strip() if r.get("name", " ") else "",
- "description": r.get("description", r.get("description ", " ")).strip()
- if r.get("description", r.get("description ", " "))
- else "",
- }
- for r in table_rows
-]
-
-
-# โโโโโโโโโโโโโโโโโ save / reset buttons โโโโโโโโโโโโโโโโโโโโโโโโโโโ
-col_save, col_reset = st.columns(2)
-
-with col_save:
- if st.button("๐พ Save schema"):
- clean = [row for row in table_rows if row["name"]]
- if not clean:
- st.error("Must have at least one field.")
- else:
- payload = {"description": schema_desc.strip(), "fields": clean}
- SM.add_custom({doc_name: payload})
- custom_data[doc_name] = payload # โ store whole dict
- CUSTOM_JSON.write_text(
- json.dumps(custom_data, indent=2, ensure_ascii=False)
- )
- st.success(f"Saved {len(clean)} fields for '{doc_name}'.")
- # Set the flag to trigger page refresh and cleanup
- st.session_state.schema_saved = True
- # Trigger immediate rerun to refresh the page
- st.rerun()
-with col_reset:
- if st.button("โบ Reset"):
- st.session_state.reset_now = True
- st.rerun()
diff --git a/pages/1_Schema_Studio.py b/pages/1_Schema_Studio.py
new file mode 100644
index 0000000..32b1bbd
--- /dev/null
+++ b/pages/1_Schema_Studio.py
@@ -0,0 +1,170 @@
+from __future__ import annotations
+
+import json
+from pathlib import Path
+import streamlit as st
+
+from extractly.config import load_config
+from extractly.domain.schema_store import SchemaStore, schemas_to_table, table_to_schema
+from extractly.domain.validation import validate_schema
+from extractly.logging import setup_logging
+from extractly.ui.components import inject_branding, inject_global_styles, section_title
+
+
+config = load_config()
+setup_logging()
+store = SchemaStore(config.schema_dir)
+
+st.set_page_config(page_title="Schema Studio", page_icon="๐งฌ", layout="wide")
+
+inject_branding(Path("data/assets/data_reply.svg"))
+inject_global_styles()
+
+st.title("๐งฌ Schema Studio")
+st.caption("Design, validate, and version schemas used in extraction runs.")
+
+TEMPLATES = {
+ "Invoice Lite": {
+ "description": "Basic invoice metadata for demo flows.",
+ "fields": [
+ {"name": "Invoice Number", "type": "string", "required": True},
+ {"name": "Invoice Date", "type": "date", "required": True},
+ {"name": "Supplier", "type": "string", "required": True},
+ {"name": "Total Amount", "type": "number", "required": True},
+ ],
+ },
+ "Resume Snapshot": {
+ "description": "Lightweight resume extraction fields.",
+ "fields": [
+ {"name": "Candidate Name", "type": "string", "required": True},
+ {"name": "Primary Role", "type": "string", "required": True},
+ {"name": "Years of Experience", "type": "integer"},
+ {"name": "Location", "type": "string"},
+ ],
+ },
+}
+
+schemas = store.list_schemas()
+
+with st.sidebar:
+ st.subheader("Schemas")
+ if schemas:
+ selected_name = st.selectbox(
+ "Choose schema",
+ options=[schema.name for schema in schemas],
+ index=0,
+ )
+ else:
+ selected_name = None
+ st.caption("No schemas yet. Create one below.")
+ st.markdown("---")
+ st.subheader("Templates")
+ template_choice = st.selectbox("Load template", options=["โ"] + list(TEMPLATES))
+ if st.button("Use template", use_container_width=True):
+ if template_choice != "โ":
+ template_payload = TEMPLATES[template_choice]
+ st.session_state["schema_payload"] = {
+ "name": template_choice,
+ "description": template_payload["description"],
+ "rows": template_payload["fields"],
+ }
+ st.rerun()
+
+ st.markdown("---")
+ st.subheader("Import / Export")
+ upload = st.file_uploader("Import schema JSON", type=["json"])
+ if upload:
+ try:
+ payload = json.load(upload)
+ imported = store.import_payload(payload)
+ st.success(f"Imported {len(imported)} schema(s).")
+ st.rerun()
+ except Exception as exc:
+ st.error(f"Import failed: {exc}")
+
+ if selected_name:
+ schema = store.get_schema(selected_name)
+ if schema:
+ export_json = store.export_schema(schema)
+ st.download_button(
+ "Download schema JSON",
+ data=export_json,
+ file_name=f"{selected_name}.json",
+ mime="application/json",
+ use_container_width=True,
+ )
+
+ if selected_name and st.button("Delete schema", type="secondary", use_container_width=True):
+ store.delete_schema(selected_name)
+ st.success("Schema deleted.")
+ st.rerun()
+
+if schemas and selected_name:
+ active_schema = store.get_schema(selected_name)
+else:
+ active_schema = None
+
+payload = st.session_state.get(
+ "schema_payload",
+ {
+ "name": active_schema.name if active_schema else "",
+ "description": active_schema.description if active_schema else "",
+ "rows": schemas_to_table(active_schema) if active_schema else [],
+ },
+)
+
+section_title("Schema editor")
+name = st.text_input("Schema name", value=payload.get("name", ""))
+description = st.text_area("Description", value=payload.get("description", ""))
+
+rows = payload.get("rows", [])
+
+data = st.data_editor(
+ rows,
+ num_rows="dynamic",
+ use_container_width=True,
+ column_config={
+ "name": st.column_config.TextColumn("Field name", required=True),
+ "type": st.column_config.SelectboxColumn(
+ "Type",
+ options=["string", "number", "integer", "boolean", "date", "enum", "object", "array"],
+ required=True,
+ ),
+ "required": st.column_config.CheckboxColumn("Required"),
+ "description": st.column_config.TextColumn("Description"),
+ "example": st.column_config.TextColumn("Example"),
+ "enum": st.column_config.TextColumn("Enum values (comma-separated)"),
+ },
+ key="schema_editor",
+)
+
+schema = table_to_schema(name=name, description=description, rows=data)
+validation = validate_schema(schema)
+
+col_a, col_b = st.columns([1, 1])
+with col_a:
+ if st.button("๐พ Save schema", use_container_width=True):
+ result = store.save_schema(schema)
+ if result.is_valid:
+ st.success("Schema saved.")
+ st.session_state.pop("schema_payload", None)
+ st.rerun()
+ else:
+ st.error("Schema failed validation. Fix errors below.")
+
+with col_b:
+ if st.button("Reset", use_container_width=True, type="secondary"):
+ st.session_state.pop("schema_payload", None)
+ st.rerun()
+
+section_title("Validation")
+if validation.errors:
+ st.error("\n".join(validation.errors))
+else:
+ st.success("Schema is valid and ready for extraction.")
+
+if validation.warnings:
+ st.warning("\n".join(validation.warnings))
+
+section_title("Live JSON preview")
+st.code(store.export_schema(schema), language="json")
diff --git a/pages/2Extraction.py b/pages/2Extraction.py
deleted file mode 100644
index 8332509..0000000
--- a/pages/2Extraction.py
+++ /dev/null
@@ -1,605 +0,0 @@
-"""
-Batch Extraction โ multi-doc classification, OCR, extraction & correction UI.
-"""
-
-from __future__ import annotations
-import contextlib
-import io
-import json
-from datetime import datetime, timezone
-import streamlit as st
-import pandas as pd
-
-from src.ocr_engine import run_ocr
-from utils.preprocess import preprocess
-from src.schema_manager import SchemaManager
-from src.classifier import classify
-from src.extractor import extract
-from utils.ui_components import inject_logo, inject_common_styles
-from utils.utils import (
- generate_doc_id,
- load_feedback,
- upsert_feedback,
- diff_fields,
-)
-
-# โโโโโ page & globals โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
-st.set_page_config("Extraction", page_icon="๐", layout="wide")
-st.title("๐ Extraction")
-schema_mgr = SchemaManager()
-
-inject_logo("data/assets/data_reply.svg", height="80px") # Adjust height as needed
-inject_common_styles()
-
-# โโโโโ CSS for image hover zoom โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
-st.markdown(
- """
-
-""",
- unsafe_allow_html=True,
-)
-
-# โโโโโ sidebar: settings โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
-with st.sidebar:
- st.header("โ๏ธ Options")
- run_ocr_checkbox = st.checkbox("Run OCR first", value=False)
- calc_conf = st.checkbox("Compute confidences", value=False)
- conf_threshold = st.slider(
- "Confidence threshold (%)",
- 0,
- 100,
- 70,
- help="Documents below this confidence will be flagged",
- )
- # NEW: System Prompts Section
- st.header("๐ค System Prompts")
-
- # Updated default prompts
- DEFAULT_CLASSIFIER_PROMPT = """You are an expert document classifier specialized in analyzing business and legal documents.
- Your task is to classify the document type based on visual layout, text content, headers, logos, and structural elements. Consider:
- - Document formatting and layout patterns
- - Official headers, letterheads, and logos
- - Specific terminology and field labels
- - Regulatory compliance markers
- - Standard document structures
-
- Respond with only the most accurate document type from the provided list. If unsure, choose "Unknown".
- """
-
- DEFAULT_EXTRACTOR_PROMPT = """You are a precise metadata extraction specialist. Your task is to extract specific field values from documents with high accuracy.
-
- Instructions:
- 1. Analyze the document image carefully for text, tables, and structured data
- 2. Extract only the exact values for the requested fields
- 3. Use OCR context when provided to improve accuracy
- 4. If a field is not clearly visible or readable, return null
- 5. Maintain original formatting for dates, numbers, and codes
- 6. For confidence scores, rate 0.0-1.0 based on text clarity and certainty
-
- Return valid JSON with three sections:
- - "metadata": field values as key-value pairs
- - "snippets": supporting text evidence for each field
- - "confidence": confidence scores (0.0-1.0) for each extraction
- """
-
- with st.expander("๐ Edit Prompts", expanded=False):
- st.subheader("Classification Prompt")
- classifier_prompt = st.text_area(
- "System prompt for document classification:",
- value=st.session_state.get("classifier_prompt", DEFAULT_CLASSIFIER_PROMPT),
- height=150,
- help="This prompt guides how the AI classifies document types",
- key="classifier_prompt_input",
- )
-
- st.subheader("Extraction Prompt")
- extractor_prompt = st.text_area(
- "System prompt for metadata extraction:",
- value=st.session_state.get("extractor_prompt", DEFAULT_EXTRACTOR_PROMPT),
- height=200,
- help="This prompt guides how the AI extracts metadata fields",
- key="extractor_prompt_input",
- )
-
- col1, col2 = st.columns(2)
- with col1:
- if st.button("๐พ Save Prompts"):
- st.session_state["classifier_prompt"] = classifier_prompt
- st.session_state["extractor_prompt"] = extractor_prompt
- st.success("Prompts saved!")
-
- with col2:
- if st.button("๐ Reset to Default"):
- st.session_state["classifier_prompt"] = DEFAULT_CLASSIFIER_PROMPT
- st.session_state["extractor_prompt"] = DEFAULT_EXTRACTOR_PROMPT
- st.rerun()
-
- # โ๏ธโโโโโโโโโโโโโโโโโโโโโโโโ OCR-preview toggle โโโโโโโโโโโโโโโโโโโโโโโ๏ธ
- if run_ocr_checkbox and "ocr_map" in st.session_state:
- with st.sidebar.expander("๐ Preview raw OCR text", expanded=False):
- for name, ocr_txt in st.session_state["ocr_map"].items():
- st.markdown(f"**{name}**")
- st.text_area(
- label=" ",
- value=ocr_txt[:10_000],
- height=200,
- key=f"ocr_{name}",
- label_visibility="collapsed",
- )
-
-# โโโโโ file uploader โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
-files = st.file_uploader(
- "Upload PDFs or images",
- type=["pdf", "png", "jpg", "jpeg"],
- accept_multiple_files=True,
-)
-
-if not files:
- st.info("Awaiting uploads โฆ")
- st.stop()
-
-# โโโโโ load past corrections by filename โโโโโโโโโโโโโโโโโโโโโโโโโโโโ
-past = load_feedback()
-corrected_map = {}
-past_metadata_map = {}
-for r in past:
- if r.get("file_name"):
- corrected_map[r["file_name"]] = r["doc_type"]
- if r.get("metadata_corrected") and r["metadata_corrected"] != "{}":
- with contextlib.suppress(json.JSONDecodeError):
- past_metadata_map[r["file_name"]] = json.loads(r["metadata_corrected"])
-
-# โโโโโ ensure session state โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
-st.session_state.setdefault("doc_rows", [])
-st.session_state.setdefault("extracted", False)
-doc_rows: list[dict] = st.session_state["doc_rows"]
-
-# โโโโโ run buttons โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
-c1, c2, c3 = st.columns(3)
-seq_clicked = c1.button("๐ Classify & Extract", width="stretch")
-classify_clicked = c2.button("โถ๏ธ Classify Only", width="stretch")
-extract_clicked = c3.button("โก Extract All", disabled=not doc_rows, width="stretch")
-
-if st.button("๐ Start over (keep uploads)", key="reset_all", type="secondary"):
- for k in ("doc_rows", "extracted", "ocr_map", "ocr_preview"):
- st.session_state.pop(k, None)
- st.toast("Workspace cleared โ you can run the pipeline again.", icon="๐")
- st.rerun()
-
-st.markdown("
", unsafe_allow_html=True)
-
-# โโโโโ 1๏ธโฃ on-click: build doc_rows โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
-if classify_clicked or seq_clicked:
- st.session_state["extracted"] = False
- st.session_state["ocr_preview"] = ""
- st.session_state["doc_rows"] = []
- prog = st.progress(0.0, "Classifyingโฆ")
- cls_results = []
-
- # CHANGE: Add Unknown/Other to classification candidates
- classification_candidates = schema_mgr.get_types() + ["Unknown", "Other"]
-
- for i, up in enumerate(files, start=1):
- images = preprocess(up)
- fname = up.name
- doc_id = generate_doc_id(up) or f"id_{i}"
- ocr_txt = ""
-
- if run_ocr_checkbox:
- ocr_txt = run_ocr(images) # Removed engine parameter - LLM only
- st.session_state.setdefault("ocr_map", {})
- st.session_state["ocr_map"][fname] = ocr_txt
- if (
- "ocr_preview" in st.session_state
- and not st.session_state["ocr_preview"]
- ):
- st.session_state["ocr_preview"] = ocr_txt
-
- with io.BytesIO() as buf:
- img = images[0].copy()
- img.thumbnail((140, 140))
- img.save(buf, format="PNG")
- thumb = buf.getvalue()
-
- # ---------- build rows with enhanced classification -------------------------
- if fname in corrected_map:
- doc_type = corrected_map[fname]
- confidence = None
- reasoning = "retrieved from past correction"
- else:
- cls_resp = classify(
- images,
- classification_candidates,
- use_confidence=calc_conf,
- n_votes=5,
- system_prompt=st.session_state.get(
- "classifier_prompt", DEFAULT_CLASSIFIER_PROMPT
- ),
- )
- doc_type = cls_resp["doc_type"]
- confidence = cls_resp.get("confidence")
- reasoning = cls_resp.get("reasoning", "")
-
- cls_results.append(
- {
- "file_name": fname,
- "doc_id": doc_id,
- "thumb": thumb,
- "images": images,
- "detected": doc_type,
- "final_type": doc_type,
- "reasoning": reasoning,
- "fields": None,
- "fields_corrected": None,
- "confidence": confidence,
- }
- )
- prog.progress(i / len(files), f"Classifying: {fname}")
-
- st.session_state["doc_rows"] = cls_results
- doc_rows = st.session_state["doc_rows"]
- st.toast("Classification finished โ adjust below if needed.", icon="โ
")
- if not seq_clicked:
- st.rerun()
-
-# โโโโโ 2๏ธโฃ render classification review โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
-if doc_rows and not st.session_state.get("extracted"):
- st.subheader("๐ Review detected document types")
-
- for idx, row in enumerate(doc_rows):
- # CHANGE: Enhanced confidence display and unrecognized document handling
- confidence_val = row.get("confidence")
- conf_pct = int(confidence_val * 100) if confidence_val is not None else None
- is_unrecognized = row["final_type"] in ["Unknown", "Other"]
- is_low_confidence = conf_pct is not None and conf_pct < conf_threshold
-
- # Color-coded title based on status
- if is_unrecognized or is_low_confidence:
- status_emoji = "โ ๏ธ"
- title_color = "๐ด"
- else:
- status_emoji = "โ
"
- title_color = "๐ข"
-
- # Build title with conditional confidence display
- if calc_conf and conf_pct is not None:
- title = f"{status_emoji} {row['file_name']} โ {row['final_type']} {title_color} {conf_pct}%"
- else:
- title = f"{status_emoji} {row['file_name']} โ {row['final_type']}"
-
- with st.expander(
- title,
- expanded=is_unrecognized
- or is_low_confidence, # Auto-expand problematic ones
- ):
- col1, col2, col3 = st.columns([1, 2, 1])
-
- with col1:
- st.image(row["thumb"], width=140)
-
- with col2:
- # CHANGE: Enhanced type selection with Unknown/Other
- choices = schema_mgr.get_types() + ["Unknown", "Other"]
- if row["final_type"] not in choices:
- choices.insert(0, row["final_type"])
-
- sel = st.selectbox(
- "Document type",
- choices,
- index=choices.index(row["final_type"]),
- key=f"type_{row['doc_id']}",
- )
- st.session_state["doc_rows"][idx]["final_type"] = sel
-
- # Enhanced status display
- if is_unrecognized:
- st.error("โ Unrecognized document type")
- elif is_low_confidence:
- st.warning(f"โ ๏ธ Low confidence ({conf_pct}% < {conf_threshold}%)")
- with col3:
- # CHANGE: Enhanced confidence display in % - only show if calc_conf is enabled
- if calc_conf and conf_pct is not None:
- if conf_pct >= 80:
- st.success(f"๐ข {conf_pct}%")
- elif conf_pct >= 60:
- st.warning(f"๐ก {conf_pct}%")
- else:
- st.error(f"๐ด {conf_pct}%")
-
- st.caption(f"Threshold: {conf_threshold}%")
- elif calc_conf:
- st.caption("Confidence not computed")
-
- if st.button("๐พ Save type corrections"):
- for row in doc_rows:
- upsert_feedback(
- {
- "doc_id": row["doc_id"],
- "file_name": row["file_name"],
- "doc_type": row["final_type"],
- "metadata_extracted": "{}",
- "metadata_corrected": "{}",
- "timestamp": datetime.now(timezone.utc).isoformat(),
- }
- )
- st.toast("Document types saved ๐", icon="๐พ")
-
-st.markdown("
", unsafe_allow_html=True)
-
-# โโโโโ 3๏ธโฃ ON-CLICK: run extraction with unrecognized document handling โโโโโโโโโโโโโโโโ
-if (extract_clicked or seq_clicked) and not st.session_state.get("extracted"):
- st.subheader("๐ฆ Extracting Data...")
- prog = st.progress(0.0, "Extractingโฆ")
-
- for i, row in enumerate(doc_rows, start=1):
- if "cancel_extraction" in st.session_state:
- break
-
- # CHANGE: Skip extraction for unrecognized documents
- if row["final_type"] in ["Unknown", "Other"]:
- st.session_state["doc_rows"][i - 1]["fields"] = {}
- st.session_state["doc_rows"][i - 1]["field_conf"] = {}
- st.session_state["doc_rows"][i - 1]["fields_corrected"] = {}
- prog.progress(
- i / len(doc_rows), f"Skipping unrecognized: {row['file_name']}"
- )
- continue
-
- schema = schema_mgr.get(row["final_type"]) or []
- if not schema:
- st.session_state["doc_rows"][i - 1]["fields"] = {}
- st.session_state["doc_rows"][i - 1]["field_conf"] = {}
- st.session_state["doc_rows"][i - 1]["fields_corrected"] = {}
- continue
-
- ocr_txt = None
- if run_ocr_checkbox:
- ocr_txt = run_ocr(row["images"]) # Removed engine parameter - LLM only
- out = extract(
- row["images"],
- schema,
- ocr_text=ocr_txt,
- with_confidence=calc_conf,
- system_prompt=st.session_state.get(
- "extractor_prompt", DEFAULT_EXTRACTOR_PROMPT
- ),
- ) or {"metadata": {}, "confidence": {}}
-
- if not any(out["metadata"].values()):
- st.toast(f"No fields detected in {row['file_name']}", icon="โ ๏ธ")
-
- st.session_state["doc_rows"][i - 1]["fields"] = out["metadata"]
- st.session_state["doc_rows"][i - 1]["field_conf"] = out.get("confidence", {})
-
- file_name = row["file_name"]
- if file_name in past_metadata_map:
- st.session_state["doc_rows"][i - 1]["fields_corrected"] = past_metadata_map[
- file_name
- ]
- else:
- st.session_state["doc_rows"][i - 1]["fields_corrected"] = out[
- "metadata"
- ].copy()
-
- prog.progress(i / len(doc_rows), f"Extracting from: {row['file_name']}")
-
- if st.session_state.get("extracted") is False:
- st.button("๐ Cancel extraction", key="cancel_extraction")
-
- st.session_state["extracted"] = True
- extract_clicked = False
- seq_clicked = False
- st.toast("Extraction completed โ review below", icon="๐ฆ")
- st.rerun()
-
-# โโโโโ 4๏ธโฃ RENDER: extraction review & correction UI โโโโโโโโโโโโโโโโ
-if st.session_state.get("extracted"):
- st.subheader("๐ฆ Review and Correct Extracted Data")
-
- for i, row in enumerate(doc_rows):
- # CHANGE: Handle unrecognized documents in review
- if row["final_type"] in ["Unknown", "Other"]:
- with st.expander(
- f"โ ๏ธ {row['file_name']} โ {row['final_type']} (SKIPPED)", expanded=False
- ):
- st.warning(
- "This document was skipped because it's unrecognized. Please reclassify it first."
- )
- continue
-
- with st.expander(f"{row['file_name']} โ {row['final_type']}", expanded=True):
- if not row.get("fields_corrected"):
- st.warning("No fields were extracted or loaded for this document type.")
- continue
-
- # CHANGE: Enhanced confidence display in data editor (% format)
- row_conf = row.get("field_conf", {})
- # Convert confidence to percentage format
- conf_display = {}
- for k, v in row_conf.items():
- if isinstance(v, (int, float)):
- conf_display[k] = f"{int(v * 100)}%"
- else:
- conf_display[k] = str(v) if v else ""
-
- df = pd.DataFrame(
- {
- "Field": list(row["fields_corrected"].keys()),
- "Value": list(row["fields_corrected"].values()),
- "Conf.": [
- conf_display.get(k, "") for k in row["fields_corrected"]
- ], # CHANGE: Show %
- }
- )
-
- edited_df = st.data_editor(
- df,
- key=f"grid_{row['doc_id']}",
- disabled=["Field", "Conf."],
- width="stretch",
- )
-
- updated_values = pd.Series(
- edited_df.Value.values, index=edited_df.Field
- ).to_dict()
-
- st.session_state["doc_rows"][i]["fields_corrected"] = updated_values
-
- # โโ enhanced action buttons โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
- col_re, col_save, col_download = st.columns([0.08, 0.08, 0.08])
-
- with col_re:
- if st.button("โป Re-extract", key=f"reextract_{row['doc_id']}"):
- schema = schema_mgr.get(row["final_type"]) or []
- ocr_txt = (
- st.session_state.get("ocr_map", {}).get(row["file_name"])
- if run_ocr_checkbox
- else None
- )
- new_out = extract(
- row["images"],
- schema,
- ocr_text=ocr_txt,
- with_confidence=calc_conf,
- system_prompt=st.session_state.get(
- "extractor_prompt", DEFAULT_EXTRACTOR_PROMPT
- ),
- ) or {"metadata": {}, "confidence": {}}
-
- row["fields"] = new_out["metadata"]
- row["field_conf"] = new_out.get("confidence", {})
- row["fields_corrected"] = new_out["metadata"].copy()
- st.toast(f"Re-extracted {row['file_name']}", icon="โ
")
- st.rerun()
-
- with col_save:
- if st.button("๐พ Save", key=f"save_{row['doc_id']}"):
- current_row = st.session_state["doc_rows"][i]
- changed = diff_fields(
- current_row["fields"], current_row["fields_corrected"]
- )
- upsert_feedback(
- {
- "doc_id": current_row["doc_id"],
- "file_name": current_row["file_name"],
- "doc_type": current_row["final_type"],
- "metadata_extracted": json.dumps(
- current_row["fields"], ensure_ascii=False
- ),
- "metadata_corrected": json.dumps(
- current_row["fields_corrected"], ensure_ascii=False
- ),
- "fields_corrected": changed,
- "timestamp": datetime.now(timezone.utc).isoformat(),
- }
- )
- st.toast(f"Saved {current_row['file_name']}", icon="๐พ")
-
- # NEW: Individual download button
- with col_download:
- if st.button("๐ JSON", key=f"download_{row['doc_id']}"):
- # Fix: Handle None confidence properly
- confidence_val = row.get("confidence")
- if confidence_val is not None:
- confidence_display = f"{int(confidence_val * 100)}%"
- else:
- confidence_display = "N/A"
-
- result = {
- "document_info": {
- "filename": row["file_name"],
- "document_type": row["final_type"],
- "confidence": confidence_display, # Fixed
- "timestamp": datetime.now().isoformat(),
- },
- "original_extraction": row.get("fields", {}),
- "corrected_metadata": row.get("fields_corrected", {}),
- "confidence_scores": {
- k: f"{int(v * 100)}%"
- for k, v in row.get("field_conf", {}).items()
- if isinstance(v, (int, float))
- and v is not None # Added None check
- },
- "processing_info": {
- "ocr_used": run_ocr_checkbox,
- "confidence_threshold": f"{conf_threshold}%",
- },
- }
-
- # Enhanced bulk save with session summary
- col1, col2 = st.columns([3, 1])
- with col1:
- if st.button("๐พ Save all corrections", width="stretch", type="primary"):
- saved_count = 0
- for row in st.session_state["doc_rows"]:
- if row["final_type"] not in ["Unknown", "Other"]:
- upsert_feedback(
- {
- "doc_id": row["doc_id"],
- "file_name": row["file_name"],
- "doc_type": row["final_type"],
- "metadata_extracted": json.dumps(
- row["fields"], ensure_ascii=False
- ),
- "metadata_corrected": json.dumps(
- row["fields_corrected"], ensure_ascii=False
- ),
- "timestamp": datetime.now(timezone.utc).isoformat(),
- }
- )
- saved_count += 1
- st.toast(f"Saved {saved_count} documents โ thank you!", icon="๐พ")
-
- with col2:
- # NEW: Bulk download button
- if st.button("๐ฅ Download All JSON", width="stretch"):
- # Prepare bulk export
- bulk_results = {
- "export_info": {
- "timestamp": datetime.now().isoformat(),
- "total_documents": len(doc_rows),
- "processed_documents": len(
- [
- r
- for r in doc_rows
- if r["final_type"] not in ["Unknown", "Other"]
- ]
- ),
- "confidence_threshold": f"{conf_threshold}%",
- },
- "documents": [],
- }
-
- for row in doc_rows:
- confidence_val = row.get("confidence")
- if confidence_val is not None:
- confidence_display = f"{int(confidence_val * 100)}%"
- else:
- confidence_display = "N/A"
-
- doc_result = {
- "filename": row["file_name"],
- "document_type": row["final_type"],
- "confidence": confidence_display, # Fixed
- "original_extraction": row.get("fields", {}),
- "corrected_metadata": row.get("fields_corrected", {}),
- "confidence_scores": {
- k: f"{int(v * 100)}%"
- for k, v in row.get("field_conf", {}).items()
- if isinstance(v, (int, float))
- and v is not None # Added None check
- },
- }
- bulk_results["documents"].append(doc_result)
-
- json_str = json.dumps(bulk_results, indent=2, ensure_ascii=False)
- st.download_button(
- label="๐ Download Complete Session",
- data=json_str,
- file_name=f"bulk_extraction_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
- mime="application/json",
- key="bulk_download",
- )
diff --git a/pages/2_Extract.py b/pages/2_Extract.py
new file mode 100644
index 0000000..19a5b91
--- /dev/null
+++ b/pages/2_Extract.py
@@ -0,0 +1,129 @@
+from __future__ import annotations
+
+from datetime import datetime
+from pathlib import Path
+import streamlit as st
+from PIL import Image
+
+from extractly.config import load_config
+from extractly.domain.run_store import RunStore
+from extractly.domain.schema_store import SchemaStore
+from extractly.integrations.preprocess import preprocess
+from extractly.pipeline.classification import DEFAULT_CLASSIFIER_PROMPT
+from extractly.pipeline.extraction import DEFAULT_EXTRACTION_PROMPT
+from extractly.pipeline.runner import PipelineOptions, run_pipeline
+from extractly.logging import setup_logging
+from extractly.ui.components import inject_branding, inject_global_styles, section_title
+
+
+config = load_config()
+setup_logging()
+store = SchemaStore(config.schema_dir)
+run_store = RunStore(config.run_store_dir)
+
+st.set_page_config(page_title="Extract", page_icon="โก", layout="wide")
+
+inject_branding(Path("data/assets/data_reply.svg"))
+inject_global_styles()
+
+st.title("โก Run Extraction")
+st.caption("Upload documents, select a schema, and run the extraction pipeline.")
+
+schemas = store.list_schemas()
+if not schemas:
+ st.warning("No schemas found. Create one in Schema Studio first.")
+ st.stop()
+
+schema_name = st.selectbox("Schema", options=[schema.name for schema in schemas])
+active_schema = store.get_schema(schema_name)
+
+left, right = st.columns([2, 1])
+with left:
+ files = st.file_uploader(
+ "Upload documents",
+ type=["pdf", "png", "jpg", "jpeg", "txt"],
+ accept_multiple_files=True,
+ )
+
+with right:
+ section_title("Pipeline options")
+ enable_ocr = st.toggle("Enable OCR", value=False)
+ compute_conf = st.toggle("Field confidence", value=True)
+ mode = st.radio("Mode", options=["fast", "accurate"], horizontal=True)
+
+with st.expander("Advanced prompts", expanded=False):
+ classifier_prompt = st.text_area(
+ "Classifier prompt",
+ value=st.session_state.get("classifier_prompt", DEFAULT_CLASSIFIER_PROMPT),
+ height=140,
+ )
+ extractor_prompt = st.text_area(
+ "Extraction prompt",
+ value=st.session_state.get("extractor_prompt", DEFAULT_EXTRACTION_PROMPT),
+ height=160,
+ )
+ if st.button("Save prompts"):
+ st.session_state["classifier_prompt"] = classifier_prompt
+ st.session_state["extractor_prompt"] = extractor_prompt
+ st.success("Prompts saved.")
+
+st.markdown("---")
+
+section_title("Pipeline steps")
+steps = st.columns(5)
+steps[0].markdown("โ
1. Parse")
+steps[1].markdown("โ
2. Classify")
+steps[2].markdown("โ
3. Extract")
+steps[3].markdown("โ
4. Validate")
+steps[4].markdown("โ
5. Export")
+
+if st.button("Run extraction", type="primary", use_container_width=True):
+ if not files:
+ st.error("Upload at least one document.")
+ st.stop()
+
+ parsed_files = []
+ progress = st.progress(0.0, "Parsing files")
+
+ for idx, upload in enumerate(files, start=1):
+ filename = upload.name
+ if filename.lower().endswith(".txt"):
+ content = upload.read().decode("utf-8", errors="ignore")
+ blank_image = Image.new("RGB", (800, 1000), color="white")
+ images = [blank_image]
+ parsed_files.append(
+ {
+ "name": filename,
+ "images": images,
+ "ocr_text": content,
+ "doc_type_override": active_schema.name,
+ }
+ )
+ else:
+ images = preprocess(upload, filename)
+ parsed_files.append({"name": filename, "images": images})
+
+ progress.progress(idx / len(files), f"Parsed {filename}")
+
+ options = PipelineOptions(
+ enable_ocr=enable_ocr,
+ compute_confidence=compute_conf,
+ mode=mode,
+ classifier_prompt=st.session_state.get("classifier_prompt"),
+ extraction_prompt=st.session_state.get("extractor_prompt"),
+ )
+
+ run = run_pipeline(
+ files=parsed_files,
+ schema=active_schema,
+ candidates=[schema.name for schema in schemas] + ["Unknown", "Other"],
+ run_store=run_store,
+ options=options,
+ )
+
+ st.session_state["latest_run_id"] = run.run_id
+ st.success("Extraction completed.")
+ st.page_link("pages/3_Results.py", label="View results", use_container_width=True)
+
+ st.caption(f"Run {run.run_id} stored at {config.run_store_dir}")
+ st.code(f"Run completed at {datetime.now().isoformat()}")
diff --git a/pages/3_Results.py b/pages/3_Results.py
new file mode 100644
index 0000000..919385f
--- /dev/null
+++ b/pages/3_Results.py
@@ -0,0 +1,131 @@
+from __future__ import annotations
+
+import csv
+import io
+import json
+from pathlib import Path
+import streamlit as st
+
+from extractly.config import load_config
+from extractly.domain.run_store import RunStore
+from extractly.logging import setup_logging
+from extractly.ui.components import inject_branding, inject_global_styles, section_title
+
+
+config = load_config()
+setup_logging()
+run_store = RunStore(config.run_store_dir)
+
+st.set_page_config(page_title="Results", page_icon="๐", layout="wide")
+
+inject_branding(Path("data/assets/data_reply.svg"))
+inject_global_styles()
+
+st.title("๐ Results")
+st.caption("Browse extraction runs, review outputs, and export data.")
+
+runs = run_store.list_runs()
+if not runs:
+ st.info("No runs yet. Run an extraction first.")
+ st.stop()
+
+run_ids = [run["run_id"] for run in runs]
+latest_id = st.session_state.get("latest_run_id")
+selected_id = st.selectbox(
+ "Select a run",
+ options=run_ids,
+ index=run_ids.index(latest_id) if latest_id in run_ids else 0,
+)
+
+run = run_store.load(selected_id)
+if not run:
+ st.error("Run not found.")
+ st.stop()
+
+section_title("Run summary")
+summary_cols = st.columns(4)
+summary_cols[0].metric("Run ID", run["run_id"])
+summary_cols[1].metric("Schema", run.get("schema_name", "โ"))
+summary_cols[2].metric("Mode", run.get("mode", "โ"))
+summary_cols[3].metric("Documents", len(run.get("documents", [])))
+
+st.markdown("---")
+
+section_title("Documents")
+doc_rows = []
+for doc in run.get("documents", []):
+ doc_rows.append(
+ {
+ "filename": doc.get("filename"),
+ "document_type": doc.get("document_type"),
+ "confidence": doc.get("confidence"),
+ "warnings": len(doc.get("warnings", [])),
+ "errors": len(doc.get("errors", [])),
+ }
+ )
+
+st.dataframe(doc_rows, use_container_width=True)
+
+selected_doc_name = st.selectbox(
+ "View document details",
+ options=[doc["filename"] for doc in run.get("documents", [])],
+)
+
+selected_doc = next(
+ (doc for doc in run.get("documents", []) if doc["filename"] == selected_doc_name),
+ None,
+)
+
+if selected_doc:
+ section_title("Extracted fields")
+ field_rows = [
+ {
+ "field": key,
+ "value": value,
+ "confidence": selected_doc.get("field_confidence", {}).get(key, ""),
+ }
+ for key, value in selected_doc.get("corrected", {}).items()
+ ]
+ st.dataframe(field_rows, use_container_width=True)
+
+ section_title("JSON output")
+ st.json(selected_doc.get("corrected", {}))
+
+ if selected_doc.get("warnings"):
+ st.warning("\n".join(selected_doc.get("warnings")))
+ if selected_doc.get("errors"):
+ st.error("\n".join(selected_doc.get("errors")))
+
+section_title("Exports")
+
+json_payload = json.dumps(run, indent=2, ensure_ascii=False)
+
+st.download_button(
+ "Download run JSON",
+ data=json_payload,
+ file_name=f"{selected_id}.json",
+ mime="application/json",
+)
+
+csv_buffer = io.StringIO()
+fieldnames = {"filename", "document_type", "confidence"}
+for doc in run.get("documents", []):
+ fieldnames.update(doc.get("corrected", {}).keys())
+
+writer = csv.DictWriter(csv_buffer, fieldnames=sorted(fieldnames))
+writer.writeheader()
+for doc in run.get("documents", []):
+ row = {
+ "filename": doc.get("filename"),
+ "document_type": doc.get("document_type"),
+ "confidence": doc.get("confidence"),
+ }
+ row.update(doc.get("corrected", {}))
+ writer.writerow(row)
+
+st.download_button(
+ "Download CSV",
+ data=csv_buffer.getvalue(),
+ file_name=f"{selected_id}.csv",
+ mime="text/csv",
+)
diff --git a/pages/4_Settings.py b/pages/4_Settings.py
new file mode 100644
index 0000000..1605316
--- /dev/null
+++ b/pages/4_Settings.py
@@ -0,0 +1,44 @@
+from __future__ import annotations
+
+from pathlib import Path
+import streamlit as st
+
+from extractly.config import load_config
+from extractly.ui.components import inject_branding, inject_global_styles, section_title
+from extractly.logging import setup_logging
+
+
+config = load_config()
+setup_logging()
+
+st.set_page_config(page_title="Settings", page_icon="โ๏ธ", layout="wide")
+
+inject_branding(Path("data/assets/data_reply.svg"))
+inject_global_styles()
+
+st.title("โ๏ธ Settings")
+st.caption("Review configuration, models, and environment status.")
+
+section_title("Environment")
+cols = st.columns(3)
+cols[0].metric("OpenAI key", "โ
Found" if config.openai_api_key else "โ Missing")
+cols[1].metric("Timeout (s)", config.request_timeout_s)
+cols[2].metric("Max retries", config.max_retries)
+
+section_title("Models")
+model_cols = st.columns(3)
+model_cols[0].text_input("Classifier model", value=config.classify_model, disabled=True)
+model_cols[1].text_input("Extractor model", value=config.extract_model, disabled=True)
+model_cols[2].text_input("OCR model", value=config.ocr_model, disabled=True)
+
+section_title("Directories")
+st.write(f"Schemas: `{config.schema_dir}`")
+st.write(f"Runs: `{config.run_store_dir}`")
+st.write(f"Sample docs: `{config.sample_data_dir}`")
+
+section_title("Notes")
+st.info(
+ "To update models or pipeline settings, set the environment variables in your `.env` file. "
+ "Run `streamlit run Home.py` after changing them.",
+ icon="๐",
+)
diff --git a/schemas/demo_invoice.json b/schemas/demo_invoice.json
new file mode 100644
index 0000000..f6cdd26
--- /dev/null
+++ b/schemas/demo_invoice.json
@@ -0,0 +1,48 @@
+{
+ "Invoice Demo": {
+ "description": "Client-ready demo schema for invoice metadata.",
+ "version": "v1",
+ "fields": [
+ {
+ "name": "Invoice Number",
+ "type": "string",
+ "required": true,
+ "description": "Unique invoice identifier",
+ "example": "INV-2025-0042"
+ },
+ {
+ "name": "Invoice Date",
+ "type": "date",
+ "required": true,
+ "description": "Issue date",
+ "example": "2025-01-15"
+ },
+ {
+ "name": "Supplier",
+ "type": "string",
+ "required": true,
+ "description": "Supplier company name"
+ },
+ {
+ "name": "Client",
+ "type": "string",
+ "required": true,
+ "description": "Client company name"
+ },
+ {
+ "name": "Total Amount",
+ "type": "number",
+ "required": true,
+ "description": "Invoice total",
+ "example": "1248.50"
+ },
+ {
+ "name": "Payment Terms",
+ "type": "string",
+ "required": false,
+ "description": "Payment terms",
+ "example": "Net 30"
+ }
+ ]
+ }
+}
diff --git a/src/classifier.py b/src/classifier.py
index 3dd7c8e..b165a28 100644
--- a/src/classifier.py
+++ b/src/classifier.py
@@ -1,12 +1,8 @@
-# utils/classifier.py
+from __future__ import annotations
-import os
-import io
-import base64
from PIL import Image
-from utils.openai_client import get_chat_completion
-from statistics import mode, StatisticsError
-from utils.utils import DEFAULT_OPENAI_MODEL
+
+from extractly.pipeline.classification import classify_document
def classify(
@@ -14,39 +10,13 @@ def classify(
candidates: list[str],
*,
use_confidence: bool = False,
- n_votes: int = 5, # number of self-consistency calls
+ n_votes: int = 5,
system_prompt: str = "",
) -> dict:
- """Returns {'doc_type': str} or additionally a 'confidence' field."""
-
- def _single_vote() -> str:
- buf = io.BytesIO()
- images[0].save(buf, format="PNG")
- data_uri = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
-
- prompt = f"Choose one type from: {candidates}. Return only the type."
- msgs = [
- {"role": "system", "content": system_prompt},
- {
- "role": "user",
- "content": [
- {"type": "text", "text": prompt},
- {"type": "image_url", "image_url": {"url": data_uri}},
- ],
- },
- ]
- return get_chat_completion(
- msgs, model=os.getenv("CLASSIFY_MODEL", DEFAULT_OPENAI_MODEL)
- ).strip()
-
- if not use_confidence:
- return {"doc_type": _single_vote()}
-
- votes = [_single_vote() for _ in range(n_votes)]
- try:
- best = mode(votes)
- confidence = votes.count(best) / n_votes
- except StatisticsError: # all votes different
- best, confidence = votes[0], 1 / n_votes
-
- return {"doc_type": best, "confidence": confidence}
+ return classify_document(
+ images,
+ candidates,
+ use_confidence=use_confidence,
+ n_votes=n_votes,
+ system_prompt=system_prompt or None,
+ )
diff --git a/src/extractly/__init__.py b/src/extractly/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/extractly/config.py b/src/extractly/config.py
new file mode 100644
index 0000000..8012926
--- /dev/null
+++ b/src/extractly/config.py
@@ -0,0 +1,42 @@
+from __future__ import annotations
+
+import os
+from dataclasses import dataclass
+from pathlib import Path
+from dotenv import load_dotenv
+
+
+PROJECT_ROOT = Path(__file__).resolve().parents[2]
+
+
+@dataclass(frozen=True)
+class AppConfig:
+ app_name: str
+ openai_api_key: str | None
+ classify_model: str
+ extract_model: str
+ ocr_model: str
+ request_timeout_s: int
+ max_retries: int
+ retry_backoff_s: float
+ run_store_dir: Path
+ schema_dir: Path
+ sample_data_dir: Path
+
+
+def load_config() -> AppConfig:
+ load_dotenv(override=True)
+
+ return AppConfig(
+ app_name=os.getenv("EXTRACTLY_APP_NAME", "Extractly"),
+ openai_api_key=os.getenv("OPENAI_API_KEY"),
+ classify_model=os.getenv("CLASSIFY_MODEL", "o4-mini"),
+ extract_model=os.getenv("EXTRACT_MODEL", "o4-mini"),
+ ocr_model=os.getenv("OCR_MODEL", "o4-mini"),
+ request_timeout_s=int(os.getenv("EXTRACTLY_TIMEOUT_S", "40")),
+ max_retries=int(os.getenv("EXTRACTLY_MAX_RETRIES", "2")),
+ retry_backoff_s=float(os.getenv("EXTRACTLY_RETRY_BACKOFF_S", "1.5")),
+ run_store_dir=Path(os.getenv("EXTRACTLY_RUNS_DIR", PROJECT_ROOT / "runs")),
+ schema_dir=Path(os.getenv("EXTRACTLY_SCHEMAS_DIR", PROJECT_ROOT / "schemas")),
+ sample_data_dir=Path(os.getenv("EXTRACTLY_SAMPLE_DATA_DIR", PROJECT_ROOT / "data" / "sample_docs")),
+ )
diff --git a/src/extractly/domain/__init__.py b/src/extractly/domain/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/extractly/domain/models.py b/src/extractly/domain/models.py
new file mode 100644
index 0000000..32b99c6
--- /dev/null
+++ b/src/extractly/domain/models.py
@@ -0,0 +1,53 @@
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import Any
+
+
+FIELD_TYPES = {
+ "string",
+ "number",
+ "integer",
+ "boolean",
+ "date",
+ "enum",
+ "object",
+ "array",
+}
+
+
+@dataclass
+class SchemaField:
+ name: str
+ field_type: str = "string"
+ required: bool = False
+ description: str = ""
+ example: str = ""
+ enum_values: list[str] = field(default_factory=list)
+
+ def to_dict(self) -> dict[str, Any]:
+ payload = {
+ "name": self.name,
+ "type": self.field_type,
+ "required": self.required,
+ "description": self.description,
+ "example": self.example,
+ }
+ if self.enum_values:
+ payload["enum"] = self.enum_values
+ return payload
+
+
+@dataclass
+class DocumentSchema:
+ name: str
+ description: str = ""
+ fields: list[SchemaField] = field(default_factory=list)
+ version: str = "v1"
+
+ def to_dict(self) -> dict[str, Any]:
+ return {
+ "description": self.description,
+ "version": self.version,
+ "fields": [field.to_dict() for field in self.fields],
+ }
diff --git a/src/extractly/domain/run_store.py b/src/extractly/domain/run_store.py
new file mode 100644
index 0000000..70f2c1f
--- /dev/null
+++ b/src/extractly/domain/run_store.py
@@ -0,0 +1,90 @@
+from __future__ import annotations
+
+import json
+from dataclasses import dataclass, field
+from datetime import datetime, timezone
+from pathlib import Path
+from typing import Any
+from uuid import uuid4
+
+
+@dataclass
+class RunDocument:
+ filename: str
+ document_type: str
+ confidence: float | None
+ extracted: dict[str, Any]
+ corrected: dict[str, Any]
+ field_confidence: dict[str, float] = field(default_factory=dict)
+ warnings: list[str] = field(default_factory=list)
+ errors: list[str] = field(default_factory=list)
+
+
+@dataclass
+class ExtractionRun:
+ run_id: str
+ started_at: str
+ schema_name: str
+ mode: str
+ documents: list[RunDocument]
+ status: str = "completed"
+ logs: list[str] = field(default_factory=list)
+
+ def to_dict(self) -> dict[str, Any]:
+ return {
+ "run_id": self.run_id,
+ "started_at": self.started_at,
+ "schema_name": self.schema_name,
+ "mode": self.mode,
+ "status": self.status,
+ "logs": self.logs,
+ "documents": [
+ {
+ "filename": doc.filename,
+ "document_type": doc.document_type,
+ "confidence": doc.confidence,
+ "extracted": doc.extracted,
+ "corrected": doc.corrected,
+ "field_confidence": doc.field_confidence,
+ "warnings": doc.warnings,
+ "errors": doc.errors,
+ }
+ for doc in self.documents
+ ],
+ }
+
+
+class RunStore:
+ def __init__(self, base_dir: Path):
+ self.base_dir = base_dir
+ self.base_dir.mkdir(parents=True, exist_ok=True)
+
+ def create_run_id(self) -> str:
+ timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
+ return f"run_{timestamp}_{uuid4().hex[:6]}"
+
+ def save(self, run: ExtractionRun) -> Path:
+ run_dir = self.base_dir / run.run_id
+ run_dir.mkdir(parents=True, exist_ok=True)
+ run_path = run_dir / "run.json"
+ with run_path.open("w", encoding="utf-8") as fp:
+ json.dump(run.to_dict(), fp, indent=2, ensure_ascii=False)
+ return run_path
+
+ def list_runs(self) -> list[dict[str, Any]]:
+ runs: list[dict[str, Any]] = []
+ for run_dir in sorted(self.base_dir.glob("run_*"), reverse=True):
+ run_path = run_dir / "run.json"
+ if not run_path.exists():
+ continue
+ with run_path.open("r", encoding="utf-8") as fp:
+ payload = json.load(fp)
+ runs.append(payload)
+ return runs
+
+ def load(self, run_id: str) -> dict[str, Any] | None:
+ run_path = self.base_dir / run_id / "run.json"
+ if not run_path.exists():
+ return None
+ with run_path.open("r", encoding="utf-8") as fp:
+ return json.load(fp)
diff --git a/src/extractly/domain/schema_store.py b/src/extractly/domain/schema_store.py
new file mode 100644
index 0000000..59092bb
--- /dev/null
+++ b/src/extractly/domain/schema_store.py
@@ -0,0 +1,138 @@
+from __future__ import annotations
+
+import json
+import re
+from pathlib import Path
+from typing import Iterable
+
+from extractly.domain.models import DocumentSchema, SchemaField
+from extractly.domain.validation import validate_schema, ValidationResult
+
+
+class SchemaStore:
+ def __init__(self, schema_dir: Path):
+ self.schema_dir = schema_dir
+ self.schema_dir.mkdir(parents=True, exist_ok=True)
+
+ def list_schemas(self) -> list[DocumentSchema]:
+ schemas: list[DocumentSchema] = []
+ for file_path in sorted(self.schema_dir.glob("*.json")):
+ with file_path.open("r", encoding="utf-8") as fp:
+ payload = json.load(fp)
+ schemas.extend(self._parse_payload(payload))
+ return sorted(schemas, key=lambda s: s.name.lower())
+
+ def get_schema(self, name: str) -> DocumentSchema | None:
+ for schema in self.list_schemas():
+ if schema.name == name:
+ return schema
+ return None
+
+ def save_schema(self, schema: DocumentSchema) -> ValidationResult:
+ validation = validate_schema(schema)
+ if not validation.is_valid:
+ return validation
+
+ file_slug = self._slugify(schema.name)
+ file_path = self.schema_dir / f"{file_slug}.json"
+ payload = {schema.name: schema.to_dict()}
+ with file_path.open("w", encoding="utf-8") as fp:
+ json.dump(payload, fp, indent=2, ensure_ascii=False)
+ return validation
+
+ def delete_schema(self, name: str) -> bool:
+ deleted = False
+ for file_path in self.schema_dir.glob("*.json"):
+ with file_path.open("r", encoding="utf-8") as fp:
+ payload = json.load(fp)
+ if name in payload:
+ payload.pop(name)
+ deleted = True
+ if payload:
+ with file_path.open("w", encoding="utf-8") as fp:
+ json.dump(payload, fp, indent=2, ensure_ascii=False)
+ else:
+ file_path.unlink(missing_ok=True)
+ return deleted
+
+ def import_payload(self, payload: dict) -> list[DocumentSchema]:
+ schemas = self._parse_payload(payload)
+ for schema in schemas:
+ self.save_schema(schema)
+ return schemas
+
+ def export_schema(self, schema: DocumentSchema) -> str:
+ return json.dumps({schema.name: schema.to_dict()}, indent=2, ensure_ascii=False)
+
+ @staticmethod
+ def _slugify(name: str) -> str:
+ cleaned = re.sub(r"[^a-zA-Z0-9_-]+", "-", name.strip())
+ return cleaned.strip("-").lower() or "schema"
+
+ def _parse_payload(self, payload: dict) -> list[DocumentSchema]:
+ schemas: list[DocumentSchema] = []
+ for name, data in payload.items():
+ if isinstance(data, list):
+ fields = [self._parse_field(field) for field in data]
+ schemas.append(DocumentSchema(name=name, fields=fields))
+ continue
+
+ description = data.get("description", "")
+ version = data.get("version", "v1")
+ raw_fields = data.get("fields", [])
+ fields = [self._parse_field(field) for field in raw_fields]
+ schemas.append(
+ DocumentSchema(
+ name=name,
+ description=description,
+ fields=fields,
+ version=version,
+ )
+ )
+ return schemas
+
+ @staticmethod
+ def _parse_field(field: dict) -> SchemaField:
+ return SchemaField(
+ name=str(field.get("name", "")).strip(),
+ field_type=field.get("type", field.get("field_type", "string")),
+ required=bool(field.get("required", False)),
+ description=str(field.get("description", "")),
+ example=str(field.get("example", "")),
+ enum_values=list(field.get("enum", field.get("enum_values", [])) or []),
+ )
+
+
+def schemas_to_table(schema: DocumentSchema) -> list[dict]:
+ return [
+ {
+ "name": field.name,
+ "type": field.field_type,
+ "required": field.required,
+ "description": field.description,
+ "example": field.example,
+ "enum": ", ".join(field.enum_values),
+ }
+ for field in schema.fields
+ ]
+
+
+def table_to_schema(name: str, description: str, rows: Iterable[dict]) -> DocumentSchema:
+ fields: list[SchemaField] = []
+ for row in rows:
+ enum_values = [
+ item.strip()
+ for item in str(row.get("enum", "")).split(",")
+ if item.strip()
+ ]
+ fields.append(
+ SchemaField(
+ name=str(row.get("name", "")).strip(),
+ field_type=row.get("type", "string"),
+ required=bool(row.get("required", False)),
+ description=str(row.get("description", "")).strip(),
+ example=str(row.get("example", "")).strip(),
+ enum_values=enum_values,
+ )
+ )
+ return DocumentSchema(name=name, description=description, fields=fields)
diff --git a/src/extractly/domain/validation.py b/src/extractly/domain/validation.py
new file mode 100644
index 0000000..e9fd140
--- /dev/null
+++ b/src/extractly/domain/validation.py
@@ -0,0 +1,56 @@
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+
+from extractly.domain.models import DocumentSchema, FIELD_TYPES
+
+
+@dataclass
+class ValidationResult:
+ errors: list[str] = field(default_factory=list)
+ warnings: list[str] = field(default_factory=list)
+
+ @property
+ def is_valid(self) -> bool:
+ return not self.errors
+
+
+def validate_schema(schema: DocumentSchema) -> ValidationResult:
+ result = ValidationResult()
+
+ if not schema.name.strip():
+ result.errors.append("Schema name is required.")
+
+ seen_names = set()
+ for idx, field in enumerate(schema.fields, start=1):
+ if not field.name.strip():
+ result.errors.append(f"Field #{idx} is missing a name.")
+ continue
+
+ if field.name in seen_names:
+ result.errors.append(f"Field name '{field.name}' is duplicated.")
+ seen_names.add(field.name)
+
+ if field.field_type not in FIELD_TYPES:
+ result.errors.append(
+ f"Field '{field.name}' has invalid type '{field.field_type}'."
+ )
+
+ if field.field_type == "enum" and not field.enum_values:
+ result.errors.append(
+ f"Field '{field.name}' is enum but has no enum values."
+ )
+
+ if field.enum_values:
+ if len(set(field.enum_values)) != len(field.enum_values):
+ result.errors.append(
+ f"Field '{field.name}' has duplicate enum values."
+ )
+
+ if not schema.fields:
+ result.errors.append("At least one field is required.")
+
+ if schema.description and len(schema.description) < 10:
+ result.warnings.append("Schema description is very short.")
+
+ return result
diff --git a/src/extractly/integrations/__init__.py b/src/extractly/integrations/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/extractly/integrations/ocr.py b/src/extractly/integrations/ocr.py
new file mode 100644
index 0000000..a5d7b1e
--- /dev/null
+++ b/src/extractly/integrations/ocr.py
@@ -0,0 +1,52 @@
+from __future__ import annotations
+
+import base64
+import io
+from typing import Iterable
+from PIL import Image
+
+from extractly.config import load_config
+from extractly.integrations.openai_client import get_chat_completion
+from extractly.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+_OCR_SYSTEM_PROMPT = """
+You are an expert OCR (Optical Character Recognition) assistant. Extract every visible piece
+of text from the document image. Preserve the reading order and line breaks where meaningful.
+Return only the raw text content.
+"""
+
+
+def _page_to_data_uri(page: Image.Image) -> str:
+ buf = io.BytesIO()
+ page.save(buf, format="PNG")
+ return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
+
+
+def ocr_page(page: Image.Image) -> str:
+ config = load_config()
+ messages = [
+ {"role": "system", "content": _OCR_SYSTEM_PROMPT},
+ {
+ "role": "user",
+ "content": [
+ {"type": "image_url", "image_url": {"url": _page_to_data_uri(page)}},
+ ],
+ },
+ ]
+
+ return get_chat_completion(messages, model=config.ocr_model, temperature=0.0)
+
+
+def run_ocr(pages: Iterable[Image.Image]) -> str:
+ outputs: list[str] = []
+ for page in pages:
+ try:
+ outputs.append(ocr_page(page))
+ except Exception as exc:
+ logger.error("OCR failed: %s", exc)
+ outputs.append("")
+ return "\n\n".join(outputs).strip()
diff --git a/src/extractly/integrations/openai_client.py b/src/extractly/integrations/openai_client.py
new file mode 100644
index 0000000..c783d5f
--- /dev/null
+++ b/src/extractly/integrations/openai_client.py
@@ -0,0 +1,56 @@
+from __future__ import annotations
+
+import time
+from typing import Any
+
+from openai import OpenAI
+
+from extractly.config import load_config
+from extractly.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+_client: OpenAI | None = None
+
+
+def get_client() -> OpenAI:
+ global _client
+ if _client is None:
+ config = load_config()
+ _client = OpenAI(api_key=config.openai_api_key, timeout=config.request_timeout_s)
+ return _client
+
+
+def get_chat_completion(
+ messages: list[dict[str, Any]],
+ *,
+ model: str,
+ temperature: float = 0.2,
+ max_retries: int | None = None,
+ timeout_s: int | None = None,
+) -> str:
+ config = load_config()
+ retries = max_retries if max_retries is not None else config.max_retries
+ timeout = timeout_s if timeout_s is not None else config.request_timeout_s
+
+ if not config.openai_api_key:
+ raise RuntimeError("OPENAI_API_KEY is not set. Add it to your environment.")
+
+ for attempt in range(retries + 1):
+ try:
+ response = get_client().chat.completions.create(
+ model=model,
+ messages=messages,
+ temperature=temperature,
+ timeout=timeout,
+ )
+ return response.choices[0].message.content or ""
+ except Exception as exc:
+ logger.warning("OpenAI request failed (attempt %s): %s", attempt + 1, exc)
+ if attempt >= retries:
+ raise
+ time.sleep(config.retry_backoff_s * (attempt + 1))
+
+ return ""
diff --git a/src/extractly/integrations/preprocess.py b/src/extractly/integrations/preprocess.py
new file mode 100644
index 0000000..93fb89e
--- /dev/null
+++ b/src/extractly/integrations/preprocess.py
@@ -0,0 +1,17 @@
+from __future__ import annotations
+
+import io
+from typing import BinaryIO
+
+from pdf2image import convert_from_bytes
+from PIL import Image
+
+
+def preprocess(uploaded: BinaryIO, filename: str) -> list[Image.Image]:
+ data = uploaded.read()
+ uploaded.seek(0)
+
+ if filename.lower().endswith(".pdf"):
+ return convert_from_bytes(data)
+
+ return [Image.open(io.BytesIO(data))]
diff --git a/src/extractly/logging.py b/src/extractly/logging.py
new file mode 100644
index 0000000..511a262
--- /dev/null
+++ b/src/extractly/logging.py
@@ -0,0 +1,14 @@
+from __future__ import annotations
+
+import logging
+
+
+LOG_FORMAT = "%(asctime)s | %(levelname)s | %(name)s | %(message)s"
+
+
+def setup_logging(level: int = logging.INFO) -> None:
+ logging.basicConfig(level=level, format=LOG_FORMAT)
+
+
+def get_logger(name: str) -> logging.Logger:
+ return logging.getLogger(name)
diff --git a/src/extractly/pipeline/__init__.py b/src/extractly/pipeline/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/extractly/pipeline/classification.py b/src/extractly/pipeline/classification.py
new file mode 100644
index 0000000..c232480
--- /dev/null
+++ b/src/extractly/pipeline/classification.py
@@ -0,0 +1,71 @@
+from __future__ import annotations
+
+import base64
+import io
+from statistics import StatisticsError, mode
+from typing import Any
+
+from PIL import Image
+
+from extractly.config import load_config
+from extractly.integrations.openai_client import get_chat_completion
+from extractly.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+DEFAULT_CLASSIFIER_PROMPT = """
+You are an expert document classifier. Choose the most likely document type based on layout,
+visual cues, and key text. Respond only with a single label from the provided list.
+If uncertain, choose "Unknown".
+"""
+
+
+def _image_to_data_uri(image: Image.Image) -> str:
+ buf = io.BytesIO()
+ image.save(buf, format="PNG")
+ return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
+
+
+def classify_document(
+ images: list[Image.Image],
+ candidates: list[str],
+ *,
+ use_confidence: bool = False,
+ n_votes: int = 3,
+ system_prompt: str | None = None,
+) -> dict[str, Any]:
+ config = load_config()
+ prompt = system_prompt or DEFAULT_CLASSIFIER_PROMPT
+
+ def _single_vote() -> str:
+ messages = [
+ {"role": "system", "content": prompt},
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": f"Choose one type from: {candidates}.",
+ },
+ {
+ "type": "image_url",
+ "image_url": {"url": _image_to_data_uri(images[0])},
+ },
+ ],
+ },
+ ]
+ return get_chat_completion(messages, model=config.classify_model).strip()
+
+ if not use_confidence:
+ return {"doc_type": _single_vote()}
+
+ votes = [_single_vote() for _ in range(n_votes)]
+ try:
+ best = mode(votes)
+ confidence = votes.count(best) / n_votes
+ except StatisticsError:
+ best, confidence = votes[0], 1 / n_votes
+
+ return {"doc_type": best, "confidence": confidence}
diff --git a/src/extractly/pipeline/extraction.py b/src/extractly/pipeline/extraction.py
new file mode 100644
index 0000000..60be5e5
--- /dev/null
+++ b/src/extractly/pipeline/extraction.py
@@ -0,0 +1,119 @@
+from __future__ import annotations
+
+import base64
+import contextlib
+import io
+import json
+import re
+from typing import Any
+
+from PIL import Image
+
+from extractly.config import load_config
+from extractly.domain.models import SchemaField
+from extractly.integrations.openai_client import get_chat_completion
+from extractly.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+DEFAULT_EXTRACTION_PROMPT = """
+You are a metadata extraction specialist. Extract the requested fields with high accuracy.
+Return JSON with keys: metadata, snippets, confidence. Use null when a field is missing.
+"""
+
+
+def _image_to_data_uri(image: Image.Image) -> str:
+ buf = io.BytesIO()
+ image.save(buf, format="PNG")
+ return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
+
+
+def _truncate(text: str, max_chars: int = 64_000) -> str:
+ return text[:max_chars]
+
+
+def _schema_payload(fields: list[SchemaField]) -> dict[str, Any]:
+ return {
+ field.name: {
+ "type": field.field_type,
+ "required": field.required,
+ "description": field.description,
+ "example": field.example,
+ "enum": field.enum_values,
+ }
+ for field in fields
+ }
+
+
+def extract_metadata(
+ images: list[Image.Image],
+ fields: list[SchemaField],
+ *,
+ ocr_text: str | None = None,
+ with_confidence: bool = False,
+ system_prompt: str | None = None,
+) -> dict[str, Any]:
+ config = load_config()
+ field_names = [field.name for field in fields]
+
+ schema_json = json.dumps(_schema_payload(fields), ensure_ascii=False)
+
+ user_content = [
+ {"type": "image_url", "image_url": {"url": _image_to_data_uri(images[0])}},
+ {"type": "text", "text": f"Schema: {schema_json}"},
+ ]
+ if ocr_text:
+ user_content.append(
+ {
+ "type": "text",
+ "text": "OCR context:\n" + _truncate(ocr_text),
+ }
+ )
+
+ messages = [
+ {"role": "system", "content": system_prompt or DEFAULT_EXTRACTION_PROMPT},
+ {"role": "user", "content": user_content},
+ ]
+
+ response = get_chat_completion(messages, model=config.extract_model)
+ if not response.strip():
+ raise RuntimeError("Empty extraction response")
+
+ raw: dict[str, Any] | None = None
+ with contextlib.suppress(json.JSONDecodeError):
+ raw = json.loads(response)
+
+ if raw is None and (match := re.search(r"\{.*\}", response, flags=re.S)):
+ with contextlib.suppress(json.JSONDecodeError):
+ raw = json.loads(match.group())
+
+ if not isinstance(raw, dict):
+ logger.error("Bad extraction JSON. Returning blanks.")
+ raw = {}
+
+ raw_meta = raw.get("metadata") if isinstance(raw.get("metadata"), dict) else {}
+ raw_snippets = (
+ raw.get("snippets") if isinstance(raw.get("snippets"), dict) else {}
+ )
+ raw_conf = (
+ raw.get("confidence") if isinstance(raw.get("confidence"), dict) else {}
+ )
+
+
+ if with_confidence and not raw_conf:
+ raw_conf = {name: 1.0 if raw_meta.get(name) else 0.0 for name in field_names}
+
+ metadata = {name: raw_meta.get(name) for name in field_names}
+ snippets = {name: raw_snippets.get(name) for name in field_names}
+
+ confidence = {}
+ if with_confidence:
+ confidence = {name: raw_conf.get(name, 0.0) for name in field_names}
+
+ return {
+ "metadata": metadata,
+ "snippets": snippets,
+ "confidence": confidence,
+ }
diff --git a/src/extractly/pipeline/runner.py b/src/extractly/pipeline/runner.py
new file mode 100644
index 0000000..20bb7bd
--- /dev/null
+++ b/src/extractly/pipeline/runner.py
@@ -0,0 +1,110 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from datetime import datetime, timezone
+from typing import Any
+
+from PIL import Image
+
+from extractly.domain.models import DocumentSchema
+from extractly.domain.run_store import ExtractionRun, RunDocument, RunStore
+from extractly.integrations.ocr import run_ocr
+from extractly.pipeline.classification import classify_document
+from extractly.pipeline.extraction import extract_metadata
+from extractly.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+@dataclass
+class PipelineOptions:
+ enable_ocr: bool = False
+ compute_confidence: bool = False
+ mode: str = "fast"
+ classifier_prompt: str | None = None
+ extraction_prompt: str | None = None
+
+
+def run_pipeline(
+ *,
+ files: list[dict[str, Any]],
+ schema: DocumentSchema,
+ candidates: list[str],
+ run_store: RunStore,
+ options: PipelineOptions,
+) -> ExtractionRun:
+ run_id = run_store.create_run_id()
+ logs: list[str] = []
+ documents: list[RunDocument] = []
+
+ for payload in files:
+ filename = payload["name"]
+ images: list[Image.Image] = payload["images"]
+
+ logs.append(f"Parsing {filename}")
+ ocr_text = payload.get("ocr_text")
+ if ocr_text is None and options.enable_ocr:
+ ocr_text = run_ocr(images)
+
+ doc_type_override = payload.get("doc_type_override")
+ if doc_type_override:
+ doc_type = doc_type_override
+ confidence = None
+ logs.append(f"Using provided document type for {filename}: {doc_type}")
+ else:
+ classification = classify_document(
+ images,
+ candidates,
+ use_confidence=options.compute_confidence,
+ system_prompt=options.classifier_prompt,
+ )
+ doc_type = classification.get("doc_type", "Unknown")
+ confidence = classification.get("confidence")
+ logs.append(f"Classified {filename} as {doc_type}")
+
+ warnings: list[str] = []
+ errors: list[str] = []
+ extracted: dict[str, Any] = {}
+ field_confidence: dict[str, float] = {}
+
+ if doc_type in {"Unknown", "Other"}:
+ warnings.append("Document type is unknown. Extraction skipped.")
+ else:
+ try:
+ extraction = extract_metadata(
+ images,
+ schema.fields,
+ ocr_text=ocr_text,
+ with_confidence=options.compute_confidence,
+ system_prompt=options.extraction_prompt,
+ )
+ extracted = extraction.get("metadata", {})
+ field_confidence = extraction.get("confidence", {})
+ except Exception as exc:
+ logger.error("Extraction failed for %s: %s", filename, exc)
+ errors.append(str(exc))
+
+ documents.append(
+ RunDocument(
+ filename=filename,
+ document_type=doc_type,
+ confidence=confidence,
+ extracted=extracted,
+ corrected=extracted.copy(),
+ field_confidence=field_confidence,
+ warnings=warnings,
+ errors=errors,
+ )
+ )
+
+ run = ExtractionRun(
+ run_id=run_id,
+ started_at=datetime.now(timezone.utc).isoformat(),
+ schema_name=schema.name,
+ mode=options.mode,
+ documents=documents,
+ logs=logs,
+ )
+ run_store.save(run)
+ return run
diff --git a/src/extractly/ui/__init__.py b/src/extractly/ui/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/extractly/ui/components.py b/src/extractly/ui/components.py
new file mode 100644
index 0000000..0527394
--- /dev/null
+++ b/src/extractly/ui/components.py
@@ -0,0 +1,77 @@
+from __future__ import annotations
+
+import base64
+from pathlib import Path
+import streamlit as st
+
+
+def inject_branding(logo_path: str | Path, height: str = "64px") -> None:
+ logo_path = Path(logo_path)
+ if not logo_path.exists():
+ return
+
+ encoded = base64.b64encode(logo_path.read_bytes()).decode()
+ st.markdown(
+ f"""
+
+ """,
+ unsafe_allow_html=True,
+ )
+
+
+def inject_global_styles() -> None:
+ st.markdown(
+ """
+
+ """,
+ unsafe_allow_html=True,
+ )
+
+
+def section_title(title: str, subtitle: str | None = None) -> None:
+ st.markdown(f"", unsafe_allow_html=True)
+ if subtitle:
+ st.caption(subtitle)
diff --git a/src/extractor.py b/src/extractor.py
index 1c27ff7..82a3af7 100644
--- a/src/extractor.py
+++ b/src/extractor.py
@@ -1,111 +1,35 @@
-import contextlib
-import os
-import io
-import base64
-import json
-import logging
-import re
-from typing import List, Dict
-from PIL import Image
-from utils.openai_client import get_chat_completion
-from utils.utils import DEFAULT_OPENAI_MODEL
-from utils.confidence_utils import score_confidence
+from __future__ import annotations
+from PIL import Image
-def _truncate(txt: str, max_chars: int = 64_000) -> str:
- """Guard-rail so we don't blow up the context window with huge OCR dumps."""
- return txt[:max_chars]
+from extractly.domain.models import SchemaField
+from extractly.pipeline.extraction import extract_metadata
def extract(
- images: List[Image.Image],
- schema: List[Dict],
- ocr_text: str | None = None, # Changed from Mapping to str
- *, # keyword-only "tuning" flags
+ images: list[Image.Image],
+ schema: list[dict],
+ ocr_text: str | None = None,
+ *,
with_confidence: bool = False,
system_prompt: str = "",
-) -> Dict:
- """
- Return a dict with exactly three top-level keys:
- metadata โ field -> value
- snippets โ field -> supporting text (ocr or model)
- confidence โ field -> float in [0,1] (empty if with_confidence=False)
- """
- # 1๏ธโฃ template & helpers --------------------------------------------------
- field_names = [f["name"] for f in schema]
- blank_dict = {n: None for n in field_names}
- schema_json = json.dumps(blank_dict, ensure_ascii=False)
-
- buf = io.BytesIO()
- images[0].save(buf, format="PNG")
- data_uri = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
-
- usr = [
- {"type": "image_url", "image_url": {"url": data_uri}},
- {"type": "text", "text": f"Fields schema: {schema_json}"},
- ]
-
- if ocr_text:
- usr.append(
- {
- "type": "text",
- "text": "Extra context (OCR dump):\n\n" + _truncate(ocr_text),
- }
+) -> dict:
+ fields = [
+ SchemaField(
+ name=field.get("name", ""),
+ field_type=field.get("type", "string"),
+ required=field.get("required", False),
+ description=field.get("description", ""),
+ example=field.get("example", ""),
+ enum_values=list(field.get("enum", []) or []),
)
-
- messages = [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": usr},
+ for field in schema
]
- # 2๏ธโฃ LLM call ------------------------------------------------------------
- resp = get_chat_completion(
- messages, model=os.getenv("EXTRACT_MODEL", DEFAULT_OPENAI_MODEL)
+ return extract_metadata(
+ images,
+ fields,
+ ocr_text=ocr_text,
+ with_confidence=with_confidence,
+ system_prompt=system_prompt or None,
)
-
- if not resp.strip():
- raise RuntimeError("empty LLM response")
-
- # 3๏ธโฃ JSON-safe parse -----------------------------------------------------
- raw: Dict | None = None
- with contextlib.suppress(json.JSONDecodeError):
- raw = json.loads(resp)
-
- if raw is None and (m := re.search(r"\{.*\}", resp, flags=re.S)):
- with contextlib.suppress(json.JSONDecodeError):
- raw = json.loads(m.group())
-
- if not isinstance(raw, dict):
- logging.error("Bad extraction JSON โ returning blanks")
- raw = {}
-
- # 4๏ธโฃ normalise sections --------------------------------------------------
- raw_meta = raw.get("metadata") or {}
- raw_snip = raw.get("snippets") or {}
- raw_conf = raw.get("confidence") or {}
-
- # force dicts (LLMs sometimes give a scalar there)
- if not isinstance(raw_meta, dict):
- raw_meta = {}
- if not isinstance(raw_snip, dict):
- raw_snip = {}
- if not isinstance(raw_conf, dict):
- raw_conf = {}
-
- # merge external OCR into snippets (OCR wins)
- if ocr_text: # Changed from isinstance(ocr_text, dict) check
- raw_snip["ocr_content"] = ocr_text
-
- # 5๏ธโฃ fallback confidence --------------------------------------------------
- if with_confidence and not raw_conf:
- # heuristic: 1.0 if field present & not null, else 0.0
- raw_conf = score_confidence(raw_meta, schema)
-
- # 6๏ธโฃ final payload with *exact* keys -------------------------------------
- return {
- "metadata": {n: raw_meta.get(n) for n in field_names},
- "snippets": {n: raw_snip.get(n) for n in field_names},
- "confidence": {n: raw_conf.get(n) for n in field_names}
- if with_confidence
- else {},
- }
diff --git a/src/ocr_engine.py b/src/ocr_engine.py
index a913cb3..3c81505 100644
--- a/src/ocr_engine.py
+++ b/src/ocr_engine.py
@@ -1,73 +1,12 @@
-"""
-src/ocr_engine.py
-A super-thin wrapper so you can swap engines without touching the rest of the app.
-"""
-
from __future__ import annotations
-from PIL import Image
-import io
-import base64
-import os
-import logging
-from openai import OpenAI
-from utils.utils import DEFAULT_OPENAI_MODEL
-
-api_key = os.getenv("OPENAI_API_KEY")
-if not api_key:
- logging.error("OPENAI_API_KEY not found in env")
-
-# โโ local OpenAI helper โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
-_client = OpenAI(api_key=api_key) # โ
-_VISION_MODEL = os.getenv("OCR_MODEL", DEFAULT_OPENAI_MODEL) # โก
-
-def _ocr_llm(page: Image.Image) -> str:
- """Do a *vision* chat-completion round-trip and return plain text."""
- buf = io.BytesIO()
- page.save(buf, format="PNG")
- data_uri = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
-
- system_prompt = """
- You are an expert OCR (Optical Character Recognition) assistant. Your task is to extract ALL visible text from the document image with perfect accuracy.
-
- Instructions:
- 1. Read every piece of text visible in the image, including:
- - Headers, titles, and headings
- - Body text and paragraphs
- - Table contents and data
- - Form fields and labels
- - Numbers, dates, and codes
- - Fine print and footnotes
- - Watermarks or stamps (if readable)
+from PIL import Image
- 2. Maintain the logical reading order (top to bottom, left to right)
- 3. Preserve line breaks and spacing where meaningful
- 4. Return ONLY the literal text - no commentary, no JSON formatting
- 5. If text is unclear or partially obscured, make your best attempt
- 6. Format as plain text but keep line breaks as they appear so humans can read it easily
- """
+from extractly.integrations.ocr import run_ocr as _run_ocr
- msg = [
- {"role": "system", "content": system_prompt},
- {
- "role": "user",
- "content": [
- {"type": "image_url", "image_url": {"url": data_uri}},
- ],
- },
- ]
- try:
- resp = _client.chat.completions.create(model=_VISION_MODEL, messages=msg)
- return resp.choices[0].message.content.strip()
- except Exception as e:
- logging.error(f"LLM-OCR failed: {e}")
- return ""
+def run_ocr(pages: list[Image.Image]) -> str:
+ return _run_ocr(pages)
-def run_ocr(pages: list[Image.Image]) -> str:
- """
- Concatenate OCR text from **all** pages with double newlines.
- Now LLM-only.
- """
- return "\n\n".join(_ocr_llm(p) for p in pages)
+__all__ = ["run_ocr"]
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..2d24259
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,7 @@
+import sys
+from pathlib import Path
+
+PROJECT_ROOT = Path(__file__).resolve().parents[1]
+SRC_PATH = PROJECT_ROOT / "src"
+if str(SRC_PATH) not in sys.path:
+ sys.path.insert(0, str(SRC_PATH))
diff --git a/tests/test_run_store.py b/tests/test_run_store.py
new file mode 100644
index 0000000..e68e533
--- /dev/null
+++ b/tests/test_run_store.py
@@ -0,0 +1,30 @@
+from pathlib import Path
+
+from extractly.domain.run_store import ExtractionRun, RunDocument, RunStore
+
+
+def test_run_store_round_trip(tmp_path: Path):
+ store = RunStore(tmp_path)
+ run = ExtractionRun(
+ run_id=store.create_run_id(),
+ started_at="2025-01-01T00:00:00Z",
+ schema_name="Invoice",
+ mode="fast",
+ documents=[
+ RunDocument(
+ filename="sample.pdf",
+ document_type="Invoice",
+ confidence=0.8,
+ extracted={"Invoice Number": "INV-1"},
+ corrected={"Invoice Number": "INV-1"},
+ field_confidence={"Invoice Number": 0.8},
+ )
+ ],
+ )
+
+ store.save(run)
+ runs = store.list_runs()
+ assert runs
+ loaded = store.load(run.run_id)
+ assert loaded is not None
+ assert loaded["schema_name"] == "Invoice"
diff --git a/tests/test_schema_validation.py b/tests/test_schema_validation.py
new file mode 100644
index 0000000..d333ebf
--- /dev/null
+++ b/tests/test_schema_validation.py
@@ -0,0 +1,34 @@
+from extractly.domain.models import DocumentSchema, SchemaField
+from extractly.domain.validation import validate_schema
+
+
+def test_validate_schema_detects_errors():
+ schema = DocumentSchema(
+ name="",
+ fields=[
+ SchemaField(name="", field_type="string"),
+ SchemaField(name="status", field_type="enum", enum_values=[]),
+ SchemaField(name="status", field_type="string"),
+ ],
+ )
+
+ result = validate_schema(schema)
+ assert not result.is_valid
+ assert "Schema name is required." in result.errors
+ assert any("Field #1" in err for err in result.errors)
+ assert any("enum" in err for err in result.errors)
+ assert any("duplicated" in err for err in result.errors)
+
+
+def test_validate_schema_accepts_valid_schema():
+ schema = DocumentSchema(
+ name="Invoice",
+ description="Invoice metadata",
+ fields=[
+ SchemaField(name="Invoice Number", field_type="string", required=True),
+ SchemaField(name="Total", field_type="number"),
+ ],
+ )
+
+ result = validate_schema(schema)
+ assert result.is_valid