diff --git a/.gitignore b/.gitignore index b100828f..9210a8a8 100644 --- a/.gitignore +++ b/.gitignore @@ -162,6 +162,7 @@ cython_debug/ #.idea/ # Custom stuff +examples/generated/ docs/source/gallery docs/source/sg_execution_times.rst *.png diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e8a9d61..cac6b6c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- **PTN Format**: Human-readable text format (`.ptn`) for pattern serialization + - `ptn_format.dumps()` / `ptn_format.dump()`: Serialize patterns to text + - `ptn_format.loads()` / `ptn_format.load()`: Deserialize patterns from text + - Format separates quantum instructions and classical feedforward processing + - Timeslice markers `[n]` indicate parallel execution groups + - Pauli measurements use compact notation (`X +`, `Y -`, `Z +`) + - Non-Pauli measurements use plane+angle format (`XY pi/4`) + - Support for node coordinates, logical observables, and inline comments - **Non-Unitary Parity Projection Example**: Added `examples/nonunitary_parity_projection.py` demonstrating measurement-induced entanglement via a 3-node star graph parity projector ### Fixed diff --git a/docs/source/ptn_format.rst b/docs/source/ptn_format.rst new file mode 100644 index 00000000..bd6b1f14 --- /dev/null +++ b/docs/source/ptn_format.rst @@ -0,0 +1,9 @@ +Pattern Text Format +=================== + +:mod:`graphqomb.ptn_format` module +++++++++++++++++++++++++++++++++++ + +.. automodule:: graphqomb.ptn_format + :members: + :member-order: bysource diff --git a/docs/source/references.rst b/docs/source/references.rst index 37a72600..05571380 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -17,6 +17,7 @@ Module reference focus_flow command pattern + ptn_format pauli_frame qompiler scheduler diff --git a/examples/pattern_generation.py b/examples/pattern_generation.py index d691fae3..9e874cd8 100644 --- a/examples/pattern_generation.py +++ b/examples/pattern_generation.py @@ -6,7 +6,11 @@ """ # %% +import sys +from pathlib import Path + from graphqomb.pattern import print_pattern +from graphqomb.ptn_format import dump from graphqomb.qompiler import qompile from graphqomb.random_objects import generate_random_flow_graph @@ -22,3 +26,12 @@ print("pattern depth:", pattern.depth) print("pattern max space:", pattern.max_space) print_pattern(pattern) + +# Dump the pattern in GraphQOMB's text format. +script_path = Path(globals().get("__file__", sys.argv[0])).resolve() +example_dir = script_path.parent +output_dir = example_dir / "generated" +output_dir.mkdir(exist_ok=True) +ptn_path = output_dir / "pattern_generation.ptn" +dump(pattern, ptn_path) +print("wrote pattern:", ptn_path) diff --git a/graphqomb/ptn_format.py b/graphqomb/ptn_format.py new file mode 100644 index 00000000..20d0d92b --- /dev/null +++ b/graphqomb/ptn_format.py @@ -0,0 +1,964 @@ +"""Pattern text format (.ptn) module. + +This module provides: + +- `dump`: Write a pattern to a .ptn file or string. +- `load`: Read a pattern from a .ptn file or string. +- `dumps`: Serialize a pattern to a .ptn format string. +- `loads`: Deserialize a pattern from a .ptn format string. +""" + +from __future__ import annotations + +import math +import operator +import re +from dataclasses import dataclass, field +from io import StringIO +from pathlib import Path +from types import MappingProxyType +from typing import TYPE_CHECKING + +from graphqomb.command import TICK, Command, E, M, N, X, Z +from graphqomb.common import ( + Axis, + AxisMeasBasis, + MeasBasis, + Plane, + PlannerMeasBasis, + Sign, + determine_pauli_axis, + is_clifford_angle, + is_close_angle, +) +from graphqomb.graphstate import BaseGraphState +from graphqomb.pattern import Pattern +from graphqomb.pauli_frame import PauliFrame + +if TYPE_CHECKING: + from collections.abc import Sequence + +PTN_VERSION = 1 + +# Angle formatting/parsing lookup tables +_ANGLE_TO_STR: dict[float, str] = { + 0.0: "0", + math.pi: "pi", + -math.pi: "-pi", + math.pi / 2: "pi/2", + -math.pi / 2: "-pi/2", + math.pi / 4: "pi/4", + -math.pi / 4: "-pi/4", + 3 * math.pi / 2: "3pi/2", + 3 * math.pi / 4: "3pi/4", +} + +_STR_TO_ANGLE: dict[str, float] = { + "0": 0.0, + "pi": math.pi, + "-pi": -math.pi, + "pi/2": math.pi / 2, + "-pi/2": -math.pi / 2, + "pi/4": math.pi / 4, + "-pi/4": -math.pi / 4, + "3pi/2": 3 * math.pi / 2, + "3pi/4": 3 * math.pi / 4, +} + +_PI_PATTERN = re.compile(r"^(-?\d*)pi(?:/(\d+))?$") + + +def _format_angle(angle: float) -> str: + r"""Format angle for output, using pi fractions where appropriate. + + Parameters + ---------- + angle : `float` + The angle in radians. + + Returns + ------- + `str` + Formatted angle string. + """ + candidates = ( + ((ref_angle, label) for ref_angle, label in _ANGLE_TO_STR.items() if is_clifford_angle(ref_angle)) + if is_clifford_angle(angle) + else _ANGLE_TO_STR.items() + ) + ordered_candidates = sorted(candidates, key=lambda item: item[0] < 0 if angle >= 0 else item[0] >= 0) + for ref_angle, label in ordered_candidates: + if is_close_angle(angle, ref_angle): + return label + return f"{angle}" + + +def _parse_angle(s: str) -> float: + r"""Parse angle string to float. + + Parameters + ---------- + s : `str` + Angle string (e.g., "0", "pi", "pi/2", "3pi/4", "1.5707963"). + + Returns + ------- + `float` + The angle in radians. + + Raises + ------ + ValueError + If the angle is not a valid number or pi expression. + """ + s = s.strip() + if s in _STR_TO_ANGLE: + return _STR_TO_ANGLE[s] + + pi_match = _PI_PATTERN.match(s) + if pi_match: + numerator = pi_match.group(1) + denominator = pi_match.group(2) + num = int(numerator) if numerator and numerator != "-" else (1 if numerator != "-" else -1) + denom = int(denominator) if denominator else 1 + if denom == 0: + msg = "Angle denominator cannot be zero" + raise ValueError(msg) + return num * math.pi / denom + + return float(s) + + +def _format_coord(coord: tuple[float, ...]) -> str: + r"""Format coordinate tuple for output. + + Parameters + ---------- + coord : `tuple`\[`float`, ...\] + Coordinate tuple (2D or 3D). + + Returns + ------- + `str` + Space-separated coordinate string. + """ + return " ".join(str(c) for c in coord) + + +def _parse_coord(parts: Sequence[str]) -> tuple[float, ...]: + r"""Parse coordinate from string parts. + + Parameters + ---------- + parts : `list`\[`str`\] + List of coordinate value strings. + + Returns + ------- + `tuple`\[`float`, ...\] + Coordinate tuple. + """ + return tuple(float(p) for p in parts) + + +# ============================================================ +# Serialization (dumps/dump) +# ============================================================ + + +def _write_header(out: StringIO, pattern: Pattern) -> None: + """Write header section to output.""" + out.write(f"# GraphQOMB Pattern Format v{PTN_VERSION}\n") + out.write("\n") + out.write("#======== HEADER ========\n") + out.write(f".version {PTN_VERSION}\n") + + if pattern.input_node_indices: + input_parts = [ + f"{node}:{qidx}" for node, qidx in sorted(pattern.input_node_indices.items(), key=operator.itemgetter(1)) + ] + out.write(f".input {' '.join(input_parts)}\n") + + if pattern.output_node_indices: + output_parts = [ + f"{node}:{qidx}" for node, qidx in sorted(pattern.output_node_indices.items(), key=operator.itemgetter(1)) + ] + out.write(f".output {' '.join(output_parts)}\n") + + out.writelines( + f".coord {node} {_format_coord(coord)}\n" for node, coord in sorted(pattern.input_coordinates.items()) + ) + + +def _write_command(out: StringIO, cmd: Command) -> None: + """Write a single command to output.""" + if isinstance(cmd, N): + if cmd.coordinate is not None: + out.write(f"N {cmd.node} {_format_coord(cmd.coordinate)}\n") + else: + out.write(f"N {cmd.node}\n") + elif isinstance(cmd, E): + out.write(f"E {cmd.nodes[0]} {cmd.nodes[1]}\n") + elif isinstance(cmd, M): + _write_measurement(out, cmd) + elif isinstance(cmd, X): + out.write(f"X {cmd.node}\n") + elif isinstance(cmd, Z): + out.write(f"Z {cmd.node}\n") + + +def _is_positive_pauli_measurement(meas_basis: MeasBasis, pauli_axis: Axis) -> bool: + """Return whether a Pauli measurement is on the positive eigenbasis. + + Returns + ------- + bool + True if the measurement basis is the positive Pauli eigenbasis. + """ + angle = meas_basis.angle + plane = meas_basis.plane + if pauli_axis == Axis.X: + positive_angle = math.pi / 2 if plane == Plane.XZ else 0.0 + elif pauli_axis == Axis.Y: + positive_angle = math.pi / 2 + else: + positive_angle = 0.0 + return is_close_angle(angle, positive_angle) + + +def _write_measurement(out: StringIO, cmd: M) -> None: + """Write measurement command with appropriate format.""" + pauli_axis = determine_pauli_axis(cmd.meas_basis) + if pauli_axis is not None: + sign = "+" if _is_positive_pauli_measurement(cmd.meas_basis, pauli_axis) else "-" + out.write(f"M {cmd.node} {pauli_axis.name} {sign}\n") + else: + plane_name = cmd.meas_basis.plane.name + angle_str = _format_angle(cmd.meas_basis.angle) + out.write(f"M {cmd.node} {plane_name} {angle_str}\n") + + +def _write_quantum_section(out: StringIO, pattern: Pattern) -> None: + """Write quantum instructions section to output.""" + out.write("\n") + out.write("#======== QUANTUM ========\n") + + timeslice = 0 + current_slice_commands: list[Command] = [] + + def write_slice(slice_num: int, commands: list[Command]) -> None: + out.write(f"[{slice_num}]\n") + for cmd in commands: + _write_command(out, cmd) + + for cmd in pattern.commands: + if isinstance(cmd, TICK): + write_slice(timeslice, current_slice_commands) + current_slice_commands = [] + timeslice += 1 + else: + current_slice_commands.append(cmd) + + if current_slice_commands or timeslice == 0 or (pattern.commands and isinstance(pattern.commands[-1], TICK)): + write_slice(timeslice, current_slice_commands) + + +def _write_classical_section(out: StringIO, pauli_frame: PauliFrame) -> None: + """Write classical frame section to output.""" + out.write("\n") + out.write("#======== CLASSICAL ========\n") + + for source, targets in sorted(pauli_frame.xflow.items()): + if targets: + targets_str = " ".join(str(t) for t in sorted(targets)) + out.write(f".xflow {source} -> {targets_str}\n") + + for source, targets in sorted(pauli_frame.zflow.items()): + if targets: + targets_str = " ".join(str(t) for t in sorted(targets)) + out.write(f".zflow {source} -> {targets_str}\n") + + for group in pauli_frame.parity_check_group: + if group: + group_str = " ".join(str(n) for n in sorted(group)) + out.write(f".detector {group_str}\n") + + for logical_idx, nodes in sorted(pauli_frame.logical_observables.items()): + if nodes: + nodes_str = " ".join(str(n) for n in sorted(nodes)) + out.write(f".observable {logical_idx} -> {nodes_str}\n") + + +def dumps(pattern: Pattern) -> str: + """Serialize a pattern to a .ptn format string. + + Parameters + ---------- + pattern : `Pattern` + The pattern to serialize. + + Returns + ------- + `str` + The .ptn format string. + """ + out = StringIO() + _write_header(out, pattern) + _write_quantum_section(out, pattern) + _write_classical_section(out, pattern.pauli_frame) + return out.getvalue() + + +def dump(pattern: Pattern, file: Path | str) -> None: + """Write a pattern to a .ptn file. + + Parameters + ---------- + pattern : `Pattern` + The pattern to write. + file : `pathlib.Path` | `str` + The file path to write to. + """ + path = Path(file) + path.write_text(dumps(pattern), encoding="utf-8") + + +# ============================================================ +# Deserialization (loads/load) +# ============================================================ + + +def _parse_int(value: str, label: str) -> int: + """Parse an integer field. + + Returns + ------- + `int` + Parsed integer. + + Raises + ------ + ValueError + If the field is not an integer. + """ + try: + return int(value) + except ValueError as exc: + msg = f"Invalid {label}: {value!r}" + raise ValueError(msg) from exc + + +def _parse_node_qubit_pairs(parts: Sequence[str]) -> dict[int, int]: + r"""Parse node:qubit pairs from string parts. + + Parameters + ---------- + parts : `list`\[`str`\] + List of "node:qubit" strings. + + Returns + ------- + `dict`\[`int`, `int`\] + Mapping from node to qubit index. + + Raises + ------ + ValueError + If any pair is malformed or duplicated. + """ + result: dict[int, int] = {} + for part in parts: + pair = part.split(":") + if len(pair) != 2: # noqa: PLR2004 + msg = f"Invalid node:qubit pair: {part!r}" + raise ValueError(msg) + node_str, qidx_str = pair + node = _parse_int(node_str, "node") + qidx = _parse_int(qidx_str, "qubit index") + if node in result: + msg = f"Duplicate node mapping: {node}" + raise ValueError(msg) + if qidx in result.values(): + msg = f"Duplicate qubit index: {qidx}" + raise ValueError(msg) + result[node] = qidx + return result + + +def _parse_node_set(parts: Sequence[str], label: str) -> set[int]: + r"""Parse a non-empty set of node ids. + + Returns + ------- + `set`\[`int`\] + Parsed node ids. + + Raises + ------ + ValueError + If the node set is empty or contains invalid integers. + """ + if not parts: + msg = f"{label} requires at least one node" + raise ValueError(msg) + return {_parse_int(part, "node") for part in parts} + + +def _parse_arrow_mapping(line: str, label: str) -> tuple[int, set[int]]: + r"""Parse a flow line (xflow or zflow). + + Parameters + ---------- + line : `str` + The flow line content after ".xflow" or ".zflow". + + Returns + ------- + `tuple`\[`int`, `set`\[`int`\]\] + Source node and set of target nodes. + + Raises + ------ + ValueError + If the mapping is malformed. + """ + parts = line.split("->") + if len(parts) != 2: # noqa: PLR2004 + msg = f"{label} must contain exactly one '->'" + raise ValueError(msg) + source_part = parts[0].strip() + target_parts = parts[1].strip().split() + if not source_part: + msg = f"{label} requires a source node" + raise ValueError(msg) + source = _parse_int(source_part, "source node") + targets = _parse_node_set(target_parts, f"{label} targets") + return source, targets + + +def _empty_node_index_map() -> dict[int, int]: + r"""Return an empty node-to-qubit-index map. + + Returns + ------- + `dict`\[`int`, `int`\] + Empty node-to-qubit-index map. + """ + return {} + + +def _empty_coordinates() -> dict[int, tuple[float, ...]]: + r"""Return an empty coordinate map. + + Returns + ------- + `dict`\[`int`, `tuple`\[`float`, ...\]\] + Empty coordinate map. + """ + return {} + + +def _empty_commands() -> list[Command]: + r"""Return an empty command list. + + Returns + ------- + `list`\[`Command`\] + Empty command list. + """ + return [] + + +def _empty_node_set_map() -> dict[int, set[int]]: + r"""Return an empty node-to-node-set map. + + Returns + ------- + `dict`\[`int`, `set`\[`int`\]\] + Empty node-to-node-set map. + """ + return {} + + +def _empty_node_groups() -> list[set[int]]: + r"""Return an empty node group list. + + Returns + ------- + `list`\[`set`\[`int`\]\] + Empty node group list. + """ + return [] + + +@dataclass(slots=True) +class _PatternData: + """Container for parsed pattern data from .ptn format. + + Attributes + ---------- + input_node_indices : `dict`[`int`, `int`] + Mapping from node to qubit index for input nodes. + output_node_indices : `dict`[`int`, `int`] + Mapping from node to qubit index for output nodes. + input_coordinates : `dict`[`int`, `tuple`[`float`, ...]] + Coordinates for input nodes. + commands : `list`[`Command`] + List of quantum commands. + xflow : `dict`[`int`, `set`[`int`]] + X correction flow mapping. + zflow : `dict`[`int`, `set`[`int`]] + Z correction flow mapping. + parity_check_groups : `list`[`set`[`int`]] + Parity check groups for error detection. + """ + + input_node_indices: dict[int, int] = field(default_factory=_empty_node_index_map) + output_node_indices: dict[int, int] = field(default_factory=_empty_node_index_map) + input_coordinates: dict[int, tuple[float, ...]] = field(default_factory=_empty_coordinates) + commands: list[Command] = field(default_factory=_empty_commands) + xflow: dict[int, set[int]] = field(default_factory=_empty_node_set_map) + zflow: dict[int, set[int]] = field(default_factory=_empty_node_set_map) + parity_check_groups: list[set[int]] = field(default_factory=_empty_node_groups) + logical_observables: dict[int, set[int]] = field(default_factory=_empty_node_set_map) + + +@dataclass(slots=True) +class _LoadedGraphState(BaseGraphState): + """Read-only graph state reconstructed from a .ptn file.""" + + _input_node_indices: dict[int, int] + _output_node_indices: dict[int, int] + _physical_nodes: set[int] + _physical_edges: set[tuple[int, int]] + _meas_bases: dict[int, MeasBasis] + _coordinates: dict[int, tuple[float, ...]] + _neighbors: dict[int, set[int]] = field(init=False, repr=False) + + def __post_init__(self) -> None: + self._input_node_indices = dict(self._input_node_indices) + self._output_node_indices = dict(self._output_node_indices) + self._physical_nodes = set(self._physical_nodes) + self._physical_edges = { + (node1, node2) if node1 < node2 else (node2, node1) for node1, node2 in self._physical_edges + } + self._meas_bases = dict(self._meas_bases) + self._coordinates = dict(self._coordinates) + self._neighbors: dict[int, set[int]] = {node: set() for node in self._physical_nodes} + for node1, node2 in self._physical_edges: + self._neighbors.setdefault(node1, set()).add(node2) + self._neighbors.setdefault(node2, set()).add(node1) + + @property + def input_node_indices(self) -> dict[int, int]: + return self._input_node_indices.copy() + + @property + def output_node_indices(self) -> dict[int, int]: + return self._output_node_indices.copy() + + @property + def physical_nodes(self) -> set[int]: + return set(self._physical_nodes) + + @property + def physical_edges(self) -> set[tuple[int, int]]: + return set(self._physical_edges) + + @property + def meas_bases(self) -> MappingProxyType[int, MeasBasis]: + return MappingProxyType(self._meas_bases) + + @property + def coordinates(self) -> dict[int, tuple[float, ...]]: + return self._coordinates.copy() + + def add_physical_node(self, coordinate: tuple[float, ...] | None = None) -> int: + msg = "Loaded .ptn graph states are read-only" + raise NotImplementedError(msg) + + def add_physical_edge(self, node1: int, node2: int) -> None: + msg = "Loaded .ptn graph states are read-only" + raise NotImplementedError(msg) + + def register_input(self, node: int, q_index: int) -> None: + msg = "Loaded .ptn graph states are read-only" + raise NotImplementedError(msg) + + def register_output(self, node: int, q_index: int) -> None: + msg = "Loaded .ptn graph states are read-only" + raise NotImplementedError(msg) + + def assign_meas_basis(self, node: int, meas_basis: MeasBasis) -> None: + msg = "Loaded .ptn graph states are read-only" + raise NotImplementedError(msg) + + def neighbors(self, node: int) -> set[int]: + if node not in self._physical_nodes: + msg = f"Node does not exist node={node}" + raise ValueError(msg) + return self._neighbors.get(node, set()).copy() + + def check_canonical_form(self) -> None: + for node in self._physical_nodes - self._output_node_indices.keys(): + if node not in self._meas_bases: + msg = f"Measurement basis not set for node {node}" + raise ValueError(msg) + + +def _command_nodes(cmd: Command) -> set[int]: + r"""Return node ids referenced by a command. + + Returns + ------- + `set`\[`int`\] + Node ids referenced by the command. + """ + if isinstance(cmd, (N, M, X, Z)): + return {cmd.node} + if isinstance(cmd, E): + return set(cmd.nodes) + return set() + + +def _build_pattern(data: _PatternData) -> Pattern: + """Build a Pattern from parsed .ptn data. + + Returns + ------- + `Pattern` + Reconstructed pattern. + + Raises + ------ + ValueError + If parsed commands contain invalid graph structure. + """ + nodes: set[int] = set(data.input_node_indices) | set(data.output_node_indices) | set(data.input_coordinates) + edges: set[tuple[int, int]] = set() + meas_bases: dict[int, MeasBasis] = {} + coordinates = dict(data.input_coordinates) + + for cmd in data.commands: + nodes.update(_command_nodes(cmd)) + if isinstance(cmd, E): + node1, node2 = cmd.nodes + edge = (node1, node2) if node1 < node2 else (node2, node1) + if edge[0] == edge[1]: + msg = f"Self-loop edge is not allowed: {cmd.nodes}" + raise ValueError(msg) + edges.add(edge) + elif isinstance(cmd, M): + meas_bases[cmd.node] = cmd.meas_basis + elif isinstance(cmd, N) and cmd.coordinate is not None: + coordinates[cmd.node] = cmd.coordinate + + for source, targets in data.xflow.items(): + nodes.add(source) + nodes.update(targets) + for source, targets in data.zflow.items(): + nodes.add(source) + nodes.update(targets) + for group in data.parity_check_groups: + nodes.update(group) + for nodes_in_observable in data.logical_observables.values(): + nodes.update(nodes_in_observable) + + graphstate = _LoadedGraphState( + _input_node_indices=data.input_node_indices, + _output_node_indices=data.output_node_indices, + _physical_nodes=nodes, + _physical_edges=edges, + _meas_bases=meas_bases, + _coordinates=coordinates, + ) + pauli_frame = PauliFrame( + graphstate, + data.xflow, + data.zflow, + parity_check_group=data.parity_check_groups, + logical_observables=data.logical_observables, + ) + return Pattern( + input_node_indices=dict(data.input_node_indices), + output_node_indices=dict(data.output_node_indices), + commands=tuple(data.commands), + pauli_frame=pauli_frame, + input_coordinates=dict(data.input_coordinates), + ) + + +class _Parser: + """Internal parser state for loads().""" + + def __init__(self) -> None: + self.result = _PatternData() + self.current_timeslice = -1 + self.version_found = False + + def parse(self, s: str) -> Pattern: + r"""Parse the input string and return Pattern. + + Parameters + ---------- + s : `str` + The .ptn format string. + + Returns + ------- + `Pattern` + Loaded measurement pattern. + + Raises + ------ + ValueError + If the format is invalid or unsupported version. + """ + for line_num, raw_line in enumerate(s.splitlines(), 1): + self._parse_line(line_num, raw_line) + + if not self.version_found: + msg = "Missing .version directive" + raise ValueError(msg) + + return _build_pattern(self.result) + + def _parse_line(self, line_num: int, raw_line: str) -> None: + """Parse a single line. + + Raises + ------ + ValueError + If the line is malformed. + """ + line = raw_line.split("#", 1)[0].strip() + if not line: + return + + try: + if line.startswith("."): + self._parse_directive(line) + elif line.startswith("[") and line.endswith("]"): + self._parse_timeslice(line) + else: + self._parse_command(line) + except ValueError as exc: + msg = f"Line {line_num}: {exc}" + raise ValueError(msg) from exc + + def _parse_directive(self, line: str) -> None: + """Parse a directive line (starts with '.'). + + Raises + ------ + ValueError + If the directive is invalid. + """ + parts = line.split(maxsplit=1) + directive = parts[0] + content = parts[1] if len(parts) > 1 else "" + + if directive == ".version": + self._handle_version(content) + elif directive == ".input": + self.result.input_node_indices = _parse_node_qubit_pairs(content.split()) + elif directive == ".output": + self.result.output_node_indices = _parse_node_qubit_pairs(content.split()) + elif directive == ".coord": + self._handle_coord(content) + elif directive == ".xflow": + source, targets = _parse_arrow_mapping(content, ".xflow") + self.result.xflow[source] = targets + elif directive == ".zflow": + source, targets = _parse_arrow_mapping(content, ".zflow") + self.result.zflow[source] = targets + elif directive == ".detector": + self.result.parity_check_groups.append(_parse_node_set(content.split(), ".detector")) + elif directive == ".observable": + logical_idx, nodes = _parse_arrow_mapping(content, ".observable") + self.result.logical_observables[logical_idx] = nodes + else: + msg = f"Unknown directive: {directive}" + raise ValueError(msg) + + def _handle_version(self, content: str) -> None: + r"""Handle .version directive. + + Raises + ------ + ValueError + If the version is unsupported. + """ + version = _parse_int(content, "version") + if version != PTN_VERSION: + msg = f"Unsupported .ptn version: {version} (expected {PTN_VERSION})" + raise ValueError(msg) + self.version_found = True + + def _handle_coord(self, content: str) -> None: + """Handle .coord directive. + + Raises + ------ + ValueError + If the coordinate directive is malformed. + """ + coord_parts = content.split() + if len(coord_parts) not in {3, 4}: + msg = ".coord requires a node and 2D or 3D coordinates" + raise ValueError(msg) + node = _parse_int(coord_parts[0], "node") + coord = _parse_coord(coord_parts[1:]) + self.result.input_coordinates[node] = coord + + def _parse_timeslice(self, line: str) -> None: + """Parse timeslice marker [n]. + + Raises + ------ + ValueError + If the timeslice marker is malformed. + """ + slice_num = _parse_int(line[1:-1], "timeslice") + if slice_num < 0: + msg = "Timeslice must be non-negative" + raise ValueError(msg) + if slice_num < self.current_timeslice: + msg = "Timeslices must be monotonically increasing" + raise ValueError(msg) + ticks_to_insert = slice_num if self.current_timeslice < 0 else slice_num - self.current_timeslice + self.result.commands.extend(TICK() for _ in range(ticks_to_insert)) + self.current_timeslice = slice_num + + def _parse_command(self, line: str) -> None: + r"""Parse a command line. + + Raises + ------ + ValueError + If the command type is unknown. + """ + parts = line.split() + cmd_type = parts[0] + + if cmd_type == "N": + self._parse_n_command(parts) + elif cmd_type == "E": + self._parse_e_command(parts) + elif cmd_type == "M": + self._parse_m_command(parts) + elif cmd_type == "X": + if len(parts) != 2: # noqa: PLR2004 + msg = "X command requires exactly one node" + raise ValueError(msg) + self.result.commands.append(X(node=_parse_int(parts[1], "node"))) + elif cmd_type == "Z": + if len(parts) != 2: # noqa: PLR2004 + msg = "Z command requires exactly one node" + raise ValueError(msg) + self.result.commands.append(Z(node=_parse_int(parts[1], "node"))) + else: + msg = f"Unknown command: {cmd_type}" + raise ValueError(msg) + + def _parse_n_command(self, parts: Sequence[str]) -> None: + """Parse N (node) command. + + Raises + ------ + ValueError + If the command is malformed. + """ + if len(parts) not in {2, 4, 5}: + msg = "N command requires a node and optional 2D or 3D coordinates" + raise ValueError(msg) + node = _parse_int(parts[1], "node") + coord: tuple[float, ...] | None = _parse_coord(parts[2:]) if len(parts) > 2 else None # noqa: PLR2004 + self.result.commands.append(N(node=node, coordinate=coord)) + + def _parse_e_command(self, parts: Sequence[str]) -> None: + """Parse E (entangle) command. + + Raises + ------ + ValueError + If the command is malformed. + """ + if len(parts) != 3: # noqa: PLR2004 + msg = "E command requires exactly two nodes" + raise ValueError(msg) + node1 = _parse_int(parts[1], "node") + node2 = _parse_int(parts[2], "node") + self.result.commands.append(E(nodes=(node1, node2))) + + def _parse_m_command(self, parts: Sequence[str]) -> None: + """Parse M (measure) command. + + Raises + ------ + ValueError + If the command is malformed. + """ + if len(parts) != 4: # noqa: PLR2004 + msg = "M command requires a node, basis, and angle/sign" + raise ValueError(msg) + node = _parse_int(parts[1], "node") + basis_spec = parts[2] + meas_basis: MeasBasis + + if basis_spec in {"X", "Y", "Z"}: + sign_str = parts[3] + if sign_str not in {"+", "-"}: + msg = f"Invalid Pauli measurement sign: {sign_str!r}" + raise ValueError(msg) + sign = Sign.PLUS if sign_str == "+" else Sign.MINUS + axis = Axis[basis_spec] + meas_basis = AxisMeasBasis(axis, sign) + else: + try: + plane = Plane[basis_spec] + except KeyError as exc: + msg = f"Invalid measurement basis: {basis_spec!r}" + raise ValueError(msg) from exc + angle = _parse_angle(parts[3]) + meas_basis = PlannerMeasBasis(plane, angle) + + self.result.commands.append(M(node=node, meas_basis=meas_basis)) + + +def loads(s: str) -> Pattern: + """Deserialize a .ptn format string to a pattern. + + Parameters + ---------- + s : `str` + The .ptn format string. + + Returns + ------- + `Pattern` + The loaded pattern. + + See Also + -------- + _Parser.parse : Internal parser that may raise ValueError for invalid input. + """ + return _Parser().parse(s) + + +def load(file: Path | str) -> Pattern: + """Read a pattern from a .ptn file. + + Parameters + ---------- + file : `pathlib.Path` | `str` + The file path to read from. + + Returns + ------- + `Pattern` + The loaded pattern. + See `loads` for details. + """ + path = Path(file) + return loads(path.read_text(encoding="utf-8")) diff --git a/tests/test_ptn_format.py b/tests/test_ptn_format.py new file mode 100644 index 00000000..99bc0371 --- /dev/null +++ b/tests/test_ptn_format.py @@ -0,0 +1,684 @@ +"""Tests for ptn_format module.""" + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Any + +import pytest + +from graphqomb.command import TICK, E, M, N, X, Z +from graphqomb.common import Axis, AxisMeasBasis, Plane, PlannerMeasBasis, Sign, determine_pauli_axis +from graphqomb.graphstate import GraphState +from graphqomb.pattern import Pattern +from graphqomb.pauli_frame import PauliFrame +from graphqomb.ptn_format import ( + dump, + dumps, + load, + loads, +) +from graphqomb.qompiler import qompile +from graphqomb.stim_compiler import stim_compile + +if TYPE_CHECKING: + from pathlib import Path + + from graphqomb.command import Command + + +def create_simple_pattern() -> Pattern: + """Create a simple pattern for testing. + + Returns + ------- + Pattern + A compiled MBQC pattern for testing. + """ + graph = GraphState() + in_node = graph.add_physical_node(coordinate=(0.0, 0.0)) + mid_node = graph.add_physical_node(coordinate=(1.0, 0.0)) + out_node = graph.add_physical_node(coordinate=(2.0, 0.0)) + + graph.register_input(in_node, 0) + graph.register_output(out_node, 0) + + graph.add_physical_edge(in_node, mid_node) + graph.add_physical_edge(mid_node, out_node) + + graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(mid_node, PlannerMeasBasis(Plane.XY, math.pi / 2)) + + xflow = {in_node: {mid_node}, mid_node: {out_node}} + return qompile(graph, xflow) + + +def create_measured_output_pattern_with_observable() -> Pattern: + """Create a stim-compatible pattern with a logical observable.""" + graph = GraphState() + in_node = graph.add_physical_node(coordinate=(10.0, 0.0)) + out_node = graph.add_physical_node(coordinate=(20.0, 0.0)) + + graph.register_input(in_node, 0) + graph.register_output(out_node, 0) + graph.add_physical_edge(in_node, out_node) + graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(out_node, PlannerMeasBasis(Plane.XY, 0.0)) + + return qompile(graph, {in_node: {out_node}}, logical_observables={0: {in_node}}) + + +def create_measured_output_pattern_with_detector() -> Pattern: + """Create a stim-compatible pattern with a detector.""" + graph = GraphState() + in_node = graph.add_physical_node(coordinate=(30.0, 0.0)) + out_node = graph.add_physical_node(coordinate=(40.0, 0.0)) + + graph.register_input(in_node, 0) + graph.register_output(out_node, 0) + graph.add_physical_edge(in_node, out_node) + graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(out_node, PlannerMeasBasis(Plane.XY, 0.0)) + + return qompile(graph, {in_node: {out_node}}, parity_check_group=[{in_node}]) + + +def command_signature(cmd: Command) -> tuple[Any, ...]: # noqa: PLR0911 + """Return a behavior-level signature for a pattern command.""" + if isinstance(cmd, N): + return ("N", cmd.node, cmd.coordinate) + if isinstance(cmd, E): + return ("E", cmd.nodes) + if isinstance(cmd, M): + pauli_axis = determine_pauli_axis(cmd.meas_basis) + if pauli_axis is not None: + return ("M", cmd.node, pauli_axis, cmd.meas_basis.angle) + return ("M", cmd.node, cmd.meas_basis.plane, cmd.meas_basis.angle) + if isinstance(cmd, X): + return ("X", cmd.node) + if isinstance(cmd, Z): + return ("Z", cmd.node) + if isinstance(cmd, TICK): + return ("TICK",) + return ("UNKNOWN", type(cmd).__name__) + + +def assert_pattern_equivalent(actual: Pattern, expected: Pattern) -> None: + """Assert that serialized pattern content survived a roundtrip.""" + assert actual.input_node_indices == expected.input_node_indices + assert actual.output_node_indices == expected.output_node_indices + assert actual.input_coordinates == expected.input_coordinates + assert actual.pauli_frame.xflow == expected.pauli_frame.xflow + assert actual.pauli_frame.zflow == expected.pauli_frame.zflow + assert actual.pauli_frame.parity_check_group == expected.pauli_frame.parity_check_group + assert actual.pauli_frame.logical_observables == expected.pauli_frame.logical_observables + assert [command_signature(cmd) for cmd in actual.commands] == [command_signature(cmd) for cmd in expected.commands] + + +def test_dumps_basic() -> None: + """Test basic pattern serialization.""" + pattern = create_simple_pattern() + ptn_str = dumps(pattern) + + assert ".version 1" in ptn_str + assert ".input" in ptn_str + assert ".output" in ptn_str + assert "#======== QUANTUM ========" in ptn_str + assert "#======== CLASSICAL ========" in ptn_str + + +def test_dumps_contains_commands() -> None: + """Test that dumps includes all command types.""" + pattern = create_simple_pattern() + ptn_str = dumps(pattern) + + # Check for command types + assert "N " in ptn_str # Node creation + assert "E " in ptn_str # Entanglement + assert "M " in ptn_str # Measurement + + +def test_dumps_coordinates() -> None: + """Test that coordinates are correctly serialized.""" + pattern = create_simple_pattern() + ptn_str = dumps(pattern) + + assert ".coord 0 0.0 0.0" in ptn_str + + +def test_dumps_pauli_measurements() -> None: + """Test that Pauli measurements are correctly formatted with +/- signs.""" + pattern = create_simple_pattern() + ptn_str = dumps(pattern) + + # X measurement (XY plane, angle 0) should be formatted as "X +" + assert "M 0 X +" in ptn_str + # Y measurement (XY plane, angle pi/2) should be formatted as "Y +" + assert "M 1 Y +" in ptn_str + + +def test_dumps_formats_near_known_angle() -> None: + """Angles near known constants should serialize canonically.""" + graph = GraphState() + node = graph.add_physical_node() + pattern = Pattern( + input_node_indices={}, + output_node_indices={}, + commands=(M(node, PlannerMeasBasis(Plane.XY, math.pi / 4 + 1e-10)),), + pauli_frame=PauliFrame(graph, xflow={}, zflow={}), + ) + + assert f"M {node} XY pi/4" in dumps(pattern) + + +def test_dumps_preserves_xz_plane_x_pauli_sign() -> None: + """Plane.XZ X measurements should serialize with the correct Pauli sign.""" + graph = GraphState() + plus_node = graph.add_physical_node() + minus_node = graph.add_physical_node() + pattern = Pattern( + input_node_indices={}, + output_node_indices={}, + commands=( + M(plus_node, PlannerMeasBasis(Plane.XZ, math.pi / 2)), + M(minus_node, PlannerMeasBasis(Plane.XZ, 3 * math.pi / 2)), + ), + pauli_frame=PauliFrame(graph, xflow={}, zflow={}), + ) + + ptn_str = dumps(pattern) + result = loads(ptn_str) + + assert f"M {plus_node} X +" in ptn_str + assert f"M {minus_node} X -" in ptn_str + measurements = {cmd.node: cmd for cmd in result.commands if isinstance(cmd, M)} + plus_basis = measurements[plus_node].meas_basis + minus_basis = measurements[minus_node].meas_basis + assert isinstance(plus_basis, AxisMeasBasis) + assert plus_basis.axis == Axis.X + assert plus_basis.sign == Sign.PLUS + assert isinstance(minus_basis, AxisMeasBasis) + assert minus_basis.axis == Axis.X + assert minus_basis.sign == Sign.MINUS + + +def test_dumps_preserves_consecutive_trailing_ticks() -> None: + """Empty final timeslices should preserve consecutive trailing TICK commands.""" + graph = GraphState() + node = graph.add_physical_node() + pattern = Pattern( + input_node_indices={}, + output_node_indices={}, + commands=(N(node), TICK(), TICK()), + pauli_frame=PauliFrame(graph, xflow={}, zflow={}), + ) + + ptn_str = dumps(pattern) + result = loads(ptn_str) + + assert f"[0]\nN {node}\n[1]\n[2]\n" in ptn_str + assert [command_signature(cmd) for cmd in result.commands] == [ + ("N", node, None), + ("TICK",), + ("TICK",), + ] + + +def test_loads_basic() -> None: + """Test basic pattern deserialization.""" + ptn_str = """# Test pattern +.version 1 +.input 0:0 +.output 2:0 +.coord 0 0.0 0.0 + +#======== QUANTUM ======== +[0] +N 1 +E 0 1 +M 0 XY 0 + +#======== CLASSICAL ======== +.xflow 0 -> 1 +""" + result = loads(ptn_str) + + assert result.input_node_indices == {0: 0} + assert result.output_node_indices == {2: 0} + assert result.input_coordinates == {0: (0.0, 0.0)} + + +def test_loads_commands() -> None: + """Test that commands are correctly parsed.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 2:0 + +[0] +N 1 +N 3 1.0 2.0 +E 0 1 +M 0 XY 0 +M 1 XY pi/2 +X 2 +Z 2 +""" + result = loads(ptn_str) + + commands = result.commands + # Check command types + assert any(isinstance(c, N) and c.node == 1 for c in commands) + assert any(isinstance(c, N) and c.node == 3 and c.coordinate == (1.0, 2.0) for c in commands) + assert any(isinstance(c, E) and c.nodes == (0, 1) for c in commands) + assert any(isinstance(c, M) and c.node == 0 for c in commands) + assert any(isinstance(c, X) and c.node == 2 for c in commands) + assert any(isinstance(c, Z) and c.node == 2 for c in commands) + + +def test_loads_timeslices() -> None: + """Test that timeslice markers generate TICK commands.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 1:0 + +[0] +E 0 1 +[1] +M 0 XY 0 +[2] +M 1 XY 0 +""" + result = loads(ptn_str) + + # Count TICK commands + tick_count = sum(1 for c in result.commands if isinstance(c, TICK)) + assert tick_count == 2 + + +def test_loads_initial_empty_timeslices() -> None: + """Test that an initial non-zero timeslice marker creates the matching TICKs.""" + ptn_str = """ +.version 1 + +[2] +M 0 XY 0 +""" + result = loads(ptn_str) + + assert [command_signature(cmd) for cmd in result.commands] == [ + ("TICK",), + ("TICK",), + ("M", 0, Axis.X, 0), + ] + + +def test_loads_angle_parsing() -> None: + """Test various angle format parsing.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 5:0 + +[0] +M 0 XY 0 +M 1 XY pi +M 2 XY pi/2 +M 3 XY pi/4 +M 4 XY 3pi/4 +""" + result = loads(ptn_str) + + measurements = [c for c in result.commands if isinstance(c, M)] + angles = {m.node: m.meas_basis.angle for m in measurements} + + assert math.isclose(angles[0], 0.0) + assert math.isclose(angles[1], math.pi) + assert math.isclose(angles[2], math.pi / 2) + assert math.isclose(angles[3], math.pi / 4) + assert math.isclose(angles[4], 3 * math.pi / 4) + + +def test_loads_pauli_measurements() -> None: + """Test parsing of Pauli measurement format (X/Y/Z +/-).""" + ptn_str = """ +.version 1 +.input 0:0 +.output 6:0 + +[0] +M 0 X + +M 1 X - +M 2 Y + +M 3 Y - +M 4 Z + +M 5 Z - +""" + result = loads(ptn_str) + + measurements = [c for c in result.commands if isinstance(c, M)] + assert len(measurements) == 6 + + # Check that Pauli measurements are parsed correctly + m0 = next(m for m in measurements if m.node == 0) + assert math.isclose(m0.meas_basis.angle, 0.0) # X + + + m1 = next(m for m in measurements if m.node == 1) + assert math.isclose(m1.meas_basis.angle, math.pi) # X - + + m2 = next(m for m in measurements if m.node == 2) + assert math.isclose(m2.meas_basis.angle, math.pi / 2) # Y + + + m3 = next(m for m in measurements if m.node == 3) + assert math.isclose(m3.meas_basis.angle, 3 * math.pi / 2) # Y - + + m4 = next(m for m in measurements if m.node == 4) + assert math.isclose(m4.meas_basis.angle, 0.0) # Z + + + m5 = next(m for m in measurements if m.node == 5) + assert math.isclose(m5.meas_basis.angle, math.pi) # Z - + + +def test_loads_flow_parsing() -> None: + """Test xflow and zflow parsing.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 2:0 + +[0] +N 1 +E 0 1 +M 0 XY 0 + +.xflow 0 -> 1 2 +.zflow 0 -> 3 4 +""" + result = loads(ptn_str) + + assert result.pauli_frame.xflow == {0: {1, 2}} + assert result.pauli_frame.zflow == {0: {3, 4}} + + +def test_loads_detector_parsing() -> None: + """Test detector (parity check group) parsing.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 2:0 + +[0] +M 0 XY 0 + +.detector 0 1 2 +.detector 3 4 +""" + result = loads(ptn_str) + + assert len(result.pauli_frame.parity_check_group) == 2 + assert result.pauli_frame.parity_check_group[0] == {0, 1, 2} + assert result.pauli_frame.parity_check_group[1] == {3, 4} + + +def test_loads_observable_parsing() -> None: + """Test logical observable parsing.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 1:0 + +[0] +M 0 X + +M 1 X + + +.observable 0 -> 0 1 +""" + result = loads(ptn_str) + + assert result.pauli_frame.logical_observables == {0: {0, 1}} + + +def test_loads_missing_version() -> None: + """Test that missing version raises ValueError.""" + ptn_str = """ +.input 0:0 +.output 1:0 +[0] +M 0 XY 0 +""" + with pytest.raises(ValueError, match=r"Missing \.version directive"): + loads(ptn_str) + + +def test_loads_unsupported_version() -> None: + """Test that unsupported version raises ValueError.""" + ptn_str = """ +.version 99 +.input 0:0 +.output 1:0 +""" + with pytest.raises(ValueError, match=r"Unsupported \.ptn version"): + loads(ptn_str) + + +def test_loads_unknown_command() -> None: + """Test that unknown command raises ValueError.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 1:0 + +[0] +UNKNOWN 0 +""" + with pytest.raises(ValueError, match="Unknown command"): + loads(ptn_str) + + +def test_roundtrip() -> None: + """Test that dumps followed by loads preserves data.""" + pattern = create_simple_pattern() + ptn_str = dumps(pattern) + result = loads(ptn_str) + + assert_pattern_equivalent(result, pattern) + + +def test_dump_and_load_file(tmp_path: Path) -> None: + """Test file I/O operations.""" + pattern = create_simple_pattern() + filepath = tmp_path / "test.ptn" + + dump(pattern, filepath) + + assert filepath.exists() + + result = load(filepath) + + assert_pattern_equivalent(result, pattern) + + +def test_multiple_input_output_qubits() -> None: + """Test pattern with multiple input/output qubits.""" + graph = GraphState() + in0 = graph.add_physical_node() + in1 = graph.add_physical_node() + out0 = graph.add_physical_node() + out1 = graph.add_physical_node() + + graph.register_input(in0, 0) + graph.register_input(in1, 1) + graph.register_output(out0, 0) + graph.register_output(out1, 1) + + graph.add_physical_edge(in0, out0) + graph.add_physical_edge(in1, out1) + + graph.assign_meas_basis(in0, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(in1, PlannerMeasBasis(Plane.XY, 0.0)) + + xflow = {in0: {out0}, in1: {out1}} + pattern = qompile(graph, xflow) + + ptn_str = dumps(pattern) + result = loads(ptn_str) + + assert len(result.input_node_indices) == 2 + assert len(result.output_node_indices) == 2 + + +def test_3d_coordinates() -> None: + """Test 3D coordinate serialization and parsing.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 1:0 +.coord 0 1.0 2.0 3.0 + +[0] +N 1 4.0 5.0 6.0 +M 0 XY 0 +""" + result = loads(ptn_str) + + assert result.input_coordinates[0] == (1.0, 2.0, 3.0) + + # Check N command coordinate + n_cmd = next(c for c in result.commands if isinstance(c, N)) + assert n_cmd.coordinate == (4.0, 5.0, 6.0) + + +def test_different_measurement_planes() -> None: + """Test all measurement planes are correctly handled.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 3:0 + +[0] +M 0 XY pi/4 +M 1 XZ pi/4 +M 2 YZ pi/4 +""" + result = loads(ptn_str) + + measurements = [c for c in result.commands if isinstance(c, M)] + planes = {m.node: m.meas_basis.plane for m in measurements} + + assert planes[0] == Plane.XY + assert planes[1] == Plane.XZ + assert planes[2] == Plane.YZ + + +def test_empty_flow() -> None: + """Test pattern with empty flow mappings.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 1:0 + +[0] +M 0 XY 0 + +#======== CLASSICAL ======== +""" + result = loads(ptn_str) + + assert result.pauli_frame.xflow == {} + assert result.pauli_frame.zflow == {} + + +def test_comments_ignored() -> None: + """Test that comments are properly ignored.""" + ptn_str = """ +# This is a comment +.version 1 +# Another comment +.input 0:0 # inline comment should be parsed as part of content +.output 1:0 + +# Comment in quantum section +[0] +M 0 XY 0 +""" + result = loads(ptn_str) + assert result.input_node_indices == {0: 0} + + +def test_roundtrip_preserves_logical_observables_for_stim() -> None: + """Logical observables should survive .ptn serialization.""" + pattern = create_measured_output_pattern_with_observable() + ptn_str = dumps(pattern) + + assert ".observable 0 -> 0" in ptn_str + + result = loads(ptn_str) + + assert_pattern_equivalent(result, pattern) + assert stim_compile(result) == stim_compile(pattern) + + +def test_roundtrip_preserves_detectors_for_stim() -> None: + """Detectors should survive .ptn serialization.""" + pattern = create_measured_output_pattern_with_detector() + ptn_str = dumps(pattern) + + assert ".detector 0" in ptn_str + + result = loads(ptn_str) + + assert_pattern_equivalent(result, pattern) + assert stim_compile(result) == stim_compile(pattern) + + +def test_loads_preserves_non_contiguous_node_ids() -> None: + """Loading should preserve node ids instead of remapping them.""" + ptn_str = """ +.version 1 +.input 10:0 +.output 30:0 +.coord 10 1.0 2.0 + +[0] +N 20 3.0 4.0 +E 10 20 +E 20 30 +[1] +M 10 X + +[2] +M 20 Y - +[3] +X 30 +Z 30 + +.xflow 10 -> 20 +.zflow 20 -> 30 +""" + result = loads(ptn_str) + + assert result.input_node_indices == {10: 0} + assert result.output_node_indices == {30: 0} + assert result.pauli_frame.graphstate.physical_nodes == {10, 20, 30} + assert result.pauli_frame.graphstate.physical_edges == {(10, 20), (20, 30)} + assert any(isinstance(cmd, N) and cmd.node == 20 for cmd in result.commands) + assert any(isinstance(cmd, X) and cmd.node == 30 for cmd in result.commands) + + +@pytest.mark.parametrize( + ("ptn_str", "message"), + [ + (".version 1\n.foo whatever\n", "Unknown directive"), + (".version 1\n[0]\nM 0 X bad\n", "Invalid Pauli measurement sign"), + (".version 1\n[0]\nM 0 X + junk\n", "M command requires"), + (".version 1\n.xflow 0 1\n", "must contain exactly one"), + (".version 1\n[-1]\n", "Timeslice must be non-negative"), + (".version 1\n[1]\n[0]\n", "monotonically increasing"), + (".version 1\n[0]\nM 0 XY pi/0\n", "denominator"), + (".version 1\n.detector\n", "requires at least one node"), + ], +) +def test_loads_rejects_malformed_input(ptn_str: str, message: str) -> None: + """Malformed .ptn input should fail instead of being guessed.""" + with pytest.raises(ValueError, match=message): + loads(ptn_str)