diff --git a/pyproject.toml b/pyproject.toml index 16d5f46..5136cc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,9 @@ dev = [ [project.scripts] sqlextract = "src.cli:main" +[tool.setuptools.packages.find] +include = ["src*"] + [tool.black] line-length = 100 target-version = ['py39', 'py310', 'py311', 'py312'] diff --git a/src/cli.py b/src/cli.py index 3a1ba4a..b2a91c4 100644 --- a/src/cli.py +++ b/src/cli.py @@ -22,6 +22,7 @@ @click.option('--seed-data', is_flag=True, help='Extract seed data from tables') @click.option('--tables', help='Comma-separated list of tables for seed data (format: schema.table)') @click.option('--max-rows', type=int, help='Maximum rows per table for seed data') +@click.option('--split', is_flag=True, help='Write each object to its own file, organized in subdirectories by type') @click.option('--verbose', '-v', is_flag=True, help='Verbose logging') @click.option('--quiet', '-q', is_flag=True, help='Quiet mode (errors only)') @click.option('--log-file', help='Log to file') @@ -39,6 +40,7 @@ def main( seed_data: bool, tables: str, max_rows: int, + split: bool, verbose: bool, quiet: bool, log_file: str @@ -110,7 +112,8 @@ def main( schema_filter=schema_filter, extract_seed_data=seed_data, seed_tables=seed_tables, - max_rows=max_rows + max_rows=max_rows, + split=split ) click.echo(f"\n✓ Extraction complete! SQL scripts written to: {output}") diff --git a/src/extractor.py b/src/extractor.py index 58e2b44..39d0e24 100644 --- a/src/extractor.py +++ b/src/extractor.py @@ -31,7 +31,8 @@ def extract( schema_filter: Optional[List[str]] = None, extract_seed_data: bool = False, seed_tables: Optional[List[str]] = None, - max_rows: Optional[int] = None + max_rows: Optional[int] = None, + split: bool = False ) -> None: """Extract database schema and data. @@ -41,6 +42,7 @@ def extract( extract_seed_data: Whether to extract seed data seed_tables: Optional list of tables for seed data (format: schema.table) max_rows: Maximum rows per table for seed data + split: Write each object to its own file """ try: # Connect to database @@ -95,8 +97,8 @@ def extract( logger.info("Extracting functions...") functions = schema_reader.get_functions(schema_filter) - # Extract seed data if requested - seed_data = [] + # Extract seed data if requested (dict keyed by schema.table for split support) + seed_data_dict: Dict[str, List[str]] = {} if extract_seed_data: logger.info("Extracting seed data...") @@ -131,24 +133,44 @@ def extract( table.columns, max_rows=max_rows ) - seed_data.extend(table_data) + if table_data: + key = f"{table.schema_name}.{table.table_name}" + seed_data_dict[key] = table_data # Generate SQL scripts logger.info("Generating SQL scripts...") - formatter.write_modular_format( - schemas=schemas, - tables=ordered_tables, - primary_keys=primary_keys, - foreign_keys=foreign_keys, - deferred_foreign_keys=deferred_fks, - unique_constraints=unique_constraints, - check_constraints=check_constraints, - indexes=indexes, - views=views, - procedures=procedures, - functions=functions, - seed_data=seed_data - ) + if split: + formatter.write_split_format( + schemas=schemas, + tables=ordered_tables, + primary_keys=primary_keys, + foreign_keys=foreign_keys, + deferred_foreign_keys=deferred_fks, + unique_constraints=unique_constraints, + check_constraints=check_constraints, + indexes=indexes, + views=views, + procedures=procedures, + functions=functions, + seed_data=seed_data_dict + ) + else: + # Flatten dict to list for modular format + seed_data_flat = [stmt for stmts in seed_data_dict.values() for stmt in stmts] + formatter.write_modular_format( + schemas=schemas, + tables=ordered_tables, + primary_keys=primary_keys, + foreign_keys=foreign_keys, + deferred_foreign_keys=deferred_fks, + unique_constraints=unique_constraints, + check_constraints=check_constraints, + indexes=indexes, + views=views, + procedures=procedures, + functions=functions, + seed_data=seed_data_flat + ) logger.info("Extraction complete!") diff --git a/src/formatter.py b/src/formatter.py index 5c51f28..33e08ed 100644 --- a/src/formatter.py +++ b/src/formatter.py @@ -6,7 +6,7 @@ from .schema import SchemaInfo, TableInfo, ColumnInfo from .constraints import PrimaryKeyInfo, ForeignKeyInfo, UniqueConstraintInfo, CheckConstraintInfo from .indexes import IndexInfo -from .utils import get_logger, escape_sql_identifier, OutputError +from .utils import get_logger, escape_sql_identifier, sanitize_filename, OutputError logger = get_logger(__name__) @@ -423,3 +423,120 @@ def write_modular_format( f.write("\n") logger.info(f"Modular SQL scripts written to {self.output_dir}") + + def write_split_format( + self, + schemas: List[SchemaInfo], + tables: List[TableInfo], + primary_keys: List[PrimaryKeyInfo], + foreign_keys: List[ForeignKeyInfo], + deferred_foreign_keys: List[ForeignKeyInfo], + unique_constraints: List[UniqueConstraintInfo], + check_constraints: List[CheckConstraintInfo], + indexes: List[IndexInfo], + views: List[Dict[str, Any]], + procedures: List[Dict[str, Any]], + functions: List[Dict[str, Any]], + seed_data: Dict[str, List[str]] + ) -> None: + """Write each object to its own file, organized in subdirectories by type. + + Args: + schemas: List of schemas + tables: List of tables + primary_keys: List of primary keys + foreign_keys: List of foreign keys + deferred_foreign_keys: List of deferred foreign keys (circular deps) + unique_constraints: List of unique constraints + check_constraints: List of check constraints + indexes: List of indexes + views: List of views + procedures: List of procedures + functions: List of functions + seed_data: Dict keyed by "schema.table" with lists of INSERT statements + """ + header = self.write_header() + written_paths: dict[str, str] = {} + + def _write(subdir: str, filename: str, content: str) -> None: + """Write a single object file inside a typed subdirectory.""" + try: + dir_path = os.path.join(self.output_dir, subdir) + os.makedirs(dir_path, exist_ok=True) + safe_filename = sanitize_filename(filename) + path = os.path.join(dir_path, safe_filename) + + # Detect in-run collisions (different source names -> same sanitized path) + collision_key = os.path.normcase(os.path.normpath(path)) + previous_source = written_paths.get(collision_key) + if previous_source is not None and previous_source != filename: + raise OutputError( + f"Filename collision after sanitization: " + f"'{previous_source}' vs '{filename}' -> '{safe_filename}'" + ) + written_paths[collision_key] = filename + + with open(path, 'w', encoding='utf-8') as f: + f.write(header) + f.write(content) + except OutputError: + raise + except (OSError, ValueError) as exc: + raise OutputError(f"Failed to write split output file: {filename}") from exc + + # Schemas + for schema in schemas: + _write("schemas", f"{schema.name}.sql", self.format_schema(schema)) + + # Tables + for table in tables: + _write("tables", f"{table.schema_name}.{table.table_name}.sql", self.format_table(table)) + + # Constraints — all types in one folder with schema.table.constraint_name.sql + for pk in primary_keys: + _write("constraints", f"{pk.schema_name}.{pk.table_name}.{pk.constraint_name}.sql", + self.format_primary_key(pk)) + + regular_fks = [fk for fk in foreign_keys if fk not in deferred_foreign_keys] + for fk in regular_fks: + _write("constraints", f"{fk.schema_name}.{fk.table_name}.{fk.constraint_name}.sql", + self.format_foreign_key(fk)) + + for uq in unique_constraints: + _write("constraints", f"{uq.schema_name}.{uq.table_name}.{uq.constraint_name}.sql", + self.format_unique_constraint(uq)) + + for chk in check_constraints: + _write("constraints", f"{chk.schema_name}.{chk.table_name}.{chk.constraint_name}.sql", + self.format_check_constraint(chk)) + + # Indexes + for idx in indexes: + _write("indexes", f"{idx.schema_name}.{idx.table_name}.{idx.index_name}.sql", + self.format_index(idx)) + + # Views + for view in views: + _write("views", f"{view['schema_name']}.{view['view_name']}.sql", + self.format_view(view)) + + # Procedures + for proc in procedures: + _write("procedures", f"{proc['schema_name']}.{proc['procedure_name']}.sql", + self.format_procedure(proc)) + + # Functions + for func in functions: + _write("functions", f"{func['schema_name']}.{func['function_name']}.sql", + self.format_function(func)) + + # Deferred foreign keys + for fk in deferred_foreign_keys: + _write("deferred_fks", f"{fk.schema_name}.{fk.table_name}.{fk.constraint_name}.sql", + self.format_foreign_key(fk)) + + # Seed data — one file per table + for table_key, statements in seed_data.items(): + _write("seed_data", f"{table_key}.sql", "\n".join(statements) + "\n") + + logger.info(f"Split SQL scripts written to {self.output_dir}") diff --git a/src/utils.py b/src/utils.py index 34faf2f..8df4f22 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,6 +1,7 @@ """Utility functions and logging setup.""" import logging +import os import sys from typing import Optional from rich.console import Console @@ -95,6 +96,33 @@ def escape_sql_string(value: str) -> str: return value.replace("'", "''") +_WINDOWS_RESERVED = frozenset({ + 'CON', 'PRN', 'AUX', 'NUL', + *(f'COM{i}' for i in range(1, 10)), + *(f'LPT{i}' for i in range(1, 10)), +}) + + +def sanitize_filename(name: str) -> str: + """Sanitize a SQL object name for use as a filename. + + Replaces filesystem-unsafe characters with underscores, strips trailing dots/spaces, + and escapes Windows-reserved device names. + """ + for ch in r'/\:*?"<>|': + name = name.replace(ch, '_') + sanitized = name.rstrip('. ') + if not sanitized: + raise ValueError(f"Filename is empty after sanitizing: {name!r}") + + # Escape Windows-reserved device names (e.g. CON.sql -> CON_.sql) + base, ext = os.path.splitext(sanitized) + if base.upper() in _WINDOWS_RESERVED: + sanitized = base + '_' + ext + + return sanitized + + class SQLExtractError(Exception): """Base exception for SQL Extract.""" pass diff --git a/tests/test_extractor.py b/tests/test_extractor.py index 9acc8c6..09e2e9b 100644 --- a/tests/test_extractor.py +++ b/tests/test_extractor.py @@ -1,7 +1,7 @@ """Tests for core extractor module.""" import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, call from src.extractor import DatabaseExtractor from src.connection import ConnectionConfig @@ -22,3 +22,153 @@ def test_extractor_initialization(self): assert extractor.config == config assert extractor.connection is None + + +class TestExtractorSplitFlag: + """Test that the split flag routes to write_split_format.""" + + def _run_extract(self, split): + """Run extract with all internals mocked, returning the formatter mock.""" + config = ConnectionConfig( + server="localhost", database="TestDB", + user="sa", password="TestPass" + ) + extractor = DatabaseExtractor(config) + + with patch('src.extractor.DatabaseConnection') as mock_conn_cls, \ + patch('src.extractor.SchemaReader') as mock_schema_cls, \ + patch('src.extractor.ConstraintExtractor') as mock_constr_cls, \ + patch('src.extractor.IndexExtractor') as mock_idx_cls, \ + patch('src.extractor.SeedDataExtractor'), \ + patch('src.extractor.SQLFormatter') as mock_fmt_cls: + + # Stub connection + mock_conn = MagicMock() + mock_conn_cls.return_value = mock_conn + + # Stub schema reader + mock_schema = mock_schema_cls.return_value + mock_schema.get_schemas.return_value = [] + mock_schema.get_tables.return_value = [] + mock_schema.get_views.return_value = [] + mock_schema.get_stored_procedures.return_value = [] + mock_schema.get_functions.return_value = [] + + # Stub constraint extractor + mock_constr = mock_constr_cls.return_value + mock_constr.get_primary_keys.return_value = [] + mock_constr.get_foreign_keys.return_value = [] + mock_constr.get_unique_constraints.return_value = [] + mock_constr.get_check_constraints.return_value = [] + + # Stub index extractor + mock_idx_cls.return_value.get_indexes.return_value = [] + + formatter = mock_fmt_cls.return_value + extractor.extract(output_dir="/tmp/test", split=split) + return formatter + + def test_split_true_calls_write_split_format(self): + """split=True calls write_split_format, not write_modular_format.""" + formatter = self._run_extract(split=True) + + formatter.write_split_format.assert_called_once() + formatter.write_modular_format.assert_not_called() + + def test_split_false_calls_write_modular_format(self): + """split=False calls write_modular_format, not write_split_format.""" + formatter = self._run_extract(split=False) + + formatter.write_modular_format.assert_called_once() + formatter.write_split_format.assert_not_called() + + def test_split_default_is_modular(self): + """Default (no split arg) calls write_modular_format.""" + config = ConnectionConfig( + server="localhost", database="TestDB", + user="sa", password="TestPass" + ) + extractor = DatabaseExtractor(config) + + with patch('src.extractor.DatabaseConnection'), \ + patch('src.extractor.SchemaReader') as mock_schema_cls, \ + patch('src.extractor.ConstraintExtractor') as mock_constr_cls, \ + patch('src.extractor.IndexExtractor') as mock_idx_cls, \ + patch('src.extractor.SeedDataExtractor'), \ + patch('src.extractor.SQLFormatter') as mock_fmt_cls: + + mock_schema = mock_schema_cls.return_value + mock_schema.get_schemas.return_value = [] + mock_schema.get_tables.return_value = [] + mock_schema.get_views.return_value = [] + mock_schema.get_stored_procedures.return_value = [] + mock_schema.get_functions.return_value = [] + + mock_constr = mock_constr_cls.return_value + mock_constr.get_primary_keys.return_value = [] + mock_constr.get_foreign_keys.return_value = [] + mock_constr.get_unique_constraints.return_value = [] + mock_constr.get_check_constraints.return_value = [] + + mock_idx_cls.return_value.get_indexes.return_value = [] + + formatter = mock_fmt_cls.return_value + + # Call without split kwarg + extractor.extract(output_dir="/tmp/test") + + formatter.write_modular_format.assert_called_once() + formatter.write_split_format.assert_not_called() + + def test_seed_data_passed_as_dict_to_split(self): + """split=True passes seed data as a dict keyed by schema.table.""" + config = ConnectionConfig( + server="localhost", database="TestDB", + user="sa", password="TestPass" + ) + extractor = DatabaseExtractor(config) + + with patch('src.extractor.DatabaseConnection'), \ + patch('src.extractor.SchemaReader') as mock_schema_cls, \ + patch('src.extractor.ConstraintExtractor') as mock_constr_cls, \ + patch('src.extractor.IndexExtractor') as mock_idx_cls, \ + patch('src.extractor.SeedDataExtractor') as mock_seed_cls, \ + patch('src.extractor.SQLFormatter') as mock_fmt_cls, \ + patch('src.extractor.order_tables_by_dependency') as mock_order: + + # Create a mock table + mock_table = MagicMock() + mock_table.schema_name = "dbo" + mock_table.table_name = "Config" + mock_table.columns = [] + + mock_schema = mock_schema_cls.return_value + mock_schema.get_schemas.return_value = [] + mock_schema.get_tables.return_value = [mock_table] + mock_schema.get_views.return_value = [] + mock_schema.get_stored_procedures.return_value = [] + mock_schema.get_functions.return_value = [] + + mock_constr = mock_constr_cls.return_value + mock_constr.get_primary_keys.return_value = [] + mock_constr.get_foreign_keys.return_value = [] + mock_constr.get_unique_constraints.return_value = [] + mock_constr.get_check_constraints.return_value = [] + + mock_idx_cls.return_value.get_indexes.return_value = [] + mock_order.return_value = ([("dbo", "Config")], []) + + # Seed data extractor returns some data + mock_seed = mock_seed_cls.return_value + mock_seed.extract_table_data.return_value = ["INSERT INTO [dbo].[Config] VALUES (1);"] + + formatter = mock_fmt_cls.return_value + extractor.extract( + output_dir="/tmp/test", + extract_seed_data=True, + split=True + ) + + # Verify seed_data kwarg is a dict + call_kwargs = formatter.write_split_format.call_args[1] + assert call_kwargs["seed_data"] == {"dbo.Config": ["INSERT INTO [dbo].[Config] VALUES (1);"]} diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 9d57f27..5f01a2d 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -6,7 +6,9 @@ from pathlib import Path from src.formatter import SQLFormatter from src.schema import SchemaInfo, TableInfo, ColumnInfo -from src.constraints import PrimaryKeyInfo, ForeignKeyInfo +from src.constraints import PrimaryKeyInfo, ForeignKeyInfo, UniqueConstraintInfo, CheckConstraintInfo +from src.indexes import IndexInfo +from src.utils import OutputError class TestSQLFormatter: @@ -125,3 +127,238 @@ def test_write_header(self): assert "SQL Extract" in header assert "Generated by" in header assert "SET QUOTED_IDENTIFIER ON" in header + + +# --- Helpers for split format tests --- + +def _make_column(name="ID", data_type="int"): + return ColumnInfo( + name=name, data_type=data_type, max_length=None, precision=None, + scale=None, is_nullable=False, is_identity=False, is_computed=False, + default_value=None, computed_definition=None + ) + + +def _make_table(schema="dbo", name="Users"): + return TableInfo(schema_name=schema, table_name=name, columns=[_make_column()]) + + +def _make_pk(schema="dbo", table="Users", constraint="PK_Users"): + return PrimaryKeyInfo( + constraint_name=constraint, schema_name=schema, + table_name=table, columns=["ID"], is_clustered=True + ) + + +def _make_fk(schema="dbo", table="Orders", constraint="FK_Orders_Users"): + return ForeignKeyInfo( + constraint_name=constraint, schema_name=schema, table_name=table, + columns=["UserID"], referenced_schema="dbo", referenced_table="Users", + referenced_columns=["ID"], delete_action="NO_ACTION", update_action="NO_ACTION" + ) + + +def _make_uq(schema="dbo", table="Users", constraint="UQ_Email"): + return UniqueConstraintInfo( + constraint_name=constraint, schema_name=schema, + table_name=table, columns=["Email"] + ) + + +def _make_chk(schema="dbo", table="Users", constraint="CHK_Age"): + return CheckConstraintInfo( + constraint_name=constraint, schema_name=schema, + table_name=table, definition="([Age] > 0)", is_disabled=False + ) + + +def _make_index(schema="dbo", table="Users", index="IX_Email"): + return IndexInfo( + index_name=index, schema_name=schema, table_name=table, + is_unique=False, is_clustered=False, is_primary_key=False, + is_unique_constraint=False, columns=["Email"], + included_columns=[], filter_definition=None + ) + + +class TestWriteSplitFormat: + """Test write_split_format method.""" + + def _call_split(self, tmpdir, **overrides): + """Call write_split_format with sensible defaults, overridable per-test.""" + formatter = SQLFormatter(tmpdir) + defaults = dict( + schemas=[], tables=[], primary_keys=[], foreign_keys=[], + deferred_foreign_keys=[], unique_constraints=[], check_constraints=[], + indexes=[], views=[], procedures=[], functions=[], seed_data={} + ) + defaults.update(overrides) + formatter.write_split_format(**defaults) + return formatter + + def test_schemas_directory(self): + """Each schema gets its own file in schemas/.""" + with tempfile.TemporaryDirectory() as tmpdir: + self._call_split(tmpdir, schemas=[ + SchemaInfo(name="dbo", owner="dbo"), + SchemaInfo(name="MQTT", owner="dbo"), + ]) + + files = os.listdir(os.path.join(tmpdir, "schemas")) + assert sorted(files) == ["MQTT.sql", "dbo.sql"] + + def test_tables_directory(self): + """Each table gets schema.table.sql in tables/.""" + with tempfile.TemporaryDirectory() as tmpdir: + self._call_split(tmpdir, tables=[ + _make_table("dbo", "Users"), + _make_table("MQTT", "Broker"), + ]) + + files = os.listdir(os.path.join(tmpdir, "tables")) + assert sorted(files) == ["MQTT.Broker.sql", "dbo.Users.sql"] + + def test_constraints_directory_mixed_types(self): + """PKs, FKs, UQs, CHKs all land in constraints/.""" + with tempfile.TemporaryDirectory() as tmpdir: + self._call_split(tmpdir, + primary_keys=[_make_pk()], + foreign_keys=[_make_fk()], + unique_constraints=[_make_uq()], + check_constraints=[_make_chk()], + ) + + files = sorted(os.listdir(os.path.join(tmpdir, "constraints"))) + assert files == [ + "dbo.Orders.FK_Orders_Users.sql", + "dbo.Users.CHK_Age.sql", + "dbo.Users.PK_Users.sql", + "dbo.Users.UQ_Email.sql", + ] + + def test_indexes_directory(self): + """Indexes get schema.table.index_name.sql in indexes/.""" + with tempfile.TemporaryDirectory() as tmpdir: + self._call_split(tmpdir, indexes=[_make_index()]) + + files = os.listdir(os.path.join(tmpdir, "indexes")) + assert files == ["dbo.Users.IX_Email.sql"] + + def test_views_directory(self): + """Views get schema.view_name.sql in views/.""" + with tempfile.TemporaryDirectory() as tmpdir: + self._call_split(tmpdir, views=[{ + 'schema_name': 'dbo', 'view_name': 'ActiveUsers', + 'definition': 'CREATE VIEW [dbo].[ActiveUsers] AS SELECT 1', + }]) + + files = os.listdir(os.path.join(tmpdir, "views")) + assert files == ["dbo.ActiveUsers.sql"] + + def test_procedures_directory(self): + """Procedures get schema.procedure_name.sql in procedures/.""" + with tempfile.TemporaryDirectory() as tmpdir: + self._call_split(tmpdir, procedures=[{ + 'schema_name': 'dbo', 'procedure_name': 'GetUser', + 'definition': 'CREATE PROCEDURE [dbo].[GetUser] AS SELECT 1', + }]) + + files = os.listdir(os.path.join(tmpdir, "procedures")) + assert files == ["dbo.GetUser.sql"] + + def test_functions_directory(self): + """Functions get schema.function_name.sql in functions/.""" + with tempfile.TemporaryDirectory() as tmpdir: + self._call_split(tmpdir, functions=[{ + 'schema_name': 'dbo', 'function_name': 'FormatName', + 'definition': 'CREATE FUNCTION [dbo].[FormatName]() RETURNS int AS BEGIN RETURN 1 END', + }]) + + files = os.listdir(os.path.join(tmpdir, "functions")) + assert files == ["dbo.FormatName.sql"] + + def test_deferred_fks_directory(self): + """Deferred FKs go to deferred_fks/, not constraints/.""" + fk = _make_fk(constraint="FK_Circular") + with tempfile.TemporaryDirectory() as tmpdir: + self._call_split(tmpdir, + foreign_keys=[fk], + deferred_foreign_keys=[fk], + ) + + # Should NOT appear in constraints/ + assert not os.path.exists(os.path.join(tmpdir, "constraints")) + # Should appear in deferred_fks/ + files = os.listdir(os.path.join(tmpdir, "deferred_fks")) + assert files == ["dbo.Orders.FK_Circular.sql"] + + def test_seed_data_directory(self): + """Seed data gets one file per table in seed_data/.""" + with tempfile.TemporaryDirectory() as tmpdir: + self._call_split(tmpdir, seed_data={ + "dbo.Users": ["INSERT INTO [dbo].[Users] VALUES (1);"], + "dbo.Orders": ["INSERT INTO [dbo].[Orders] VALUES (1);"], + }) + + files = sorted(os.listdir(os.path.join(tmpdir, "seed_data"))) + assert files == ["dbo.Orders.sql", "dbo.Users.sql"] + + def test_empty_collections_skip_directories(self): + """Empty object lists don't create subdirectories.""" + with tempfile.TemporaryDirectory() as tmpdir: + self._call_split(tmpdir) + + # No subdirectories should exist + assert os.listdir(tmpdir) == [] + + def test_files_contain_header(self): + """Each split file contains the standard header.""" + with tempfile.TemporaryDirectory() as tmpdir: + self._call_split(tmpdir, schemas=[SchemaInfo(name="dbo", owner="dbo")]) + + content = Path(os.path.join(tmpdir, "schemas", "dbo.sql")).read_text() + assert "SQL Extract" in content + assert "SET QUOTED_IDENTIFIER ON" in content + + def test_files_contain_object_sql(self): + """Split files contain the actual SQL for the object.""" + with tempfile.TemporaryDirectory() as tmpdir: + self._call_split(tmpdir, tables=[_make_table("dbo", "Users")]) + + content = Path(os.path.join(tmpdir, "tables", "dbo.Users.sql")).read_text() + assert "CREATE TABLE" in content + assert "[dbo].[Users]" in content + + def test_collision_raises_output_error(self): + """Two objects that sanitize to the same filename within one run raise OutputError.""" + with tempfile.TemporaryDirectory() as tmpdir: + # "a*b" and "a|b" both sanitize to "a_b" — in-run collision + with pytest.raises(OutputError, match="Filename collision"): + self._call_split(tmpdir, schemas=[ + SchemaInfo(name="a*b", owner="dbo"), + SchemaInfo(name="a|b", owner="dbo"), + ]) + + def test_rerun_into_same_directory_succeeds(self): + """Re-running split into the same output directory overwrites cleanly.""" + with tempfile.TemporaryDirectory() as tmpdir: + # First run + self._call_split(tmpdir, schemas=[SchemaInfo(name="dbo", owner="dbo")]) + # Second run — should not raise + self._call_split(tmpdir, schemas=[SchemaInfo(name="dbo", owner="dbo")]) + + files = os.listdir(os.path.join(tmpdir, "schemas")) + assert files == ["dbo.sql"] + + def test_seed_data_content(self): + """Seed data files contain all INSERT statements.""" + with tempfile.TemporaryDirectory() as tmpdir: + stmts = [ + "INSERT INTO [dbo].[Users] VALUES (1, 'Alice');", + "INSERT INTO [dbo].[Users] VALUES (2, 'Bob');", + ] + self._call_split(tmpdir, seed_data={"dbo.Users": stmts}) + + content = Path(os.path.join(tmpdir, "seed_data", "dbo.Users.sql")).read_text() + for stmt in stmts: + assert stmt in content diff --git a/tests/test_utils.py b/tests/test_utils.py index 17763a2..0e0621d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,7 @@ """Tests for utility functions.""" import pytest -from src.utils import get_logger, ConnectionError +from src.utils import get_logger, sanitize_filename, ConnectionError class TestUtils: @@ -27,3 +27,73 @@ def test_connection_error_inheritance(self): assert isinstance(error, Exception) assert str(error) == "Test error" + + +class TestSanitizeFilename: + """Test sanitize_filename function.""" + + def test_passthrough_safe_name(self): + """Safe names are returned unchanged.""" + assert sanitize_filename("dbo.Users.sql") == "dbo.Users.sql" + + def test_replaces_unsafe_characters(self): + """Each unsafe character is replaced with underscore.""" + assert sanitize_filename('a/b\\c:d*e?f"gi|j') == "a_b_c_d_e_f_g_h_i_j" + + def test_strips_trailing_dots_and_spaces(self): + """Trailing dots and spaces are stripped.""" + assert sanitize_filename("name...") == "name" + assert sanitize_filename("name ") == "name" + assert sanitize_filename("name. . ") == "name" + + def test_empty_after_sanitize_raises(self): + """Names that become empty after sanitization raise ValueError.""" + with pytest.raises(ValueError, match="empty after sanitizing"): + sanitize_filename("...") + + def test_only_spaces_raises(self): + """Names that are only spaces raise ValueError.""" + with pytest.raises(ValueError, match="empty after sanitizing"): + sanitize_filename(" ") + + def test_windows_reserved_con(self): + """CON.sql is escaped to CON_.sql.""" + assert sanitize_filename("CON.sql") == "CON_.sql" + + def test_windows_reserved_case_insensitive(self): + """Windows reserved name check is case-insensitive.""" + assert sanitize_filename("con.sql") == "con_.sql" + assert sanitize_filename("Con.sql") == "Con_.sql" + + def test_windows_reserved_prn(self): + """PRN is escaped.""" + assert sanitize_filename("PRN.sql") == "PRN_.sql" + + def test_windows_reserved_com1(self): + """COM1 through COM9 are escaped.""" + assert sanitize_filename("COM1.sql") == "COM1_.sql" + assert sanitize_filename("COM9.sql") == "COM9_.sql" + + def test_windows_reserved_lpt1(self): + """LPT1 through LPT9 are escaped.""" + assert sanitize_filename("LPT1.sql") == "LPT1_.sql" + + def test_windows_reserved_nul(self): + """NUL is escaped.""" + assert sanitize_filename("NUL.sql") == "NUL_.sql" + + def test_windows_reserved_aux(self): + """AUX is escaped.""" + assert sanitize_filename("AUX.sql") == "AUX_.sql" + + def test_not_reserved_when_prefixed(self): + """dbo.CON.sql is NOT reserved — only bare CON is.""" + assert sanitize_filename("dbo.CON.sql") == "dbo.CON.sql" + + def test_not_reserved_com10(self): + """COM10 is not a reserved name.""" + assert sanitize_filename("COM10.sql") == "COM10.sql" + + def test_multiple_unsafe_chars_replaced(self): + """Multiple occurrences of the same unsafe char are all replaced.""" + assert sanitize_filename("a:b:c.sql") == "a_b_c.sql"