Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/ast/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class NodeType(Enum):
LIMIT = "limit"
OFFSET = "offset"
QUERY = "query"
COMPOUND_QUERY = "compound_query"
CASE = "case"
WHEN_THEN = "when_then"

Expand Down
33 changes: 32 additions & 1 deletion core/ast/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,38 @@ def __init__(self,
if _offset:
children.append(_offset)
super().__init__(NodeType.QUERY, children=children, **kwargs)



class CompoundQueryNode(Node):
"""Binary UNION / UNION ALL node: left and right are query-producing Nodes.

Multi-branch chains are represented as left-associative trees, so
A UNION B UNION C becomes CompoundQueryNode(CompoundQueryNode(A, B, False), C, False).
The formatter collapses same-type left-chains back to flat lists to match
mo_sql_parsing's wire format.
"""

def __init__(
self,
_left: Node,
_right: Node,
_is_all: bool = False,
**kwargs,
):
super().__init__(NodeType.COMPOUND_QUERY, children=[_left, _right], **kwargs)
self.left = _left
self.right = _right
self.is_all = _is_all

def __eq__(self, other):
if not isinstance(other, CompoundQueryNode):
return False
return super().__eq__(other) and self.is_all == other.is_all

def __hash__(self):
return hash((super().__hash__(), self.is_all))


class WhenThenNode(Node):
"""Single WHEN ... THEN ... branch of a CASE expression"""
def __init__(self, _when: Node, _then: Node, **kwargs):
Expand Down
86 changes: 76 additions & 10 deletions core/query_formatter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
import re
import mo_sql_parsing as mosql
from core.ast.node import (
QueryNode, SelectNode, FromNode, WhereNode, TableNode, GroupByNode, HavingNode,
OrderByNode,JoinNode
QueryNode,
CompoundQueryNode,
SelectNode,
FromNode,
WhereNode,
TableNode,
GroupByNode,
HavingNode,
OrderByNode,
JoinNode,
SubqueryNode,
)
from core.ast.enums import NodeType, JoinType, SortOrder
from core.ast.node import Node

class QueryFormatter:
def format(self, query: QueryNode) -> str:
# [1] AST (QueryNode) -> JSON
def format(self, query: Node) -> str:
# [1] AST -> JSON
json_query = ast_to_json(query)

# [2] Any (JSON) -> str
Expand All @@ -20,8 +29,41 @@ def format(self, query: QueryNode) -> str:

return sql

def ast_to_json(node: QueryNode) -> dict:
"""Convert QueryNode AST to JSON dictionary for mosql"""
def _collect_union_branches(node: CompoundQueryNode, is_all: bool) -> list:
"""Flatten a left-chain of same-type CompoundQueryNodes into a list.

mo_sql_parsing uses flat lists for chains of the same operator
(e.g. A UNION B UNION C → {'union': [A, B, C]}). Nesting is only
used at type boundaries (e.g. (A UNION ALL B) UNION C). This helper
mirrors that convention so round-trips produce identical JSON.
"""
result = []
if isinstance(node.left, CompoundQueryNode) and node.left.is_all == is_all:
result.extend(_collect_union_branches(node.left, is_all))
else:
result.append(node.left)
if isinstance(node.right, CompoundQueryNode) and node.right.is_all == is_all:
result.extend(_collect_union_branches(node.right, is_all))
else:
result.append(node.right)
return result
Comment thread
HazelYuAhiru marked this conversation as resolved.


def compound_to_mosql_json(node: CompoundQueryNode) -> dict:
"""Convert a CompoundQueryNode binary tree to mo_sql_parsing union/union_all JSON."""
key = 'union_all' if node.is_all else 'union'
branches = _collect_union_branches(node, node.is_all)
return {key: [ast_to_json(b) for b in branches]}


def ast_to_json(node: Node) -> dict:
"""Convert AST to JSON dictionary for mosql."""
if isinstance(node, CompoundQueryNode):
return compound_to_mosql_json(node)
if not isinstance(node, QueryNode):
raise TypeError(
f"ast_to_json: expected QueryNode or CompoundQueryNode, got {type(node).__name__}"
)
result = {}

# process each clause in the query
Expand Down Expand Up @@ -81,11 +123,35 @@ def format_select(select_node: SelectNode) -> dict:
return result


def format_from(from_node: FromNode) -> list:
"""Format FROM clause with explicit JOIN support"""
sources = []
def format_from(from_node: FromNode):
"""Format the FROM clause for mo_sql_parsing.

mo_sql_parsing quirk: a bare (unaliased) UNION/UNION ALL in FROM is
represented as a plain dict at the FROM key, NOT as a one-element list
wrapping a {'value': ...} dict. For example:

SELECT * FROM (SELECT 1 UNION SELECT 2)
→ {"select": ..., "from": {"union": [...]}} ← dict, not list

An aliased variant uses the normal list-of-sources form:
SELECT * FROM (SELECT 1 UNION SELECT 2) t
→ {"select": ..., "from": [{"value": {"union": [...]}, "name": "t"}]}

Everything else (tables, aliased subqueries, JOINs) returns a list.
"""
children = list(from_node.children)


# Special case: single unaliased UNION subquery must be a bare dict
if (
len(children) == 1
and isinstance(children[0], SubqueryNode)
and children[0].alias is None
):
inner = list(children[0].children)[0]
if isinstance(inner, CompoundQueryNode):
return compound_to_mosql_json(inner)

sources = []
if not children:
return sources

Expand Down
Loading
Loading