diff --git a/core/ast/enums.py b/core/ast/enums.py index 50a8849..63f79dc 100644 --- a/core/ast/enums.py +++ b/core/ast/enums.py @@ -37,6 +37,7 @@ class NodeType(Enum): LIMIT = "limit" OFFSET = "offset" QUERY = "query" + COMPOUND_QUERY = "compound_query" CASE = "case" WHEN_THEN = "when_then" diff --git a/core/ast/node.py b/core/ast/node.py index 52e505d..2442700 100644 --- a/core/ast/node.py +++ b/core/ast/node.py @@ -383,7 +383,38 @@ def __init__(self, if _offset: children.append(_offset) super().__init__(NodeType.QUERY, children=children, **kwargs) - + + +class CompoundQueryNode(Node): + """Binary UNION / UNION ALL node: left and right are query-producing Nodes. + + Multi-branch chains are represented as left-associative trees, so + A UNION B UNION C becomes CompoundQueryNode(CompoundQueryNode(A, B, False), C, False). + The formatter collapses same-type left-chains back to flat lists to match + mo_sql_parsing's wire format. + """ + + def __init__( + self, + _left: Node, + _right: Node, + _is_all: bool = False, + **kwargs, + ): + super().__init__(NodeType.COMPOUND_QUERY, children=[_left, _right], **kwargs) + self.left = _left + self.right = _right + self.is_all = _is_all + + def __eq__(self, other): + if not isinstance(other, CompoundQueryNode): + return False + return super().__eq__(other) and self.is_all == other.is_all + + def __hash__(self): + return hash((super().__hash__(), self.is_all)) + + class WhenThenNode(Node): """Single WHEN ... THEN ... branch of a CASE expression""" def __init__(self, _when: Node, _then: Node, **kwargs): diff --git a/core/query_formatter.py b/core/query_formatter.py index b27ed1b..43d9834 100644 --- a/core/query_formatter.py +++ b/core/query_formatter.py @@ -1,15 +1,24 @@ import re import mo_sql_parsing as mosql from core.ast.node import ( - QueryNode, SelectNode, FromNode, WhereNode, TableNode, GroupByNode, HavingNode, - OrderByNode,JoinNode + QueryNode, + CompoundQueryNode, + SelectNode, + FromNode, + WhereNode, + TableNode, + GroupByNode, + HavingNode, + OrderByNode, + JoinNode, + SubqueryNode, ) from core.ast.enums import NodeType, JoinType, SortOrder from core.ast.node import Node class QueryFormatter: - def format(self, query: QueryNode) -> str: - # [1] AST (QueryNode) -> JSON + def format(self, query: Node) -> str: + # [1] AST -> JSON json_query = ast_to_json(query) # [2] Any (JSON) -> str @@ -20,8 +29,41 @@ def format(self, query: QueryNode) -> str: return sql -def ast_to_json(node: QueryNode) -> dict: - """Convert QueryNode AST to JSON dictionary for mosql""" +def _collect_union_branches(node: CompoundQueryNode, is_all: bool) -> list: + """Flatten a left-chain of same-type CompoundQueryNodes into a list. + + mo_sql_parsing uses flat lists for chains of the same operator + (e.g. A UNION B UNION C → {'union': [A, B, C]}). Nesting is only + used at type boundaries (e.g. (A UNION ALL B) UNION C). This helper + mirrors that convention so round-trips produce identical JSON. + """ + result = [] + if isinstance(node.left, CompoundQueryNode) and node.left.is_all == is_all: + result.extend(_collect_union_branches(node.left, is_all)) + else: + result.append(node.left) + if isinstance(node.right, CompoundQueryNode) and node.right.is_all == is_all: + result.extend(_collect_union_branches(node.right, is_all)) + else: + result.append(node.right) + return result + + +def compound_to_mosql_json(node: CompoundQueryNode) -> dict: + """Convert a CompoundQueryNode binary tree to mo_sql_parsing union/union_all JSON.""" + key = 'union_all' if node.is_all else 'union' + branches = _collect_union_branches(node, node.is_all) + return {key: [ast_to_json(b) for b in branches]} + + +def ast_to_json(node: Node) -> dict: + """Convert AST to JSON dictionary for mosql.""" + if isinstance(node, CompoundQueryNode): + return compound_to_mosql_json(node) + if not isinstance(node, QueryNode): + raise TypeError( + f"ast_to_json: expected QueryNode or CompoundQueryNode, got {type(node).__name__}" + ) result = {} # process each clause in the query @@ -81,11 +123,35 @@ def format_select(select_node: SelectNode) -> dict: return result -def format_from(from_node: FromNode) -> list: - """Format FROM clause with explicit JOIN support""" - sources = [] +def format_from(from_node: FromNode): + """Format the FROM clause for mo_sql_parsing. + + mo_sql_parsing quirk: a bare (unaliased) UNION/UNION ALL in FROM is + represented as a plain dict at the FROM key, NOT as a one-element list + wrapping a {'value': ...} dict. For example: + + SELECT * FROM (SELECT 1 UNION SELECT 2) + → {"select": ..., "from": {"union": [...]}} ← dict, not list + + An aliased variant uses the normal list-of-sources form: + SELECT * FROM (SELECT 1 UNION SELECT 2) t + → {"select": ..., "from": [{"value": {"union": [...]}, "name": "t"}]} + + Everything else (tables, aliased subqueries, JOINs) returns a list. + """ children = list(from_node.children) - + + # Special case: single unaliased UNION subquery must be a bare dict + if ( + len(children) == 1 + and isinstance(children[0], SubqueryNode) + and children[0].alias is None + ): + inner = list(children[0].children)[0] + if isinstance(inner, CompoundQueryNode): + return compound_to_mosql_json(inner) + + sources = [] if not children: return sources diff --git a/core/query_parser.py b/core/query_parser.py index 19494bb..30c763b 100644 --- a/core/query_parser.py +++ b/core/query_parser.py @@ -1,5 +1,5 @@ from core.ast.node import ( - Node, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, + Node, QueryNode, CompoundQueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, LiteralNode, DataTypeNode, TimeUnitNode, IntervalNode, CaseNode, WhenThenNode, OperatorNode, UnaryOperatorNode, FunctionNode, GroupByNode, HavingNode, @@ -45,22 +45,18 @@ def normalize_to_list(value): "Expected None, list, dict, or str." ) - def parse(self, query: str) -> QueryNode: - # str -> mo_sql_parsing -> QueryNode + def parse(self, query: str) -> Node: + # str -> mo_sql_parsing -> QueryNode or CompoundQueryNode mosql_ast = mosql.parse(query) - return self.parse_query_dict(mosql_ast, aliases={}) + return self.parse_top_level_dict(mosql_ast, aliases={}) def parse_select(self, select_list: list, aliases: dict, distinct: bool = False, distinct_on_expr = None) -> SelectNode: items = [] for item in select_list: if isinstance(item, dict) and 'value' in item: value = item['value'] - # Check if value is a subquery - if isinstance(value, dict) and 'select' in value: - # This is a subquery in SELECT clause - # Subquery has its own alias scope (no leaking to/from outer query) - subquery_query = self.parse_query_dict(value, aliases={}) - expression = SubqueryNode(subquery_query) + if self._is_subquery_dict(value): + expression = SubqueryNode(self.parse_top_level_dict(value, aliases={})) else: expression = self.parse_expression(value, aliases) @@ -87,113 +83,74 @@ def parse_select(self, select_list: list, aliases: dict, distinct: bool = False, return SelectNode(items, _distinct=distinct, _distinct_on=distinct_on_node) + def _build_from_source(self, value, alias) -> Node: + """Resolve a FROM/JOIN value and its alias into a SubqueryNode or TableNode.""" + if self._is_subquery_dict(value): + return SubqueryNode(self.parse_top_level_dict(value, aliases={}), alias) + return TableNode(value, alias) + def parse_from(self, from_list: list, aliases: dict) -> FromNode: sources = [] left_source = None # Can be a table or the result of a previous join - + + def _append_source(node: Node, alias): + nonlocal left_source + if alias: + aliases[alias] = node + if left_source is None: + left_source = node + else: + sources.append(node) + for item in from_list: - # Check for JOIN first (before checking for 'value') if isinstance(item, dict): - # Look for any join key join_key = next((k for k in item.keys() if 'join' in k.lower()), None) - + if join_key: - # This is a JOIN if left_source is None: raise ValueError(f"JOIN found without a left table. join_key={join_key}, item={item}") - + join_info = item[join_key] - # Handle both string and dict join_info if isinstance(join_info, str): - table_name = join_info + right_source = TableNode(join_info) alias = None - right_source = TableNode(table_name, alias) elif isinstance(join_info, dict): - # Derived table: {'value': {}, 'name': } value = join_info.get('value') - if isinstance(value, dict) and 'select' in value: - subquery_query = self.parse_query_dict(value, aliases={}) - alias = join_info.get('name') - right_source = SubqueryNode(subquery_query, alias) - elif 'select' in join_info: - # Subquery at top level (alias in 'name' if present) - subquery_query = self.parse_query_dict(join_info, aliases={}) - alias = join_info.get('name') - right_source = SubqueryNode(subquery_query, alias) + alias = join_info.get('name') + if value is not None: + right_source = self._build_from_source(value, alias) else: - table_name = join_info.get('value', join_info) - alias = join_info.get('name') - right_source = TableNode(table_name, alias) + # Bare subquery dict at join level: {'select': ..., 'name': ...} + right_source = self._build_from_source(join_info, alias) else: - table_name = join_info + right_source = TableNode(join_info) alias = None - right_source = TableNode(table_name, alias) - - # Track alias + if alias: aliases[alias] = right_source - + on_condition = None if 'on' in item: on_condition = self.parse_expression(item['on'], aliases) - - # Create join node - left_source might be a table or a previous join + join_type = self.parse_join_type(join_key) join_node = JoinNode(left_source, right_source, join_type, on_condition) - # The result of this JOIN becomes the new left source for potential next JOIN left_source = join_node + elif 'value' in item: - # Check if value is a subquery - value = item['value'] alias = item.get('name') - - if isinstance(value, dict) and 'select' in value: - # This is a subquery in FROM clause - # Subquery has its own alias scope (no leaking to/from outer query) - subquery_query = self.parse_query_dict(value, aliases={}) - subquery_node = SubqueryNode(subquery_query, alias) - # Track subquery alias - if alias: - aliases[alias] = subquery_node - - if left_source is None: - left_source = subquery_node - else: - sources.append(subquery_node) - else: - # This is a table reference - table_name = value - table_node = TableNode(table_name, alias) - # Track table alias - if alias: - aliases[alias] = table_node - - if left_source is None: - # First table becomes the left source - left_source = table_node - else: - # Multiple tables without explicit JOIN (cross join) - sources.append(table_node) - # Subquery in FROM specified directly as a query dict (with 'select'). - elif 'select' in item: + node = self._build_from_source(item['value'], alias) + _append_source(node, alias) + + elif self._is_subquery_dict(item): + # Bare query/union dict directly in FROM (no 'value' wrapper) alias = item.get('name') - subquery_query = self.parse_query_dict(item, aliases={}) - subquery_node = SubqueryNode(subquery_query, alias) - if alias: - aliases[alias] = subquery_node - if left_source is None: - left_source = subquery_node - else: - sources.append(subquery_node) + node = self._build_from_source(item, alias) + _append_source(node, alias) + elif isinstance(item, str): - # Simple string table name - table_node = TableNode(item) - if left_source is None: - left_source = table_node - else: - sources.append(table_node) - - # Prepend the first/left source so order is preserved + _append_source(TableNode(item), None) + if left_source is not None: sources.insert(0, left_source) @@ -348,12 +305,8 @@ def parse_expression(self, expr, aliases: dict = None) -> Node: return parsed if isinstance(expr, dict): - # Check if this is a subquery (has 'select' key) - if 'select' in expr: - # This is a subquery - parse it recursively - # Subquery has its own alias scope (no leaking to/from outer query) - subquery_query = self.parse_query_dict(expr, aliases={}) - return SubqueryNode(subquery_query) + if self._is_subquery_dict(expr): + return SubqueryNode(self.parse_top_level_dict(expr, aliases={})) # Special cases first if 'all_columns' in expr: @@ -438,16 +391,8 @@ def parse_expression(self, expr, aliases: dict = None) -> Node: left = self.parse_expression(left_raw, aliases) - # Subquery RHS - if isinstance(right_raw, dict) and 'select' in right_raw: - right = self.parse_expression(right_raw, aliases) - # Literal-list RHS - elif isinstance(right_raw, dict) and 'literal' in right_raw: - # parse_expression on this dict will return a ListNode or LiteralNode - right = self.parse_expression(right_raw, aliases) - elif isinstance(right_raw, list): - items = [self.parse_expression(item, aliases) for item in right_raw] - right = ListNode(items) + if isinstance(right_raw, list): + right = ListNode([self.parse_expression(item, aliases) for item in right_raw]) else: right = self.parse_expression(right_raw, aliases) @@ -513,9 +458,8 @@ def parse_expression(self, expr, aliases: dict = None) -> Node: return UnaryOperatorNode(self.parse_expression(value, aliases), '-') # Pattern 3: EXISTS operator with subquery - if key == 'exists' and isinstance(value, dict) and 'select' in value: - # Subquery has its own alias scope (no leaking to/from outer query) - subquery_query = self.parse_query_dict(value, aliases={}) + if key == 'exists' and self._is_subquery_dict(value): + subquery_query = self.parse_top_level_dict(value, aliases={}) subquery_node = SubqueryNode(subquery_query) return OperatorNode(subquery_node, 'EXISTS') @@ -533,6 +477,59 @@ def parse_expression(self, expr, aliases: dict = None) -> Node: # Other types return LiteralNode(expr) + + @staticmethod + def _mosql_dict_is_compound_union(d: dict) -> bool: + if not isinstance(d, dict): + return False + if 'select' in d or 'select_distinct' in d: + return False + return 'union' in d or 'union_all' in d + + @classmethod + def _is_subquery_dict(cls, d) -> bool: + """True if d is any mo_sql_parsing dict that produces a query (SELECT, UNION, etc.).""" + return isinstance(d, dict) and ( + 'select' in d or 'select_distinct' in d + or cls._mosql_dict_is_compound_union(d) + ) + + def parse_compound_union_dict(self, d: dict) -> CompoundQueryNode: + """Build a left-associative binary tree from a mo_sql_parsing union/union_all dict. + + mo_sql_parsing never attaches extra clause keys (e.g. limit, orderby) directly + to the union dict — it always lifts them to an outer wrapper dict. So d is + expected to contain only a 'union' or 'union_all' key. + """ + if 'union_all' in d: + is_all = True + items = self.normalize_to_list(d['union_all']) + elif 'union' in d: + is_all = False + items = self.normalize_to_list(d['union']) + else: + raise ValueError(f'Expected union or union_all in dict, got keys {list(d.keys())}') + + if len(items) < 2: + raise ValueError( + f"Expected at least 2 branches for " + f"{'union_all' if is_all else 'union'}, got {len(items)}" + ) + + def parse_item(item) -> Node: + if isinstance(item, dict) and self._mosql_dict_is_compound_union(item): + return self.parse_compound_union_dict(item) + return self.parse_query_dict(item, {}) + + left = parse_item(items[0]) + for item in items[1:]: + left = CompoundQueryNode(left, parse_item(item), is_all) + return left + + def parse_top_level_dict(self, query_dict: dict, aliases: dict) -> Node: + if self._mosql_dict_is_compound_union(query_dict): + return self.parse_compound_union_dict(query_dict) + return self.parse_query_dict(query_dict, aliases) def parse_query_dict(self, query_dict: dict, aliases: dict) -> QueryNode: """Parse a mo_sql_parsing query-dict into a QueryNode. diff --git a/data/asts.py b/data/asts.py index e6070b8..17fdeec 100644 --- a/data/asts.py +++ b/data/asts.py @@ -1,12 +1,13 @@ from typing import Optional from core.ast.node import ( - CaseNode, WhenThenNode, IntervalNode, ListNode, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, + CaseNode, WhenThenNode, IntervalNode, ListNode, Node, QueryNode, CompoundQueryNode, SelectNode, FromNode, + WhereNode, TableNode, ColumnNode, LiteralNode, DataTypeNode, TimeUnitNode, OperatorNode, UnaryOperatorNode, FunctionNode, GroupByNode, HavingNode, OrderByNode, OrderByItemNode, LimitNode, OffsetNode, JoinNode, SubqueryNode ) from core.ast.enums import JoinType, SortOrder -def get_ast(query_id: int) -> Optional[QueryNode]: +def get_ast(query_id: int) -> Optional[Node]: """Return the expected AST for a given query id, or None if not available.""" _asts = _build_asts() return _asts.get(query_id, None) @@ -28,7 +29,7 @@ def _build_asts() -> dict: 12: _ast_query_12(), 13: _ast_query_13(), 14: _ast_query_14(), - # 15: UNION not supported by parser + 15: _ast_query_15(), 16: _ast_query_16(), 17: _ast_query_17(), 18: _ast_query_18(), @@ -42,10 +43,10 @@ def _build_asts() -> dict: 26: _ast_query_26(), 27: _ast_query_27(), 28: _ast_query_28(), - # 29: UNION not supported by parser + 29: _ast_query_29(), 30: _ast_query_30(), 31: _ast_query_31(), - # 32: UNION not supported by parser + 32: _ast_query_32(), 33: _ast_query_33(), 34: _ast_query_34(), 35: _ast_query_35(), @@ -494,7 +495,24 @@ def _ast_query_14() -> QueryNode: _limit=limit_clause, ) -# TODO: Query 15 uses UNION, which is not supported by parser yet +def _ast_query_15() -> QueryNode: + """Query 15: UNION ALL inside derived table (Calcite push min through union).""" + emp_branch = QueryNode( + _select=SelectNode([ColumnNode("*")]), + _from=FromNode([TableNode("EMP", _alias="EMP")]), + ) + union_subquery = SubqueryNode( + CompoundQueryNode(emp_branch, emp_branch, True), + _alias="t", + ) + ename = ColumnNode("ENAME", _parent_alias="t") + empno_min = FunctionNode("MIN", _args=[ColumnNode("EMPNO", _parent_alias="t")]) + return QueryNode( + _select=SelectNode([ename, empno_min]), + _from=FromNode([union_subquery]), + _group_by=GroupByNode([ColumnNode("ENAME", _parent_alias="t")]), + ) + def _ast_query_16() -> QueryNode: """Query 16: Remove Max Distinct.""" @@ -797,6 +815,49 @@ def _ast_query_28() -> QueryNode: ) +def _ast_query_29() -> CompoundQueryNode: + """Query 29: Full matching — top-level UNION of two entity scans.""" + entities1 = TableNode("entities") + data1 = ColumnNode("data", _parent_alias="entities") + id1 = ColumnNode("_id", _parent_alias="entities") + sub_e_select = SelectNode([ColumnNode("_id", _parent_alias="index_users_email")]) + sub_e_from = FromNode([TableNode("index_users_email")]) + sub_e_where = WhereNode([ + OperatorNode( + ColumnNode("key", _parent_alias="index_users_email"), + "=", + LiteralNode("test"), + ) + ]) + sub_e = SubqueryNode(QueryNode(_select=sub_e_select, _from=sub_e_from, _where=sub_e_where)) + q1 = QueryNode( + _select=SelectNode([data1]), + _from=FromNode([entities1]), + _where=WhereNode([OperatorNode(id1, "IN", sub_e)]), + ) + + entities2 = TableNode("entities") + data2 = ColumnNode("data", _parent_alias="entities") + id2 = ColumnNode("_id", _parent_alias="entities") + sub_p_select = SelectNode([ColumnNode("_id", _parent_alias="index_users_profile_name")]) + sub_p_from = FromNode([TableNode("index_users_profile_name")]) + sub_p_where = WhereNode([ + OperatorNode( + ColumnNode("key", _parent_alias="index_users_profile_name"), + "=", + LiteralNode("test"), + ) + ]) + sub_p = SubqueryNode(QueryNode(_select=sub_p_select, _from=sub_p_from, _where=sub_p_where)) + q2 = QueryNode( + _select=SelectNode([data2]), + _from=FromNode([entities2]), + _where=WhereNode([OperatorNode(id2, "IN", sub_p)]), + ) + + return CompoundQueryNode(q1, q2, False) + + def _ast_query_30() -> QueryNode: """Query 30: Over Partial Matching.""" # Query pattern: SELECT * FROM table_name WHERE (title=1 AND grade=2) OR (title=2 AND debt=2 AND grade=3) OR (prog=1 AND title=1 AND debt=3) @@ -889,7 +950,59 @@ def _ast_query_31() -> QueryNode: _group_by=group_by_clause, ) -# TODO: Query 32: UNION not supported by parser +def _ast_query_32() -> QueryNode: + """Query 32: Spreadsheet ID 2 — rewrite: UNION of two limited scans in a derived table.""" + place_a = TableNode("place") + branch1 = QueryNode( + _select=SelectNode([ColumnNode("*")]), + _from=FromNode([place_a]), + _where=WhereNode([ + OperatorNode(ColumnNode("select"), "=", LiteralNode(True)), + ]), + _limit=LimitNode(10), + ) + + place_b = TableNode("place") + bookmark = TableNode("bookmark") + exists_inner = QueryNode( + _select=SelectNode([LiteralNode(1)]), + _from=FromNode([bookmark]), + _where=WhereNode([ + OperatorNode( + OperatorNode( + ColumnNode("user"), + "IN", + ListNode([ + LiteralNode(1), + LiteralNode(2), + LiteralNode(3), + LiteralNode(4), + ]), + ), + "AND", + OperatorNode( + ColumnNode("place", _parent_alias="bookmark"), + "=", + ColumnNode("id", _parent_alias="place"), + ), + ), + ]), + ) + exists_sq = SubqueryNode(exists_inner) + branch2 = QueryNode( + _select=SelectNode([ColumnNode("*")]), + _from=FromNode([place_b]), + _where=WhereNode([OperatorNode(exists_sq, "EXISTS")]), + _limit=LimitNode(10), + ) + + union_from = SubqueryNode(CompoundQueryNode(branch1, branch2, False), None) + return QueryNode( + _select=SelectNode([ColumnNode("*")]), + _from=FromNode([union_from]), + _limit=LimitNode(10), + ) + def _ast_query_33() -> QueryNode: """Query 33: Spreadsheet ID 3.""" diff --git a/tests/ast_util.py b/tests/ast_util.py index 274e074..e670b6d 100644 --- a/tests/ast_util.py +++ b/tests/ast_util.py @@ -4,7 +4,7 @@ import textwrap import sqlparse from core.ast.node import ( - Node, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, + Node, QueryNode, CompoundQueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, OrderByNode, OrderByItemNode, LimitNode, OffsetNode, JoinNode, SubqueryNode, VarNode, VarSetNode @@ -199,6 +199,14 @@ def _node_to_string(node: Node, indent: int = 0) -> str: value = node.limit if isinstance(node, LimitNode) else node.offset result.append(f"{prefix}{node_type}: {value}") + elif isinstance(node, CompoundQueryNode): + op = "UNION ALL" if node.is_all else "UNION" + result.append(f"{prefix}compound_query: {op}") + for child in node.children: + child_lines = _node_to_string(child, indent + 1).split('\n') + for line in child_lines: + result.append(line) + elif isinstance(node, QueryNode): # QueryNode: root query or subquery structure, display as "query" # Maintains tree structure consistency by using proper prefix and indentation @@ -238,7 +246,7 @@ def _node_to_string(node: Node, indent: int = 0) -> str: return '\n'.join(result) -def visualize_ast(sql: str, ast: QueryNode, max_sql_width: int = 50) -> str: +def visualize_ast(sql: str, ast: Node, max_sql_width: int = 50) -> str: """ Generate a side-by-side visualization of SQL query and AST structure. @@ -248,7 +256,7 @@ def visualize_ast(sql: str, ast: QueryNode, max_sql_width: int = 50) -> str: Args: sql: SQL query string to visualize - ast: QueryNode representing the parsed AST + ast: Root AST node (QueryNode or CompoundQueryNode, etc.) max_sql_width: Maximum width for SQL column before wrapping (default: 50) Returns: diff --git a/tests/test_parser_formatter_e2e.py b/tests/test_parser_formatter_e2e.py index 3fe8b36..f33eb54 100644 --- a/tests/test_parser_formatter_e2e.py +++ b/tests/test_parser_formatter_e2e.py @@ -126,7 +126,12 @@ def test_query_14(): assert parse(formatted_sql) == parse(original_sql) -# TODO: Query 15 uses UNION, which is not supported by parser yet +def test_query_15(): + query = get_query(15) + original_sql = query["pattern"] + parsed_ast = parser.parse(original_sql) + formatted_sql = formatter.format(parsed_ast) + assert parse(formatted_sql) == parse(original_sql) def test_query_16(): @@ -233,7 +238,12 @@ def test_query_28(): assert parse(formatted_sql) == parse(original_sql) -# TODO: Query 29: Full Matching: UNION not supported by parser +def test_query_29(): + query = get_query(29) + original_sql = query["pattern"] + parsed_ast = parser.parse(original_sql) + formatted_sql = formatter.format(parsed_ast) + assert parse(formatted_sql) == parse(original_sql) def test_query_30(): @@ -252,7 +262,12 @@ def test_query_31(): assert parse(formatted_sql) == parse(original_sql) -# TODO: Query 32: UNION not supported by parser +def test_query_32(): + query = get_query(32) + original_sql = query["rewrite"] + parsed_ast = parser.parse(original_sql) + formatted_sql = formatter.format(parsed_ast) + assert parse(formatted_sql) == parse(original_sql) def test_query_33(): diff --git a/tests/test_query_formatter.py b/tests/test_query_formatter.py index ca7a205..4fa0244 100644 --- a/tests/test_query_formatter.py +++ b/tests/test_query_formatter.py @@ -121,7 +121,11 @@ def test_query_14(): assert parse(sql) == parse(query["pattern"]) -# TODO: Query 15 uses UNION, which is not supported by parser yet +def test_query_15(): + """Query 15: UNION ALL in derived table.""" + query = get_query(15) + sql = formatter.format(get_ast(15)) + assert parse(sql) == parse(query["pattern"]) def test_query_16(): @@ -215,7 +219,11 @@ def test_query_28(): assert parse(sql) == parse(query["pattern"]) -# TODO: Query 29: Full Matching: UNION not supported by parser +def test_query_29(): + """Query 29: Top-level UNION.""" + query = get_query(29) + sql = formatter.format(get_ast(29)) + assert parse(sql) == parse(query["pattern"]) def test_query_30(): @@ -232,7 +240,11 @@ def test_query_31(): assert parse(sql) == parse(query["pattern"]) -# TODO: Query 32: UNION not supported by parser +def test_query_32(): + """Query 32: Spreadsheet ID 2.""" + query = get_query(32) + sql = formatter.format(get_ast(32)) + assert parse(sql) == parse(query["rewrite"]) def test_query_33(): diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py index 69f6b4e..5b164de 100644 --- a/tests/test_query_parser.py +++ b/tests/test_query_parser.py @@ -137,7 +137,12 @@ def test_query_14(): assert parser.parse(sql) == get_ast(14) -# TODO: Query 15 uses UNION, which is not supported by parser yet +def test_query_15(): + """Query 15: UNION ALL in derived table (Calcite).""" + query = get_query(15) + sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(15))) + assert parser.parse(sql) == get_ast(15) def test_query_16(): @@ -244,7 +249,12 @@ def test_query_28(): assert parser.parse(sql) == get_ast(28) -# TODO: Query 29: Full Matching: UNION not supported by parser +def test_query_29(): + """Query 29: Full matching with top-level UNION.""" + query = get_query(29) + sql = query["pattern"] + logger.info("\n" + visualize_ast(sql, get_ast(29))) + assert parser.parse(sql) == get_ast(29) def test_query_30(): @@ -263,7 +273,12 @@ def test_query_31(): assert parser.parse(sql) == get_ast(31) -# TODO: Query 32: UNION not supported by parser +def test_query_32(): + """Query 32: Spreadsheet ID 2 (OR + EXISTS).""" + query = get_query(32) + sql = query["rewrite"] + logger.info("\n" + visualize_ast(sql, get_ast(32))) + assert parser.parse(sql) == get_ast(32) def test_query_33():