Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ dev = [
[project.scripts]
sqlextract = "src.cli:main"

[tool.setuptools.packages.find]
include = ["src*"]

[tool.black]
line-length = 100
target-version = ['py39', 'py310', 'py311', 'py312']
Expand Down
5 changes: 4 additions & 1 deletion src/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
@click.option('--seed-data', is_flag=True, help='Extract seed data from tables')
@click.option('--tables', help='Comma-separated list of tables for seed data (format: schema.table)')
@click.option('--max-rows', type=int, help='Maximum rows per table for seed data')
@click.option('--split', is_flag=True, help='Write each object to its own file, organized in subdirectories by type')
@click.option('--verbose', '-v', is_flag=True, help='Verbose logging')
@click.option('--quiet', '-q', is_flag=True, help='Quiet mode (errors only)')
@click.option('--log-file', help='Log to file')
Expand All @@ -39,6 +40,7 @@ def main(
seed_data: bool,
tables: str,
max_rows: int,
split: bool,
verbose: bool,
quiet: bool,
log_file: str
Expand Down Expand Up @@ -110,7 +112,8 @@ def main(
schema_filter=schema_filter,
extract_seed_data=seed_data,
seed_tables=seed_tables,
max_rows=max_rows
max_rows=max_rows,
split=split
)

click.echo(f"\n✓ Extraction complete! SQL scripts written to: {output}")
Expand Down
58 changes: 40 additions & 18 deletions src/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def extract(
schema_filter: Optional[List[str]] = None,
extract_seed_data: bool = False,
seed_tables: Optional[List[str]] = None,
max_rows: Optional[int] = None
max_rows: Optional[int] = None,
split: bool = False
) -> None:
"""Extract database schema and data.

Expand All @@ -41,6 +42,7 @@ def extract(
extract_seed_data: Whether to extract seed data
seed_tables: Optional list of tables for seed data (format: schema.table)
max_rows: Maximum rows per table for seed data
split: Write each object to its own file
"""
try:
# Connect to database
Expand Down Expand Up @@ -95,8 +97,8 @@ def extract(
logger.info("Extracting functions...")
functions = schema_reader.get_functions(schema_filter)

# Extract seed data if requested
seed_data = []
# Extract seed data if requested (dict keyed by schema.table for split support)
seed_data_dict: Dict[str, List[str]] = {}
if extract_seed_data:
logger.info("Extracting seed data...")

Expand Down Expand Up @@ -131,24 +133,44 @@ def extract(
table.columns,
max_rows=max_rows
)
seed_data.extend(table_data)
if table_data:
key = f"{table.schema_name}.{table.table_name}"
seed_data_dict[key] = table_data

# Generate SQL scripts
logger.info("Generating SQL scripts...")
formatter.write_modular_format(
schemas=schemas,
tables=ordered_tables,
primary_keys=primary_keys,
foreign_keys=foreign_keys,
deferred_foreign_keys=deferred_fks,
unique_constraints=unique_constraints,
check_constraints=check_constraints,
indexes=indexes,
views=views,
procedures=procedures,
functions=functions,
seed_data=seed_data
)
if split:
formatter.write_split_format(
schemas=schemas,
tables=ordered_tables,
primary_keys=primary_keys,
foreign_keys=foreign_keys,
deferred_foreign_keys=deferred_fks,
unique_constraints=unique_constraints,
check_constraints=check_constraints,
indexes=indexes,
views=views,
procedures=procedures,
functions=functions,
seed_data=seed_data_dict
)
else:
# Flatten dict to list for modular format
seed_data_flat = [stmt for stmts in seed_data_dict.values() for stmt in stmts]
formatter.write_modular_format(
schemas=schemas,
tables=ordered_tables,
primary_keys=primary_keys,
foreign_keys=foreign_keys,
deferred_foreign_keys=deferred_fks,
unique_constraints=unique_constraints,
check_constraints=check_constraints,
indexes=indexes,
views=views,
procedures=procedures,
functions=functions,
seed_data=seed_data_flat
)

logger.info("Extraction complete!")

Expand Down
119 changes: 118 additions & 1 deletion src/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .schema import SchemaInfo, TableInfo, ColumnInfo
from .constraints import PrimaryKeyInfo, ForeignKeyInfo, UniqueConstraintInfo, CheckConstraintInfo
from .indexes import IndexInfo
from .utils import get_logger, escape_sql_identifier, OutputError
from .utils import get_logger, escape_sql_identifier, sanitize_filename, OutputError

logger = get_logger(__name__)

Expand Down Expand Up @@ -423,3 +423,120 @@ def write_modular_format(
f.write("\n")

logger.info(f"Modular SQL scripts written to {self.output_dir}")

def write_split_format(
self,
schemas: List[SchemaInfo],
tables: List[TableInfo],
primary_keys: List[PrimaryKeyInfo],
foreign_keys: List[ForeignKeyInfo],
deferred_foreign_keys: List[ForeignKeyInfo],
unique_constraints: List[UniqueConstraintInfo],
check_constraints: List[CheckConstraintInfo],
indexes: List[IndexInfo],
views: List[Dict[str, Any]],
procedures: List[Dict[str, Any]],
functions: List[Dict[str, Any]],
seed_data: Dict[str, List[str]]
) -> None:
"""Write each object to its own file, organized in subdirectories by type.

Args:
schemas: List of schemas
tables: List of tables
primary_keys: List of primary keys
foreign_keys: List of foreign keys
deferred_foreign_keys: List of deferred foreign keys (circular deps)
unique_constraints: List of unique constraints
check_constraints: List of check constraints
indexes: List of indexes
views: List of views
procedures: List of procedures
functions: List of functions
seed_data: Dict keyed by "schema.table" with lists of INSERT statements
"""
header = self.write_header()
written_paths: dict[str, str] = {}

def _write(subdir: str, filename: str, content: str) -> None:
"""Write a single object file inside a typed subdirectory."""
try:
dir_path = os.path.join(self.output_dir, subdir)
os.makedirs(dir_path, exist_ok=True)
safe_filename = sanitize_filename(filename)
path = os.path.join(dir_path, safe_filename)

# Detect in-run collisions (different source names -> same sanitized path)
collision_key = os.path.normcase(os.path.normpath(path))
previous_source = written_paths.get(collision_key)
if previous_source is not None and previous_source != filename:
raise OutputError(
f"Filename collision after sanitization: "
f"'{previous_source}' vs '{filename}' -> '{safe_filename}'"
)
written_paths[collision_key] = filename

with open(path, 'w', encoding='utf-8') as f:
f.write(header)
f.write(content)
except OutputError:
raise
except (OSError, ValueError) as exc:
raise OutputError(f"Failed to write split output file: {filename}") from exc

Comment thread
coderabbitai[bot] marked this conversation as resolved.
# Schemas
for schema in schemas:
_write("schemas", f"{schema.name}.sql", self.format_schema(schema))

# Tables
for table in tables:
_write("tables", f"{table.schema_name}.{table.table_name}.sql", self.format_table(table))

# Constraints — all types in one folder with schema.table.constraint_name.sql
for pk in primary_keys:
_write("constraints", f"{pk.schema_name}.{pk.table_name}.{pk.constraint_name}.sql",
self.format_primary_key(pk))

regular_fks = [fk for fk in foreign_keys if fk not in deferred_foreign_keys]
for fk in regular_fks:
_write("constraints", f"{fk.schema_name}.{fk.table_name}.{fk.constraint_name}.sql",
self.format_foreign_key(fk))

for uq in unique_constraints:
_write("constraints", f"{uq.schema_name}.{uq.table_name}.{uq.constraint_name}.sql",
self.format_unique_constraint(uq))

for chk in check_constraints:
_write("constraints", f"{chk.schema_name}.{chk.table_name}.{chk.constraint_name}.sql",
self.format_check_constraint(chk))

# Indexes
for idx in indexes:
_write("indexes", f"{idx.schema_name}.{idx.table_name}.{idx.index_name}.sql",
self.format_index(idx))

# Views
for view in views:
_write("views", f"{view['schema_name']}.{view['view_name']}.sql",
self.format_view(view))

# Procedures
for proc in procedures:
_write("procedures", f"{proc['schema_name']}.{proc['procedure_name']}.sql",
self.format_procedure(proc))

# Functions
for func in functions:
_write("functions", f"{func['schema_name']}.{func['function_name']}.sql",
self.format_function(func))

# Deferred foreign keys
for fk in deferred_foreign_keys:
_write("deferred_fks", f"{fk.schema_name}.{fk.table_name}.{fk.constraint_name}.sql",
self.format_foreign_key(fk))

# Seed data — one file per table
for table_key, statements in seed_data.items():
_write("seed_data", f"{table_key}.sql", "\n".join(statements) + "\n")

logger.info(f"Split SQL scripts written to {self.output_dir}")
28 changes: 28 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utility functions and logging setup."""

import logging
import os
import sys
from typing import Optional
from rich.console import Console
Expand Down Expand Up @@ -95,6 +96,33 @@ def escape_sql_string(value: str) -> str:
return value.replace("'", "''")


_WINDOWS_RESERVED = frozenset({
'CON', 'PRN', 'AUX', 'NUL',
*(f'COM{i}' for i in range(1, 10)),
*(f'LPT{i}' for i in range(1, 10)),
})


def sanitize_filename(name: str) -> str:
"""Sanitize a SQL object name for use as a filename.

Replaces filesystem-unsafe characters with underscores, strips trailing dots/spaces,
and escapes Windows-reserved device names.
"""
for ch in r'/\:*?"<>|':
name = name.replace(ch, '_')
sanitized = name.rstrip('. ')
if not sanitized:
raise ValueError(f"Filename is empty after sanitizing: {name!r}")

# Escape Windows-reserved device names (e.g. CON.sql -> CON_.sql)
base, ext = os.path.splitext(sanitized)
if base.upper() in _WINDOWS_RESERVED:
sanitized = base + '_' + ext

return sanitized

Comment thread
coderabbitai[bot] marked this conversation as resolved.

class SQLExtractError(Exception):
"""Base exception for SQL Extract."""
pass
Expand Down
Loading