From bad7f5421e2cda2bee00ace4595549dae2e97e69 Mon Sep 17 00:00:00 2001 From: Masato Fukushima Date: Mon, 9 Feb 2026 19:01:58 +0900 Subject: [PATCH 01/10] Add .ptn text format for pattern serialization Implement a human-readable text format (.ptn) for exporting and importing MBQC patterns. The format separates quantum instructions from classical feedforward processing, with explicit timeslice markers for parallel execution. Format structure: - Header section: version, input/output nodes, coordinates - Quantum section: N/E/M/X/Z commands grouped by timeslice [0], [1], etc. - Classical section: xflow/zflow definitions, detector groups Key features: - Assembly-style syntax (one command per line) - Human-readable angle formatting (pi/2, pi/4, etc.) - Support for 2D/3D node coordinates - Inline comment support Co-Authored-By: Claude Opus 4.5 --- graphqomb/ptn_format.py | 530 +++++++++++++++++++++++++++++++++++++++ pyproject.toml | 10 + tests/test_ptn_format.py | 408 ++++++++++++++++++++++++++++++ 3 files changed, 948 insertions(+) create mode 100644 graphqomb/ptn_format.py create mode 100644 tests/test_ptn_format.py diff --git a/graphqomb/ptn_format.py b/graphqomb/ptn_format.py new file mode 100644 index 00000000..c2639385 --- /dev/null +++ b/graphqomb/ptn_format.py @@ -0,0 +1,530 @@ +"""Pattern text format (.ptn) module. + +This module provides: + +- `dump`: Write a pattern to a .ptn file or string. +- `load`: Read a pattern from a .ptn file or string. +- `dumps`: Serialize a pattern to a .ptn format string. +- `loads`: Deserialize a pattern from a .ptn format string. +""" + +from __future__ import annotations + +import math +import operator +import re +from io import StringIO +from pathlib import Path +from typing import TYPE_CHECKING + +from graphqomb.command import TICK, Command, E, M, N, X, Z +from graphqomb.common import Plane, PlannerMeasBasis + +if TYPE_CHECKING: + from graphqomb.pattern import Pattern + from graphqomb.pauli_frame import PauliFrame + +PTN_VERSION = 1 + + +def _format_angle(angle: float) -> str: + """Format angle for output, using pi fractions where appropriate. + + Parameters + ---------- + angle : `float` + The angle in radians. + + Returns + ------- + `str` + Formatted angle string. + """ + # Check for common pi fractions + if math.isclose(angle, 0.0, abs_tol=1e-10): + return "0" + if math.isclose(angle, math.pi, rel_tol=1e-10): + return "pi" + if math.isclose(angle, -math.pi, rel_tol=1e-10): + return "-pi" + if math.isclose(angle, math.pi / 2, rel_tol=1e-10): + return "pi/2" + if math.isclose(angle, -math.pi / 2, rel_tol=1e-10): + return "-pi/2" + if math.isclose(angle, math.pi / 4, rel_tol=1e-10): + return "pi/4" + if math.isclose(angle, -math.pi / 4, rel_tol=1e-10): + return "-pi/4" + if math.isclose(angle, 3 * math.pi / 2, rel_tol=1e-10): + return "3pi/2" + if math.isclose(angle, 3 * math.pi / 4, rel_tol=1e-10): + return "3pi/4" + return f"{angle}" + + +def _parse_angle(s: str) -> float: + """Parse angle string to float. + + Parameters + ---------- + s : `str` + Angle string (e.g., "0", "pi", "pi/2", "3pi/4", "1.5707963"). + + Returns + ------- + `float` + The angle in radians. + """ + s = s.strip() + if s == "0": + return 0.0 + if s == "pi": + return math.pi + if s == "-pi": + return -math.pi + if s == "pi/2": + return math.pi / 2 + if s == "-pi/2": + return -math.pi / 2 + if s == "pi/4": + return math.pi / 4 + if s == "-pi/4": + return -math.pi / 4 + if s == "3pi/2": + return 3 * math.pi / 2 + if s == "3pi/4": + return 3 * math.pi / 4 + + # Try to parse as a general pi expression (e.g., "2pi/3") + pi_match = re.match(r"^(-?\d*)pi(?:/(\d+))?$", s) + if pi_match: + numerator = pi_match.group(1) + denominator = pi_match.group(2) + num = int(numerator) if numerator and numerator != "-" else (1 if numerator != "-" else -1) + denom = int(denominator) if denominator else 1 + return num * math.pi / denom + + return float(s) + + +def _format_coord(coord: tuple[float, ...]) -> str: + """Format coordinate tuple for output. + + Parameters + ---------- + coord : `tuple`[`float`, ...] + Coordinate tuple (2D or 3D). + + Returns + ------- + `str` + Space-separated coordinate string. + """ + return " ".join(str(c) for c in coord) + + +def _parse_coord(parts: list[str]) -> tuple[float, ...]: + """Parse coordinate from string parts. + + Parameters + ---------- + parts : `list`[`str`] + List of coordinate value strings. + + Returns + ------- + `tuple`[`float`, ...] + Coordinate tuple. + """ + return tuple(float(p) for p in parts) + + +def _write_header( + out: StringIO, + pattern: Pattern, +) -> None: + """Write header section to output. + + Parameters + ---------- + out : `StringIO` + Output stream. + pattern : `Pattern` + The pattern to write. + """ + out.write(f"# GraphQOMB Pattern Format v{PTN_VERSION}\n") + out.write("\n") + out.write("#======== HEADER ========\n") + out.write(f".version {PTN_VERSION}\n") + + # Input nodes + if pattern.input_node_indices: + input_parts = [ + f"{node}:{qidx}" for node, qidx in sorted(pattern.input_node_indices.items(), key=operator.itemgetter(1)) + ] + out.write(f".input {' '.join(input_parts)}\n") + + # Output nodes + if pattern.output_node_indices: + output_parts = [ + f"{node}:{qidx}" for node, qidx in sorted(pattern.output_node_indices.items(), key=operator.itemgetter(1)) + ] + out.write(f".output {' '.join(output_parts)}\n") + + # Input coordinates + for node, coord in sorted(pattern.input_coordinates.items()): + out.write(f".coord {node} {_format_coord(coord)}\n") + + +def _write_quantum_section( + out: StringIO, + pattern: Pattern, +) -> None: + """Write quantum instructions section to output. + + Parameters + ---------- + out : `StringIO` + Output stream. + pattern : `Pattern` + The pattern to write. + """ + out.write("\n") + out.write("#======== QUANTUM ========\n") + + # Group commands by timeslice + timeslice = 0 + current_slice_commands: list[Command] = [] + + def write_slice(slice_num: int, commands: list[Command]) -> None: + if commands or slice_num == 0: + out.write(f"[{slice_num}]\n") + for cmd in commands: + _write_command(out, cmd) + + for cmd in pattern.commands: + if isinstance(cmd, TICK): + write_slice(timeslice, current_slice_commands) + current_slice_commands = [] + timeslice += 1 + else: + current_slice_commands.append(cmd) + + # Write remaining commands in last slice + if current_slice_commands or timeslice == 0: + write_slice(timeslice, current_slice_commands) + + +def _write_command(out: StringIO, cmd: Command) -> None: + """Write a single command to output. + + Parameters + ---------- + out : `StringIO` + Output stream. + cmd : `Command` + The command to write. + """ + if isinstance(cmd, N): + if cmd.coordinate is not None: + out.write(f"N {cmd.node} {_format_coord(cmd.coordinate)}\n") + else: + out.write(f"N {cmd.node}\n") + elif isinstance(cmd, E): + out.write(f"E {cmd.nodes[0]} {cmd.nodes[1]}\n") + elif isinstance(cmd, M): + plane_name = cmd.meas_basis.plane.name + angle_str = _format_angle(cmd.meas_basis.angle) + out.write(f"M {cmd.node} {plane_name} {angle_str}\n") + elif isinstance(cmd, X): + out.write(f"X {cmd.node}\n") + elif isinstance(cmd, Z): + out.write(f"Z {cmd.node}\n") + elif isinstance(cmd, TICK): + pass # TICK is handled by timeslice grouping + + +def _write_classical_section( + out: StringIO, + pauli_frame: PauliFrame, +) -> None: + """Write classical frame section to output. + + Parameters + ---------- + out : `StringIO` + Output stream. + pauli_frame : `PauliFrame` + The Pauli frame to write. + """ + out.write("\n") + out.write("#======== CLASSICAL ========\n") + + # Write xflow + for source, targets in sorted(pauli_frame.xflow.items()): + if targets: + targets_str = " ".join(str(t) for t in sorted(targets)) + out.write(f".xflow {source} -> {targets_str}\n") + + # Write zflow + for source, targets in sorted(pauli_frame.zflow.items()): + if targets: + targets_str = " ".join(str(t) for t in sorted(targets)) + out.write(f".zflow {source} -> {targets_str}\n") + + # Write parity check groups (detectors) + for group in pauli_frame.parity_check_group: + if group: + group_str = " ".join(str(n) for n in sorted(group)) + out.write(f".detector {group_str}\n") + + +def dumps(pattern: Pattern) -> str: + """Serialize a pattern to a .ptn format string. + + Parameters + ---------- + pattern : `Pattern` + The pattern to serialize. + + Returns + ------- + `str` + The .ptn format string. + """ + out = StringIO() + _write_header(out, pattern) + _write_quantum_section(out, pattern) + _write_classical_section(out, pattern.pauli_frame) + return out.getvalue() + + +def dump(pattern: Pattern, file: Path | str) -> None: + """Write a pattern to a .ptn file. + + Parameters + ---------- + pattern : `Pattern` + The pattern to write. + file : `Path` | `str` + The file path to write to. + """ + path = Path(file) + path.write_text(dumps(pattern), encoding="utf-8") + + +def _parse_node_qubit_pairs(parts: list[str]) -> dict[int, int]: + """Parse node:qubit pairs from string parts. + + Parameters + ---------- + parts : `list`[`str`] + List of "node:qubit" strings. + + Returns + ------- + `dict`[`int`, `int`] + Mapping from node to qubit index. + """ + result: dict[int, int] = {} + for part in parts: + node_str, qidx_str = part.split(":") + result[int(node_str)] = int(qidx_str) + return result + + +def _parse_flow(line: str) -> tuple[int, set[int]]: + """Parse a flow line (xflow or zflow). + + Parameters + ---------- + line : `str` + The flow line content after ".xflow" or ".zflow". + + Returns + ------- + `tuple`[`int`, `set`[`int`]] + Source node and set of target nodes. + """ + parts = line.split("->") + source = int(parts[0].strip()) + targets = {int(t) for t in parts[1].strip().split()} + return source, targets + + +class PatternData: + """Container for parsed pattern data from .ptn format. + + Attributes + ---------- + input_node_indices : `dict`[`int`, `int`] + Mapping from node to qubit index for input nodes. + output_node_indices : `dict`[`int`, `int`] + Mapping from node to qubit index for output nodes. + input_coordinates : `dict`[`int`, `tuple`[`float`, ...]] + Coordinates for input nodes. + commands : `list`[`Command`] + List of quantum commands. + xflow : `dict`[`int`, `set`[`int`]] + X correction flow mapping. + zflow : `dict`[`int`, `set`[`int`]] + Z correction flow mapping. + parity_check_groups : `list`[`set`[`int`]] + Parity check groups for error detection. + """ + + def __init__(self) -> None: + self.input_node_indices: dict[int, int] = {} + self.output_node_indices: dict[int, int] = {} + self.input_coordinates: dict[int, tuple[float, ...]] = {} + self.commands: list[Command] = [] + self.xflow: dict[int, set[int]] = {} + self.zflow: dict[int, set[int]] = {} + self.parity_check_groups: list[set[int]] = [] + + +def loads(s: str) -> PatternData: + """Deserialize a .ptn format string to pattern components. + + Parameters + ---------- + s : `str` + The .ptn format string. + + Returns + ------- + `PatternData` + Container with pattern components. + + Raises + ------ + ValueError + If the format is invalid or unsupported version. + """ + result = PatternData() + + current_timeslice = -1 + version_found = False + + for line_num, line in enumerate(s.splitlines(), 1): + # Remove inline comments + if "#" in line: + line = line[: line.index("#")] + line = line.strip() + + # Skip empty lines + if not line: + continue + + # Parse directives + if line.startswith("."): + parts = line.split(maxsplit=1) + directive = parts[0] + content = parts[1] if len(parts) > 1 else "" + + if directive == ".version": + version = int(content) + if version != PTN_VERSION: + msg = f"Unsupported .ptn version: {version} (expected {PTN_VERSION})" + raise ValueError(msg) + version_found = True + + elif directive == ".input": + result.input_node_indices = _parse_node_qubit_pairs(content.split()) + + elif directive == ".output": + result.output_node_indices = _parse_node_qubit_pairs(content.split()) + + elif directive == ".coord": + coord_parts = content.split() + node = int(coord_parts[0]) + coord = _parse_coord(coord_parts[1:]) + result.input_coordinates[node] = coord + + elif directive == ".xflow": + source, targets = _parse_flow(content) + result.xflow[source] = targets + + elif directive == ".zflow": + source, targets = _parse_flow(content) + result.zflow[source] = targets + + elif directive == ".detector": + nodes = {int(n) for n in content.split()} + result.parity_check_groups.append(nodes) + + elif directive == ".observable": + # Observable parsing - store for future use + pass + + continue + + # Parse timeslice header + if line.startswith("[") and line.endswith("]"): + slice_num = int(line[1:-1]) + # Add TICK commands for timeslice transitions + while current_timeslice < slice_num - 1: + result.commands.append(TICK()) + current_timeslice += 1 + if current_timeslice < slice_num: + if current_timeslice >= 0: + result.commands.append(TICK()) + current_timeslice = slice_num + continue + + # Parse commands + parts = line.split() + cmd_type = parts[0] + + if cmd_type == "N": + node = int(parts[1]) + n_coord: tuple[float, ...] | None = _parse_coord(parts[2:]) if len(parts) > 2 else None + result.commands.append(N(node=node, coordinate=n_coord)) + + elif cmd_type == "E": + node1 = int(parts[1]) + node2 = int(parts[2]) + result.commands.append(E(nodes=(node1, node2))) + + elif cmd_type == "M": + node = int(parts[1]) + plane = Plane[parts[2]] + angle = _parse_angle(parts[3]) + meas_basis = PlannerMeasBasis(plane, angle) + result.commands.append(M(node=node, meas_basis=meas_basis)) + + elif cmd_type == "X": + node = int(parts[1]) + result.commands.append(X(node=node)) + + elif cmd_type == "Z": + node = int(parts[1]) + result.commands.append(Z(node=node)) + + else: + msg = f"Unknown command at line {line_num}: {cmd_type}" + raise ValueError(msg) + + if not version_found: + msg = "Missing .version directive" + raise ValueError(msg) + + return result + + +def load(file: Path | str) -> PatternData: + """Read pattern components from a .ptn file. + + Parameters + ---------- + file : `Path` | `str` + The file path to read from. + + Returns + ------- + `PatternData` + Container with pattern components. + See `loads` for details. + """ + path = Path(file) + return loads(path.read_text(encoding="utf-8")) diff --git a/pyproject.toml b/pyproject.toml index ea35c112..e393cdf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,6 +125,16 @@ docstring-code-format = true "T201", # print "D", ] +"graphqomb/ptn_format.py" = [ + "C901", # too complex (parser functions) + "PLR0911", # too many return statements (angle formatting) + "PLR0912", # too many branches (parser) + "PLR0914", # too many local variables (parser) + "PLR0915", # too many statements (parser) + "PLR2004", # magic values in parser + "PLW2901", # loop variable overwritten (line processing) + "FURB122", # writelines suggestion (readability preference) +] [tool.pytest.ini_options] pythonpath = ["graphqomb"] diff --git a/tests/test_ptn_format.py b/tests/test_ptn_format.py new file mode 100644 index 00000000..1dd0ec59 --- /dev/null +++ b/tests/test_ptn_format.py @@ -0,0 +1,408 @@ +"""Tests for ptn_format module.""" + +from __future__ import annotations + +import math +import tempfile +from pathlib import Path + +import pytest + +from graphqomb.command import TICK, E, M, N, X, Z +from graphqomb.common import Plane, PlannerMeasBasis +from graphqomb.graphstate import GraphState +from graphqomb.ptn_format import ( + PatternData, + dump, + dumps, + load, + loads, +) +from graphqomb.qompiler import qompile + + +def create_simple_pattern(): + """Create a simple pattern for testing.""" + graph = GraphState() + in_node = graph.add_physical_node(coordinate=(0.0, 0.0)) + mid_node = graph.add_physical_node(coordinate=(1.0, 0.0)) + out_node = graph.add_physical_node(coordinate=(2.0, 0.0)) + + graph.register_input(in_node, 0) + graph.register_output(out_node, 0) + + graph.add_physical_edge(in_node, mid_node) + graph.add_physical_edge(mid_node, out_node) + + graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(mid_node, PlannerMeasBasis(Plane.XY, math.pi / 2)) + + xflow = {in_node: {mid_node}, mid_node: {out_node}} + pattern = qompile(graph, xflow) + + return pattern + + +def test_dumps_basic(): + """Test basic pattern serialization.""" + pattern = create_simple_pattern() + ptn_str = dumps(pattern) + + assert ".version 1" in ptn_str + assert ".input" in ptn_str + assert ".output" in ptn_str + assert "#======== QUANTUM ========" in ptn_str + assert "#======== CLASSICAL ========" in ptn_str + + +def test_dumps_contains_commands(): + """Test that dumps includes all command types.""" + pattern = create_simple_pattern() + ptn_str = dumps(pattern) + + # Check for command types + assert "N " in ptn_str # Node creation + assert "E " in ptn_str # Entanglement + assert "M " in ptn_str # Measurement + + +def test_dumps_coordinates(): + """Test that coordinates are correctly serialized.""" + pattern = create_simple_pattern() + ptn_str = dumps(pattern) + + assert ".coord 0 0.0 0.0" in ptn_str + + +def test_dumps_measurement_angles(): + """Test that measurement angles are correctly formatted.""" + pattern = create_simple_pattern() + ptn_str = dumps(pattern) + + # pi/2 should be formatted as "pi/2" + assert "pi/2" in ptn_str + # 0.0 should be formatted as "0" + assert "XY 0" in ptn_str + + +def test_loads_basic(): + """Test basic pattern deserialization.""" + ptn_str = """# Test pattern +.version 1 +.input 0:0 +.output 2:0 +.coord 0 0.0 0.0 + +#======== QUANTUM ======== +[0] +N 1 +E 0 1 +M 0 XY 0 + +#======== CLASSICAL ======== +.xflow 0 -> 1 +""" + result = loads(ptn_str) + + assert isinstance(result, PatternData) + assert result.input_node_indices == {0: 0} + assert result.output_node_indices == {2: 0} + assert result.input_coordinates == {0: (0.0, 0.0)} + + +def test_loads_commands(): + """Test that commands are correctly parsed.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 2:0 + +[0] +N 1 +N 3 1.0 2.0 +E 0 1 +M 0 XY 0 +M 1 XY pi/2 +X 2 +Z 2 +""" + result = loads(ptn_str) + + commands = result.commands + # Check command types + assert any(isinstance(c, N) and c.node == 1 for c in commands) + assert any(isinstance(c, N) and c.node == 3 and c.coordinate == (1.0, 2.0) for c in commands) + assert any(isinstance(c, E) and c.nodes == (0, 1) for c in commands) + assert any(isinstance(c, M) and c.node == 0 for c in commands) + assert any(isinstance(c, X) and c.node == 2 for c in commands) + assert any(isinstance(c, Z) and c.node == 2 for c in commands) + + +def test_loads_timeslices(): + """Test that timeslice markers generate TICK commands.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 1:0 + +[0] +E 0 1 +[1] +M 0 XY 0 +[2] +M 1 XY 0 +""" + result = loads(ptn_str) + + # Count TICK commands + tick_count = sum(1 for c in result.commands if isinstance(c, TICK)) + assert tick_count == 2 + + +def test_loads_angle_parsing(): + """Test various angle format parsing.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 5:0 + +[0] +M 0 XY 0 +M 1 XY pi +M 2 XY pi/2 +M 3 XY pi/4 +M 4 XY 3pi/4 +""" + result = loads(ptn_str) + + measurements = [c for c in result.commands if isinstance(c, M)] + angles = {m.node: m.meas_basis.angle for m in measurements} + + assert math.isclose(angles[0], 0.0) + assert math.isclose(angles[1], math.pi) + assert math.isclose(angles[2], math.pi / 2) + assert math.isclose(angles[3], math.pi / 4) + assert math.isclose(angles[4], 3 * math.pi / 4) + + +def test_loads_flow_parsing(): + """Test xflow and zflow parsing.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 2:0 + +[0] +N 1 +E 0 1 +M 0 XY 0 + +.xflow 0 -> 1 2 +.zflow 0 -> 3 4 +""" + result = loads(ptn_str) + + assert result.xflow == {0: {1, 2}} + assert result.zflow == {0: {3, 4}} + + +def test_loads_detector_parsing(): + """Test detector (parity check group) parsing.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 2:0 + +[0] +M 0 XY 0 + +.detector 0 1 2 +.detector 3 4 +""" + result = loads(ptn_str) + + assert len(result.parity_check_groups) == 2 + assert result.parity_check_groups[0] == {0, 1, 2} + assert result.parity_check_groups[1] == {3, 4} + + +def test_loads_missing_version(): + """Test that missing version raises ValueError.""" + ptn_str = """ +.input 0:0 +.output 1:0 +[0] +M 0 XY 0 +""" + with pytest.raises(ValueError, match="Missing .version directive"): + loads(ptn_str) + + +def test_loads_unsupported_version(): + """Test that unsupported version raises ValueError.""" + ptn_str = """ +.version 99 +.input 0:0 +.output 1:0 +""" + with pytest.raises(ValueError, match="Unsupported .ptn version"): + loads(ptn_str) + + +def test_loads_unknown_command(): + """Test that unknown command raises ValueError.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 1:0 + +[0] +UNKNOWN 0 +""" + with pytest.raises(ValueError, match="Unknown command"): + loads(ptn_str) + + +def test_roundtrip(): + """Test that dumps followed by loads preserves data.""" + pattern = create_simple_pattern() + ptn_str = dumps(pattern) + result = loads(ptn_str) + + # Check input/output nodes match + assert result.input_node_indices == pattern.input_node_indices + assert result.output_node_indices == pattern.output_node_indices + + # Check command count matches (excluding internal differences) + original_count = len([c for c in pattern.commands if not isinstance(c, (X, Z))]) + parsed_count = len([c for c in result.commands if not isinstance(c, (X, Z))]) + assert original_count == parsed_count + + +def test_dump_and_load_file(): + """Test file I/O operations.""" + pattern = create_simple_pattern() + + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "test.ptn" + + # Write to file + dump(pattern, filepath) + + # Verify file exists + assert filepath.exists() + + # Read from file + result = load(filepath) + + # Verify content + assert result.input_node_indices == pattern.input_node_indices + assert result.output_node_indices == pattern.output_node_indices + + +def test_multiple_input_output_qubits(): + """Test pattern with multiple input/output qubits.""" + graph = GraphState() + in0 = graph.add_physical_node() + in1 = graph.add_physical_node() + out0 = graph.add_physical_node() + out1 = graph.add_physical_node() + + graph.register_input(in0, 0) + graph.register_input(in1, 1) + graph.register_output(out0, 0) + graph.register_output(out1, 1) + + graph.add_physical_edge(in0, out0) + graph.add_physical_edge(in1, out1) + + graph.assign_meas_basis(in0, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(in1, PlannerMeasBasis(Plane.XY, 0.0)) + + xflow = {in0: {out0}, in1: {out1}} + pattern = qompile(graph, xflow) + + ptn_str = dumps(pattern) + result = loads(ptn_str) + + assert len(result.input_node_indices) == 2 + assert len(result.output_node_indices) == 2 + + +def test_3d_coordinates(): + """Test 3D coordinate serialization and parsing.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 1:0 +.coord 0 1.0 2.0 3.0 + +[0] +N 1 4.0 5.0 6.0 +M 0 XY 0 +""" + result = loads(ptn_str) + + assert result.input_coordinates[0] == (1.0, 2.0, 3.0) + + # Check N command coordinate + n_cmd = next(c for c in result.commands if isinstance(c, N)) + assert n_cmd.coordinate == (4.0, 5.0, 6.0) + + +def test_different_measurement_planes(): + """Test all measurement planes are correctly handled.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 3:0 + +[0] +M 0 XY pi/4 +M 1 XZ pi/4 +M 2 YZ pi/4 +""" + result = loads(ptn_str) + + measurements = [c for c in result.commands if isinstance(c, M)] + planes = {m.node: m.meas_basis.plane for m in measurements} + + assert planes[0] == Plane.XY + assert planes[1] == Plane.XZ + assert planes[2] == Plane.YZ + + +def test_empty_flow(): + """Test pattern with empty flow mappings.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 1:0 + +[0] +M 0 XY 0 + +#======== CLASSICAL ======== +""" + result = loads(ptn_str) + + assert result.xflow == {} + assert result.zflow == {} + + +def test_comments_ignored(): + """Test that comments are properly ignored.""" + ptn_str = """ +# This is a comment +.version 1 +# Another comment +.input 0:0 # inline comment should be parsed as part of content +.output 1:0 + +# Comment in quantum section +[0] +M 0 XY 0 +""" + result = loads(ptn_str) + # Should parse without error + assert result.input_node_indices is not None From ce5155bef63f63a4d768c1845f91b60fdbc02ba4 Mon Sep 17 00:00:00 2001 From: Masato Fukushima Date: Mon, 9 Feb 2026 19:16:56 +0900 Subject: [PATCH 02/10] Use compact notation for Pauli measurements in .ptn format Change measurement output format for Pauli basis measurements: - X measurement: "M X +" or "M X -" - Y measurement: "M Y +" or "M Y -" - Z measurement: "M Z +" or "M Z -" Non-Pauli measurements continue to use plane+angle format: - "M XY pi/4" etc. This makes the format more readable for common Clifford patterns while maintaining full expressiveness for arbitrary measurements. Co-Authored-By: Claude Opus 4.5 --- graphqomb/ptn_format.py | 47 ++++++++++++++++++++++++++++++------ tests/test_ptn_format.py | 52 +++++++++++++++++++++++++++++++++++----- 2 files changed, 86 insertions(+), 13 deletions(-) diff --git a/graphqomb/ptn_format.py b/graphqomb/ptn_format.py index c2639385..d7a95df0 100644 --- a/graphqomb/ptn_format.py +++ b/graphqomb/ptn_format.py @@ -18,7 +18,16 @@ from typing import TYPE_CHECKING from graphqomb.command import TICK, Command, E, M, N, X, Z -from graphqomb.common import Plane, PlannerMeasBasis +from graphqomb.common import ( + Axis, + AxisMeasBasis, + MeasBasis, + Plane, + PlannerMeasBasis, + Sign, + determine_pauli_axis, + is_close_angle, +) if TYPE_CHECKING: from graphqomb.pattern import Pattern @@ -233,9 +242,23 @@ def _write_command(out: StringIO, cmd: Command) -> None: elif isinstance(cmd, E): out.write(f"E {cmd.nodes[0]} {cmd.nodes[1]}\n") elif isinstance(cmd, M): - plane_name = cmd.meas_basis.plane.name - angle_str = _format_angle(cmd.meas_basis.angle) - out.write(f"M {cmd.node} {plane_name} {angle_str}\n") + # Check if this is a Pauli measurement (X, Y, or Z basis) + pauli_axis = determine_pauli_axis(cmd.meas_basis) + if pauli_axis is not None: + # Determine sign based on angle + angle = cmd.meas_basis.angle + if pauli_axis == Axis.Y: + # Y measurement: +Y at pi/2, -Y at 3pi/2 + sign = "+" if is_close_angle(angle, math.pi / 2) else "-" + else: + # X or Z measurement: + at 0, - at pi + sign = "+" if is_close_angle(angle, 0.0) else "-" + out.write(f"M {cmd.node} {pauli_axis.name} {sign}\n") + else: + # Non-Pauli measurement: use plane and angle + plane_name = cmd.meas_basis.plane.name + angle_str = _format_angle(cmd.meas_basis.angle) + out.write(f"M {cmd.node} {plane_name} {angle_str}\n") elif isinstance(cmd, X): out.write(f"X {cmd.node}\n") elif isinstance(cmd, Z): @@ -488,9 +511,19 @@ def loads(s: str) -> PatternData: elif cmd_type == "M": node = int(parts[1]) - plane = Plane[parts[2]] - angle = _parse_angle(parts[3]) - meas_basis = PlannerMeasBasis(plane, angle) + basis_spec = parts[2] + meas_basis: MeasBasis + # Check if this is a Pauli measurement (X/Y/Z with +/-) + if basis_spec in {"X", "Y", "Z"}: + sign_str = parts[3] + sign = Sign.PLUS if sign_str == "+" else Sign.MINUS + axis = Axis[basis_spec] + meas_basis = AxisMeasBasis(axis, sign) + else: + # Plane-based measurement (XY/XZ/YZ with angle) + plane = Plane[basis_spec] + m_angle = _parse_angle(parts[3]) + meas_basis = PlannerMeasBasis(plane, m_angle) result.commands.append(M(node=node, meas_basis=meas_basis)) elif cmd_type == "X": diff --git a/tests/test_ptn_format.py b/tests/test_ptn_format.py index 1dd0ec59..31923f5e 100644 --- a/tests/test_ptn_format.py +++ b/tests/test_ptn_format.py @@ -74,15 +74,15 @@ def test_dumps_coordinates(): assert ".coord 0 0.0 0.0" in ptn_str -def test_dumps_measurement_angles(): - """Test that measurement angles are correctly formatted.""" +def test_dumps_pauli_measurements(): + """Test that Pauli measurements are correctly formatted with +/- signs.""" pattern = create_simple_pattern() ptn_str = dumps(pattern) - # pi/2 should be formatted as "pi/2" - assert "pi/2" in ptn_str - # 0.0 should be formatted as "0" - assert "XY 0" in ptn_str + # X measurement (XY plane, angle 0) should be formatted as "X +" + assert "M 0 X +" in ptn_str + # Y measurement (XY plane, angle pi/2) should be formatted as "Y +" + assert "M 1 Y +" in ptn_str def test_loads_basic(): @@ -185,6 +185,46 @@ def test_loads_angle_parsing(): assert math.isclose(angles[4], 3 * math.pi / 4) +def test_loads_pauli_measurements(): + """Test parsing of Pauli measurement format (X/Y/Z +/-).""" + ptn_str = """ +.version 1 +.input 0:0 +.output 6:0 + +[0] +M 0 X + +M 1 X - +M 2 Y + +M 3 Y - +M 4 Z + +M 5 Z - +""" + result = loads(ptn_str) + + measurements = [c for c in result.commands if isinstance(c, M)] + assert len(measurements) == 6 + + # Check that Pauli measurements are parsed correctly + m0 = next(m for m in measurements if m.node == 0) + assert math.isclose(m0.meas_basis.angle, 0.0) # X + + + m1 = next(m for m in measurements if m.node == 1) + assert math.isclose(m1.meas_basis.angle, math.pi) # X - + + m2 = next(m for m in measurements if m.node == 2) + assert math.isclose(m2.meas_basis.angle, math.pi / 2) # Y + + + m3 = next(m for m in measurements if m.node == 3) + assert math.isclose(m3.meas_basis.angle, 3 * math.pi / 2) # Y - + + m4 = next(m for m in measurements if m.node == 4) + assert math.isclose(m4.meas_basis.angle, 0.0) # Z + + + m5 = next(m for m in measurements if m.node == 5) + assert math.isclose(m5.meas_basis.angle, math.pi) # Z - + + def test_loads_flow_parsing(): """Test xflow and zflow parsing.""" ptn_str = """ From 3c4e6cad58316904b851ce57aafb9d409d4d2a69 Mon Sep 17 00:00:00 2001 From: Masato Fukushima Date: Mon, 9 Feb 2026 19:18:54 +0900 Subject: [PATCH 03/10] Add CHANGELOG entry for PTN format feature Co-Authored-By: Claude Opus 4.5 --- CHANGELOG.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3210ebb3..ea3e7dfb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- **PTN Format**: Human-readable text format (`.ptn`) for pattern serialization + - `ptn_format.dumps()` / `ptn_format.dump()`: Serialize patterns to text + - `ptn_format.loads()` / `ptn_format.load()`: Deserialize patterns from text + - Format separates quantum instructions and classical feedforward processing + - Timeslice markers `[n]` indicate parallel execution groups + - Pauli measurements use compact notation (`X +`, `Y -`, `Z +`) + - Non-Pauli measurements use plane+angle format (`XY pi/4`) + - Support for node coordinates and inline comments + ## [0.2.1] - 2026-01-16 ### Added From 59a5b307adc6ee7ed2230ebbd0eab1c2a798246a Mon Sep 17 00:00:00 2001 From: Masato Fukushima Date: Mon, 9 Feb 2026 19:26:38 +0900 Subject: [PATCH 04/10] Refactor ptn_format.py to pass all lint checks without exclusions - Replace multiple if-statements with lookup tables for angle formatting/parsing - Split monolithic loads() function into _Parser class with small focused methods - Remove Raises section from loads() docstring (exception is raised by _Parser.parse) - Remove per-file-ignores for ptn_format.py from pyproject.toml Co-Authored-By: Claude Opus 4.5 --- graphqomb/ptn_format.py | 519 ++++++++++++++++++++-------------------- pyproject.toml | 10 - 2 files changed, 259 insertions(+), 270 deletions(-) diff --git a/graphqomb/ptn_format.py b/graphqomb/ptn_format.py index d7a95df0..72f90521 100644 --- a/graphqomb/ptn_format.py +++ b/graphqomb/ptn_format.py @@ -35,9 +35,36 @@ PTN_VERSION = 1 +# Angle formatting/parsing lookup tables +_ANGLE_TO_STR: dict[float, str] = { + 0.0: "0", + math.pi: "pi", + -math.pi: "-pi", + math.pi / 2: "pi/2", + -math.pi / 2: "-pi/2", + math.pi / 4: "pi/4", + -math.pi / 4: "-pi/4", + 3 * math.pi / 2: "3pi/2", + 3 * math.pi / 4: "3pi/4", +} + +_STR_TO_ANGLE: dict[str, float] = { + "0": 0.0, + "pi": math.pi, + "-pi": -math.pi, + "pi/2": math.pi / 2, + "-pi/2": -math.pi / 2, + "pi/4": math.pi / 4, + "-pi/4": -math.pi / 4, + "3pi/2": 3 * math.pi / 2, + "3pi/4": 3 * math.pi / 4, +} + +_PI_PATTERN = re.compile(r"^(-?\d*)pi(?:/(\d+))?$") + def _format_angle(angle: float) -> str: - """Format angle for output, using pi fractions where appropriate. + r"""Format angle for output, using pi fractions where appropriate. Parameters ---------- @@ -49,30 +76,18 @@ def _format_angle(angle: float) -> str: `str` Formatted angle string. """ - # Check for common pi fractions - if math.isclose(angle, 0.0, abs_tol=1e-10): - return "0" - if math.isclose(angle, math.pi, rel_tol=1e-10): - return "pi" - if math.isclose(angle, -math.pi, rel_tol=1e-10): - return "-pi" - if math.isclose(angle, math.pi / 2, rel_tol=1e-10): - return "pi/2" - if math.isclose(angle, -math.pi / 2, rel_tol=1e-10): - return "-pi/2" - if math.isclose(angle, math.pi / 4, rel_tol=1e-10): - return "pi/4" - if math.isclose(angle, -math.pi / 4, rel_tol=1e-10): - return "-pi/4" - if math.isclose(angle, 3 * math.pi / 2, rel_tol=1e-10): - return "3pi/2" - if math.isclose(angle, 3 * math.pi / 4, rel_tol=1e-10): - return "3pi/4" + for ref_angle, label in _ANGLE_TO_STR.items(): + tol = 1e-10 if ref_angle == 0.0 else None + if tol is not None: + if math.isclose(angle, ref_angle, abs_tol=tol): + return label + elif math.isclose(angle, ref_angle, rel_tol=1e-10): + return label return f"{angle}" def _parse_angle(s: str) -> float: - """Parse angle string to float. + r"""Parse angle string to float. Parameters ---------- @@ -85,27 +100,10 @@ def _parse_angle(s: str) -> float: The angle in radians. """ s = s.strip() - if s == "0": - return 0.0 - if s == "pi": - return math.pi - if s == "-pi": - return -math.pi - if s == "pi/2": - return math.pi / 2 - if s == "-pi/2": - return -math.pi / 2 - if s == "pi/4": - return math.pi / 4 - if s == "-pi/4": - return -math.pi / 4 - if s == "3pi/2": - return 3 * math.pi / 2 - if s == "3pi/4": - return 3 * math.pi / 4 - - # Try to parse as a general pi expression (e.g., "2pi/3") - pi_match = re.match(r"^(-?\d*)pi(?:/(\d+))?$", s) + if s in _STR_TO_ANGLE: + return _STR_TO_ANGLE[s] + + pi_match = _PI_PATTERN.match(s) if pi_match: numerator = pi_match.group(1) denominator = pi_match.group(2) @@ -117,11 +115,11 @@ def _parse_angle(s: str) -> float: def _format_coord(coord: tuple[float, ...]) -> str: - """Format coordinate tuple for output. + r"""Format coordinate tuple for output. Parameters ---------- - coord : `tuple`[`float`, ...] + coord : `tuple`\[`float`, ...\] Coordinate tuple (2D or 3D). Returns @@ -133,75 +131,88 @@ def _format_coord(coord: tuple[float, ...]) -> str: def _parse_coord(parts: list[str]) -> tuple[float, ...]: - """Parse coordinate from string parts. + r"""Parse coordinate from string parts. Parameters ---------- - parts : `list`[`str`] + parts : `list`\[`str`\] List of coordinate value strings. Returns ------- - `tuple`[`float`, ...] + `tuple`\[`float`, ...\] Coordinate tuple. """ return tuple(float(p) for p in parts) -def _write_header( - out: StringIO, - pattern: Pattern, -) -> None: - """Write header section to output. +# ============================================================ +# Serialization (dumps/dump) +# ============================================================ - Parameters - ---------- - out : `StringIO` - Output stream. - pattern : `Pattern` - The pattern to write. - """ + +def _write_header(out: StringIO, pattern: Pattern) -> None: + """Write header section to output.""" out.write(f"# GraphQOMB Pattern Format v{PTN_VERSION}\n") out.write("\n") out.write("#======== HEADER ========\n") out.write(f".version {PTN_VERSION}\n") - # Input nodes if pattern.input_node_indices: input_parts = [ f"{node}:{qidx}" for node, qidx in sorted(pattern.input_node_indices.items(), key=operator.itemgetter(1)) ] out.write(f".input {' '.join(input_parts)}\n") - # Output nodes if pattern.output_node_indices: output_parts = [ f"{node}:{qidx}" for node, qidx in sorted(pattern.output_node_indices.items(), key=operator.itemgetter(1)) ] out.write(f".output {' '.join(output_parts)}\n") - # Input coordinates - for node, coord in sorted(pattern.input_coordinates.items()): - out.write(f".coord {node} {_format_coord(coord)}\n") + out.writelines( + f".coord {node} {_format_coord(coord)}\n" for node, coord in sorted(pattern.input_coordinates.items()) + ) -def _write_quantum_section( - out: StringIO, - pattern: Pattern, -) -> None: - """Write quantum instructions section to output. +def _write_command(out: StringIO, cmd: Command) -> None: + """Write a single command to output.""" + if isinstance(cmd, N): + if cmd.coordinate is not None: + out.write(f"N {cmd.node} {_format_coord(cmd.coordinate)}\n") + else: + out.write(f"N {cmd.node}\n") + elif isinstance(cmd, E): + out.write(f"E {cmd.nodes[0]} {cmd.nodes[1]}\n") + elif isinstance(cmd, M): + _write_measurement(out, cmd) + elif isinstance(cmd, X): + out.write(f"X {cmd.node}\n") + elif isinstance(cmd, Z): + out.write(f"Z {cmd.node}\n") - Parameters - ---------- - out : `StringIO` - Output stream. - pattern : `Pattern` - The pattern to write. - """ + +def _write_measurement(out: StringIO, cmd: M) -> None: + """Write measurement command with appropriate format.""" + pauli_axis = determine_pauli_axis(cmd.meas_basis) + if pauli_axis is not None: + angle = cmd.meas_basis.angle + if pauli_axis == Axis.Y: + sign = "+" if is_close_angle(angle, math.pi / 2) else "-" + else: + sign = "+" if is_close_angle(angle, 0.0) else "-" + out.write(f"M {cmd.node} {pauli_axis.name} {sign}\n") + else: + plane_name = cmd.meas_basis.plane.name + angle_str = _format_angle(cmd.meas_basis.angle) + out.write(f"M {cmd.node} {plane_name} {angle_str}\n") + + +def _write_quantum_section(out: StringIO, pattern: Pattern) -> None: + """Write quantum instructions section to output.""" out.write("\n") out.write("#======== QUANTUM ========\n") - # Group commands by timeslice timeslice = 0 current_slice_commands: list[Command] = [] @@ -219,83 +230,25 @@ def write_slice(slice_num: int, commands: list[Command]) -> None: else: current_slice_commands.append(cmd) - # Write remaining commands in last slice if current_slice_commands or timeslice == 0: write_slice(timeslice, current_slice_commands) -def _write_command(out: StringIO, cmd: Command) -> None: - """Write a single command to output. - - Parameters - ---------- - out : `StringIO` - Output stream. - cmd : `Command` - The command to write. - """ - if isinstance(cmd, N): - if cmd.coordinate is not None: - out.write(f"N {cmd.node} {_format_coord(cmd.coordinate)}\n") - else: - out.write(f"N {cmd.node}\n") - elif isinstance(cmd, E): - out.write(f"E {cmd.nodes[0]} {cmd.nodes[1]}\n") - elif isinstance(cmd, M): - # Check if this is a Pauli measurement (X, Y, or Z basis) - pauli_axis = determine_pauli_axis(cmd.meas_basis) - if pauli_axis is not None: - # Determine sign based on angle - angle = cmd.meas_basis.angle - if pauli_axis == Axis.Y: - # Y measurement: +Y at pi/2, -Y at 3pi/2 - sign = "+" if is_close_angle(angle, math.pi / 2) else "-" - else: - # X or Z measurement: + at 0, - at pi - sign = "+" if is_close_angle(angle, 0.0) else "-" - out.write(f"M {cmd.node} {pauli_axis.name} {sign}\n") - else: - # Non-Pauli measurement: use plane and angle - plane_name = cmd.meas_basis.plane.name - angle_str = _format_angle(cmd.meas_basis.angle) - out.write(f"M {cmd.node} {plane_name} {angle_str}\n") - elif isinstance(cmd, X): - out.write(f"X {cmd.node}\n") - elif isinstance(cmd, Z): - out.write(f"Z {cmd.node}\n") - elif isinstance(cmd, TICK): - pass # TICK is handled by timeslice grouping - - -def _write_classical_section( - out: StringIO, - pauli_frame: PauliFrame, -) -> None: - """Write classical frame section to output. - - Parameters - ---------- - out : `StringIO` - Output stream. - pauli_frame : `PauliFrame` - The Pauli frame to write. - """ +def _write_classical_section(out: StringIO, pauli_frame: PauliFrame) -> None: + """Write classical frame section to output.""" out.write("\n") out.write("#======== CLASSICAL ========\n") - # Write xflow for source, targets in sorted(pauli_frame.xflow.items()): if targets: targets_str = " ".join(str(t) for t in sorted(targets)) out.write(f".xflow {source} -> {targets_str}\n") - # Write zflow for source, targets in sorted(pauli_frame.zflow.items()): if targets: targets_str = " ".join(str(t) for t in sorted(targets)) out.write(f".zflow {source} -> {targets_str}\n") - # Write parity check groups (detectors) for group in pauli_frame.parity_check_group: if group: group_str = " ".join(str(n) for n in sorted(group)) @@ -336,17 +289,22 @@ def dump(pattern: Pattern, file: Path | str) -> None: path.write_text(dumps(pattern), encoding="utf-8") +# ============================================================ +# Deserialization (loads/load) +# ============================================================ + + def _parse_node_qubit_pairs(parts: list[str]) -> dict[int, int]: - """Parse node:qubit pairs from string parts. + r"""Parse node:qubit pairs from string parts. Parameters ---------- - parts : `list`[`str`] + parts : `list`\[`str`\] List of "node:qubit" strings. Returns ------- - `dict`[`int`, `int`] + `dict`\[`int`, `int`\] Mapping from node to qubit index. """ result: dict[int, int] = {} @@ -357,7 +315,7 @@ def _parse_node_qubit_pairs(parts: list[str]) -> dict[int, int]: def _parse_flow(line: str) -> tuple[int, set[int]]: - """Parse a flow line (xflow or zflow). + r"""Parse a flow line (xflow or zflow). Parameters ---------- @@ -366,7 +324,7 @@ def _parse_flow(line: str) -> tuple[int, set[int]]: Returns ------- - `tuple`[`int`, `set`[`int`]] + `tuple`\[`int`, `set`\[`int`\]\] Source node and set of target nodes. """ parts = line.split("->") @@ -406,143 +364,184 @@ def __init__(self) -> None: self.parity_check_groups: list[set[int]] = [] -def loads(s: str) -> PatternData: - """Deserialize a .ptn format string to pattern components. - - Parameters - ---------- - s : `str` - The .ptn format string. - - Returns - ------- - `PatternData` - Container with pattern components. - - Raises - ------ - ValueError - If the format is invalid or unsupported version. - """ - result = PatternData() +class _Parser: + """Internal parser state for loads().""" - current_timeslice = -1 - version_found = False + def __init__(self) -> None: + self.result = PatternData() + self.current_timeslice = -1 + self.version_found = False + + def parse(self, s: str) -> PatternData: + r"""Parse the input string and return PatternData. + + Parameters + ---------- + s : `str` + The .ptn format string. + + Returns + ------- + `PatternData` + Container with pattern components. + + Raises + ------ + ValueError + If the format is invalid or unsupported version. + """ + for line_num, raw_line in enumerate(s.splitlines(), 1): + self._parse_line(line_num, raw_line) + + if not self.version_found: + msg = "Missing .version directive" + raise ValueError(msg) - for line_num, line in enumerate(s.splitlines(), 1): - # Remove inline comments - if "#" in line: - line = line[: line.index("#")] - line = line.strip() + return self.result - # Skip empty lines + def _parse_line(self, line_num: int, raw_line: str) -> None: + """Parse a single line.""" + line = raw_line.split("#", 1)[0].strip() if not line: - continue + return - # Parse directives if line.startswith("."): - parts = line.split(maxsplit=1) - directive = parts[0] - content = parts[1] if len(parts) > 1 else "" - - if directive == ".version": - version = int(content) - if version != PTN_VERSION: - msg = f"Unsupported .ptn version: {version} (expected {PTN_VERSION})" - raise ValueError(msg) - version_found = True - - elif directive == ".input": - result.input_node_indices = _parse_node_qubit_pairs(content.split()) - - elif directive == ".output": - result.output_node_indices = _parse_node_qubit_pairs(content.split()) - - elif directive == ".coord": - coord_parts = content.split() - node = int(coord_parts[0]) - coord = _parse_coord(coord_parts[1:]) - result.input_coordinates[node] = coord - - elif directive == ".xflow": - source, targets = _parse_flow(content) - result.xflow[source] = targets - - elif directive == ".zflow": - source, targets = _parse_flow(content) - result.zflow[source] = targets - - elif directive == ".detector": - nodes = {int(n) for n in content.split()} - result.parity_check_groups.append(nodes) - - elif directive == ".observable": - # Observable parsing - store for future use - pass - - continue - - # Parse timeslice header - if line.startswith("[") and line.endswith("]"): - slice_num = int(line[1:-1]) - # Add TICK commands for timeslice transitions - while current_timeslice < slice_num - 1: - result.commands.append(TICK()) - current_timeslice += 1 - if current_timeslice < slice_num: - if current_timeslice >= 0: - result.commands.append(TICK()) - current_timeslice = slice_num - continue - - # Parse commands + self._parse_directive(line) + elif line.startswith("[") and line.endswith("]"): + self._parse_timeslice(line) + else: + self._parse_command(line_num, line) + + def _parse_directive(self, line: str) -> None: + """Parse a directive line (starts with '.').""" + parts = line.split(maxsplit=1) + directive = parts[0] + content = parts[1] if len(parts) > 1 else "" + + if directive == ".version": + self._handle_version(content) + elif directive == ".input": + self.result.input_node_indices = _parse_node_qubit_pairs(content.split()) + elif directive == ".output": + self.result.output_node_indices = _parse_node_qubit_pairs(content.split()) + elif directive == ".coord": + self._handle_coord(content) + elif directive == ".xflow": + source, targets = _parse_flow(content) + self.result.xflow[source] = targets + elif directive == ".zflow": + source, targets = _parse_flow(content) + self.result.zflow[source] = targets + elif directive == ".detector": + nodes = {int(n) for n in content.split()} + self.result.parity_check_groups.append(nodes) + + def _handle_version(self, content: str) -> None: + r"""Handle .version directive. + + Raises + ------ + ValueError + If the version is unsupported. + """ + version = int(content) + if version != PTN_VERSION: + msg = f"Unsupported .ptn version: {version} (expected {PTN_VERSION})" + raise ValueError(msg) + self.version_found = True + + def _handle_coord(self, content: str) -> None: + """Handle .coord directive.""" + coord_parts = content.split() + node = int(coord_parts[0]) + coord = _parse_coord(coord_parts[1:]) + self.result.input_coordinates[node] = coord + + def _parse_timeslice(self, line: str) -> None: + """Parse timeslice marker [n].""" + slice_num = int(line[1:-1]) + while self.current_timeslice < slice_num - 1: + self.result.commands.append(TICK()) + self.current_timeslice += 1 + if self.current_timeslice < slice_num: + if self.current_timeslice >= 0: + self.result.commands.append(TICK()) + self.current_timeslice = slice_num + + def _parse_command(self, line_num: int, line: str) -> None: + r"""Parse a command line. + + Raises + ------ + ValueError + If the command type is unknown. + """ parts = line.split() cmd_type = parts[0] if cmd_type == "N": - node = int(parts[1]) - n_coord: tuple[float, ...] | None = _parse_coord(parts[2:]) if len(parts) > 2 else None - result.commands.append(N(node=node, coordinate=n_coord)) - + self._parse_n_command(parts) elif cmd_type == "E": - node1 = int(parts[1]) - node2 = int(parts[2]) - result.commands.append(E(nodes=(node1, node2))) - + self._parse_e_command(parts) elif cmd_type == "M": - node = int(parts[1]) - basis_spec = parts[2] - meas_basis: MeasBasis - # Check if this is a Pauli measurement (X/Y/Z with +/-) - if basis_spec in {"X", "Y", "Z"}: - sign_str = parts[3] - sign = Sign.PLUS if sign_str == "+" else Sign.MINUS - axis = Axis[basis_spec] - meas_basis = AxisMeasBasis(axis, sign) - else: - # Plane-based measurement (XY/XZ/YZ with angle) - plane = Plane[basis_spec] - m_angle = _parse_angle(parts[3]) - meas_basis = PlannerMeasBasis(plane, m_angle) - result.commands.append(M(node=node, meas_basis=meas_basis)) - + self._parse_m_command(parts) elif cmd_type == "X": - node = int(parts[1]) - result.commands.append(X(node=node)) - + self.result.commands.append(X(node=int(parts[1]))) elif cmd_type == "Z": - node = int(parts[1]) - result.commands.append(Z(node=node)) - + self.result.commands.append(Z(node=int(parts[1]))) else: msg = f"Unknown command at line {line_num}: {cmd_type}" raise ValueError(msg) - if not version_found: - msg = "Missing .version directive" - raise ValueError(msg) + def _parse_n_command(self, parts: list[str]) -> None: + """Parse N (node) command.""" + node = int(parts[1]) + coord: tuple[float, ...] | None = _parse_coord(parts[2:]) if len(parts) > 2 else None # noqa: PLR2004 + self.result.commands.append(N(node=node, coordinate=coord)) + + def _parse_e_command(self, parts: list[str]) -> None: + """Parse E (entangle) command.""" + node1 = int(parts[1]) + node2 = int(parts[2]) + self.result.commands.append(E(nodes=(node1, node2))) + + def _parse_m_command(self, parts: list[str]) -> None: + """Parse M (measure) command.""" + node = int(parts[1]) + basis_spec = parts[2] + meas_basis: MeasBasis + + if basis_spec in {"X", "Y", "Z"}: + sign_str = parts[3] + sign = Sign.PLUS if sign_str == "+" else Sign.MINUS + axis = Axis[basis_spec] + meas_basis = AxisMeasBasis(axis, sign) + else: + plane = Plane[basis_spec] + angle = _parse_angle(parts[3]) + meas_basis = PlannerMeasBasis(plane, angle) - return result + self.result.commands.append(M(node=node, meas_basis=meas_basis)) + + +def loads(s: str) -> PatternData: + """Deserialize a .ptn format string to pattern components. + + Parameters + ---------- + s : `str` + The .ptn format string. + + Returns + ------- + `PatternData` + Container with pattern components. + + See Also + -------- + _Parser.parse : Internal parser that may raise ValueError for invalid input. + """ + return _Parser().parse(s) def load(file: Path | str) -> PatternData: diff --git a/pyproject.toml b/pyproject.toml index e393cdf7..ea35c112 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,16 +125,6 @@ docstring-code-format = true "T201", # print "D", ] -"graphqomb/ptn_format.py" = [ - "C901", # too complex (parser functions) - "PLR0911", # too many return statements (angle formatting) - "PLR0912", # too many branches (parser) - "PLR0914", # too many local variables (parser) - "PLR0915", # too many statements (parser) - "PLR2004", # magic values in parser - "PLW2901", # loop variable overwritten (line processing) - "FURB122", # writelines suggestion (readability preference) -] [tool.pytest.ini_options] pythonpath = ["graphqomb"] From 60600114f544026610590ac5d9d93e3838f6a397 Mon Sep 17 00:00:00 2001 From: Masato Fukushima Date: Mon, 9 Feb 2026 19:29:07 +0900 Subject: [PATCH 05/10] Fix lint errors in test_ptn_format.py - Add return type annotations to all test functions - Move Pattern import to TYPE_CHECKING block - Add Returns section to create_simple_pattern docstring - Use raw strings for regex patterns in pytest.raises match - Fix import order Co-Authored-By: Claude Opus 4.5 --- tests/test_ptn_format.py | 64 ++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/tests/test_ptn_format.py b/tests/test_ptn_format.py index 31923f5e..cb65d577 100644 --- a/tests/test_ptn_format.py +++ b/tests/test_ptn_format.py @@ -5,6 +5,7 @@ import math import tempfile from pathlib import Path +from typing import TYPE_CHECKING import pytest @@ -20,9 +21,18 @@ ) from graphqomb.qompiler import qompile +if TYPE_CHECKING: + from graphqomb.pattern import Pattern -def create_simple_pattern(): - """Create a simple pattern for testing.""" + +def create_simple_pattern() -> Pattern: + """Create a simple pattern for testing. + + Returns + ------- + Pattern + A compiled MBQC pattern for testing. + """ graph = GraphState() in_node = graph.add_physical_node(coordinate=(0.0, 0.0)) mid_node = graph.add_physical_node(coordinate=(1.0, 0.0)) @@ -38,12 +48,10 @@ def create_simple_pattern(): graph.assign_meas_basis(mid_node, PlannerMeasBasis(Plane.XY, math.pi / 2)) xflow = {in_node: {mid_node}, mid_node: {out_node}} - pattern = qompile(graph, xflow) - - return pattern + return qompile(graph, xflow) -def test_dumps_basic(): +def test_dumps_basic() -> None: """Test basic pattern serialization.""" pattern = create_simple_pattern() ptn_str = dumps(pattern) @@ -55,7 +63,7 @@ def test_dumps_basic(): assert "#======== CLASSICAL ========" in ptn_str -def test_dumps_contains_commands(): +def test_dumps_contains_commands() -> None: """Test that dumps includes all command types.""" pattern = create_simple_pattern() ptn_str = dumps(pattern) @@ -66,7 +74,7 @@ def test_dumps_contains_commands(): assert "M " in ptn_str # Measurement -def test_dumps_coordinates(): +def test_dumps_coordinates() -> None: """Test that coordinates are correctly serialized.""" pattern = create_simple_pattern() ptn_str = dumps(pattern) @@ -74,7 +82,7 @@ def test_dumps_coordinates(): assert ".coord 0 0.0 0.0" in ptn_str -def test_dumps_pauli_measurements(): +def test_dumps_pauli_measurements() -> None: """Test that Pauli measurements are correctly formatted with +/- signs.""" pattern = create_simple_pattern() ptn_str = dumps(pattern) @@ -85,7 +93,7 @@ def test_dumps_pauli_measurements(): assert "M 1 Y +" in ptn_str -def test_loads_basic(): +def test_loads_basic() -> None: """Test basic pattern deserialization.""" ptn_str = """# Test pattern .version 1 @@ -110,7 +118,7 @@ def test_loads_basic(): assert result.input_coordinates == {0: (0.0, 0.0)} -def test_loads_commands(): +def test_loads_commands() -> None: """Test that commands are correctly parsed.""" ptn_str = """ .version 1 @@ -138,7 +146,7 @@ def test_loads_commands(): assert any(isinstance(c, Z) and c.node == 2 for c in commands) -def test_loads_timeslices(): +def test_loads_timeslices() -> None: """Test that timeslice markers generate TICK commands.""" ptn_str = """ .version 1 @@ -159,7 +167,7 @@ def test_loads_timeslices(): assert tick_count == 2 -def test_loads_angle_parsing(): +def test_loads_angle_parsing() -> None: """Test various angle format parsing.""" ptn_str = """ .version 1 @@ -185,7 +193,7 @@ def test_loads_angle_parsing(): assert math.isclose(angles[4], 3 * math.pi / 4) -def test_loads_pauli_measurements(): +def test_loads_pauli_measurements() -> None: """Test parsing of Pauli measurement format (X/Y/Z +/-).""" ptn_str = """ .version 1 @@ -225,7 +233,7 @@ def test_loads_pauli_measurements(): assert math.isclose(m5.meas_basis.angle, math.pi) # Z - -def test_loads_flow_parsing(): +def test_loads_flow_parsing() -> None: """Test xflow and zflow parsing.""" ptn_str = """ .version 1 @@ -246,7 +254,7 @@ def test_loads_flow_parsing(): assert result.zflow == {0: {3, 4}} -def test_loads_detector_parsing(): +def test_loads_detector_parsing() -> None: """Test detector (parity check group) parsing.""" ptn_str = """ .version 1 @@ -266,7 +274,7 @@ def test_loads_detector_parsing(): assert result.parity_check_groups[1] == {3, 4} -def test_loads_missing_version(): +def test_loads_missing_version() -> None: """Test that missing version raises ValueError.""" ptn_str = """ .input 0:0 @@ -274,22 +282,22 @@ def test_loads_missing_version(): [0] M 0 XY 0 """ - with pytest.raises(ValueError, match="Missing .version directive"): + with pytest.raises(ValueError, match=r"Missing \.version directive"): loads(ptn_str) -def test_loads_unsupported_version(): +def test_loads_unsupported_version() -> None: """Test that unsupported version raises ValueError.""" ptn_str = """ .version 99 .input 0:0 .output 1:0 """ - with pytest.raises(ValueError, match="Unsupported .ptn version"): + with pytest.raises(ValueError, match=r"Unsupported \.ptn version"): loads(ptn_str) -def test_loads_unknown_command(): +def test_loads_unknown_command() -> None: """Test that unknown command raises ValueError.""" ptn_str = """ .version 1 @@ -303,7 +311,7 @@ def test_loads_unknown_command(): loads(ptn_str) -def test_roundtrip(): +def test_roundtrip() -> None: """Test that dumps followed by loads preserves data.""" pattern = create_simple_pattern() ptn_str = dumps(pattern) @@ -319,7 +327,7 @@ def test_roundtrip(): assert original_count == parsed_count -def test_dump_and_load_file(): +def test_dump_and_load_file() -> None: """Test file I/O operations.""" pattern = create_simple_pattern() @@ -340,7 +348,7 @@ def test_dump_and_load_file(): assert result.output_node_indices == pattern.output_node_indices -def test_multiple_input_output_qubits(): +def test_multiple_input_output_qubits() -> None: """Test pattern with multiple input/output qubits.""" graph = GraphState() in0 = graph.add_physical_node() @@ -369,7 +377,7 @@ def test_multiple_input_output_qubits(): assert len(result.output_node_indices) == 2 -def test_3d_coordinates(): +def test_3d_coordinates() -> None: """Test 3D coordinate serialization and parsing.""" ptn_str = """ .version 1 @@ -390,7 +398,7 @@ def test_3d_coordinates(): assert n_cmd.coordinate == (4.0, 5.0, 6.0) -def test_different_measurement_planes(): +def test_different_measurement_planes() -> None: """Test all measurement planes are correctly handled.""" ptn_str = """ .version 1 @@ -412,7 +420,7 @@ def test_different_measurement_planes(): assert planes[2] == Plane.YZ -def test_empty_flow(): +def test_empty_flow() -> None: """Test pattern with empty flow mappings.""" ptn_str = """ .version 1 @@ -430,7 +438,7 @@ def test_empty_flow(): assert result.zflow == {} -def test_comments_ignored(): +def test_comments_ignored() -> None: """Test that comments are properly ignored.""" ptn_str = """ # This is a comment From 72181f81a78090a094823006e6649c30eb3d89a5 Mon Sep 17 00:00:00 2001 From: Masato Fukushima Date: Tue, 12 May 2026 21:56:24 -0400 Subject: [PATCH 06/10] Fix PTN deserialization roundtrip --- CHANGELOG.md | 2 +- graphqomb/ptn_format.py | 446 ++++++++++++++++++++++++++++++++++----- tests/test_ptn_format.py | 214 +++++++++++++++---- 3 files changed, 568 insertions(+), 94 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e0c1e438..cac6b6c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Timeslice markers `[n]` indicate parallel execution groups - Pauli measurements use compact notation (`X +`, `Y -`, `Z +`) - Non-Pauli measurements use plane+angle format (`XY pi/4`) - - Support for node coordinates and inline comments + - Support for node coordinates, logical observables, and inline comments - **Non-Unitary Parity Projection Example**: Added `examples/nonunitary_parity_projection.py` demonstrating measurement-induced entanglement via a 3-node star graph parity projector ### Fixed diff --git a/graphqomb/ptn_format.py b/graphqomb/ptn_format.py index 72f90521..0b866993 100644 --- a/graphqomb/ptn_format.py +++ b/graphqomb/ptn_format.py @@ -15,6 +15,7 @@ import re from io import StringIO from pathlib import Path +from types import MappingProxyType from typing import TYPE_CHECKING from graphqomb.command import TICK, Command, E, M, N, X, Z @@ -28,10 +29,13 @@ determine_pauli_axis, is_close_angle, ) +from graphqomb.graphstate import BaseGraphState +from graphqomb.pattern import Pattern +from graphqomb.pauli_frame import PauliFrame if TYPE_CHECKING: - from graphqomb.pattern import Pattern - from graphqomb.pauli_frame import PauliFrame + from collections.abc import Mapping, Sequence + from collections.abc import Set as AbstractSet PTN_VERSION = 1 @@ -77,7 +81,7 @@ def _format_angle(angle: float) -> str: Formatted angle string. """ for ref_angle, label in _ANGLE_TO_STR.items(): - tol = 1e-10 if ref_angle == 0.0 else None + tol = 1e-10 if label == "0" else None if tol is not None: if math.isclose(angle, ref_angle, abs_tol=tol): return label @@ -98,6 +102,11 @@ def _parse_angle(s: str) -> float: ------- `float` The angle in radians. + + Raises + ------ + ValueError + If the angle is not a valid number or pi expression. """ s = s.strip() if s in _STR_TO_ANGLE: @@ -109,6 +118,9 @@ def _parse_angle(s: str) -> float: denominator = pi_match.group(2) num = int(numerator) if numerator and numerator != "-" else (1 if numerator != "-" else -1) denom = int(denominator) if denominator else 1 + if denom == 0: + msg = "Angle denominator cannot be zero" + raise ValueError(msg) return num * math.pi / denom return float(s) @@ -130,7 +142,7 @@ def _format_coord(coord: tuple[float, ...]) -> str: return " ".join(str(c) for c in coord) -def _parse_coord(parts: list[str]) -> tuple[float, ...]: +def _parse_coord(parts: Sequence[str]) -> tuple[float, ...]: r"""Parse coordinate from string parts. Parameters @@ -254,6 +266,11 @@ def _write_classical_section(out: StringIO, pauli_frame: PauliFrame) -> None: group_str = " ".join(str(n) for n in sorted(group)) out.write(f".detector {group_str}\n") + for logical_idx, nodes in sorted(pauli_frame.logical_observables.items()): + if nodes: + nodes_str = " ".join(str(n) for n in sorted(nodes)) + out.write(f".observable {logical_idx} -> {nodes_str}\n") + def dumps(pattern: Pattern) -> str: """Serialize a pattern to a .ptn format string. @@ -294,7 +311,27 @@ def dump(pattern: Pattern, file: Path | str) -> None: # ============================================================ -def _parse_node_qubit_pairs(parts: list[str]) -> dict[int, int]: +def _parse_int(value: str, label: str) -> int: + """Parse an integer field. + + Returns + ------- + `int` + Parsed integer. + + Raises + ------ + ValueError + If the field is not an integer. + """ + try: + return int(value) + except ValueError as exc: + msg = f"Invalid {label}: {value!r}" + raise ValueError(msg) from exc + + +def _parse_node_qubit_pairs(parts: Sequence[str]) -> dict[int, int]: r"""Parse node:qubit pairs from string parts. Parameters @@ -306,15 +343,51 @@ def _parse_node_qubit_pairs(parts: list[str]) -> dict[int, int]: ------- `dict`\[`int`, `int`\] Mapping from node to qubit index. + + Raises + ------ + ValueError + If any pair is malformed or duplicated. """ result: dict[int, int] = {} for part in parts: - node_str, qidx_str = part.split(":") - result[int(node_str)] = int(qidx_str) + pair = part.split(":") + if len(pair) != 2: # noqa: PLR2004 + msg = f"Invalid node:qubit pair: {part!r}" + raise ValueError(msg) + node_str, qidx_str = pair + node = _parse_int(node_str, "node") + qidx = _parse_int(qidx_str, "qubit index") + if node in result: + msg = f"Duplicate node mapping: {node}" + raise ValueError(msg) + if qidx in result.values(): + msg = f"Duplicate qubit index: {qidx}" + raise ValueError(msg) + result[node] = qidx return result -def _parse_flow(line: str) -> tuple[int, set[int]]: +def _parse_node_set(parts: Sequence[str], label: str) -> set[int]: + r"""Parse a non-empty set of node ids. + + Returns + ------- + `set`\[`int`\] + Parsed node ids. + + Raises + ------ + ValueError + If the node set is empty or contains invalid integers. + """ + if not parts: + msg = f"{label} requires at least one node" + raise ValueError(msg) + return {_parse_int(part, "node") for part in parts} + + +def _parse_arrow_mapping(line: str, label: str) -> tuple[int, set[int]]: r"""Parse a flow line (xflow or zflow). Parameters @@ -326,14 +399,27 @@ def _parse_flow(line: str) -> tuple[int, set[int]]: ------- `tuple`\[`int`, `set`\[`int`\]\] Source node and set of target nodes. + + Raises + ------ + ValueError + If the mapping is malformed. """ parts = line.split("->") - source = int(parts[0].strip()) - targets = {int(t) for t in parts[1].strip().split()} + if len(parts) != 2: # noqa: PLR2004 + msg = f"{label} must contain exactly one '->'" + raise ValueError(msg) + source_part = parts[0].strip() + target_parts = parts[1].strip().split() + if not source_part: + msg = f"{label} requires a source node" + raise ValueError(msg) + source = _parse_int(source_part, "source node") + targets = _parse_node_set(target_parts, f"{label} targets") return source, targets -class PatternData: +class _PatternData: """Container for parsed pattern data from .ptn format. Attributes @@ -362,18 +448,182 @@ def __init__(self) -> None: self.xflow: dict[int, set[int]] = {} self.zflow: dict[int, set[int]] = {} self.parity_check_groups: list[set[int]] = [] + self.logical_observables: dict[int, set[int]] = {} + + +class _LoadedGraphState(BaseGraphState): + """Read-only graph state reconstructed from a .ptn file.""" + + def __init__( # noqa: PLR0913 + self, + *, + input_node_indices: Mapping[int, int], + output_node_indices: Mapping[int, int], + physical_nodes: AbstractSet[int], + physical_edges: AbstractSet[tuple[int, int]], + meas_bases: Mapping[int, MeasBasis], + coordinates: Mapping[int, tuple[float, ...]], + ) -> None: + self._input_node_indices = dict(input_node_indices) + self._output_node_indices = dict(output_node_indices) + self._physical_nodes = set(physical_nodes) + self._physical_edges = {(node1, node2) if node1 < node2 else (node2, node1) for node1, node2 in physical_edges} + self._meas_bases = dict(meas_bases) + self._coordinates = dict(coordinates) + self._neighbors: dict[int, set[int]] = {node: set() for node in self._physical_nodes} + for node1, node2 in self._physical_edges: + self._neighbors.setdefault(node1, set()).add(node2) + self._neighbors.setdefault(node2, set()).add(node1) + + @property + def input_node_indices(self) -> dict[int, int]: + return self._input_node_indices.copy() + + @property + def output_node_indices(self) -> dict[int, int]: + return self._output_node_indices.copy() + + @property + def physical_nodes(self) -> set[int]: + return set(self._physical_nodes) + + @property + def physical_edges(self) -> set[tuple[int, int]]: + return set(self._physical_edges) + + @property + def meas_bases(self) -> MappingProxyType[int, MeasBasis]: + return MappingProxyType(self._meas_bases) + + @property + def coordinates(self) -> dict[int, tuple[float, ...]]: + return self._coordinates.copy() + + def add_physical_node(self, coordinate: tuple[float, ...] | None = None) -> int: + msg = "Loaded .ptn graph states are read-only" + raise NotImplementedError(msg) + + def add_physical_edge(self, node1: int, node2: int) -> None: + msg = "Loaded .ptn graph states are read-only" + raise NotImplementedError(msg) + + def register_input(self, node: int, q_index: int) -> None: + msg = "Loaded .ptn graph states are read-only" + raise NotImplementedError(msg) + + def register_output(self, node: int, q_index: int) -> None: + msg = "Loaded .ptn graph states are read-only" + raise NotImplementedError(msg) + + def assign_meas_basis(self, node: int, meas_basis: MeasBasis) -> None: + msg = "Loaded .ptn graph states are read-only" + raise NotImplementedError(msg) + + def neighbors(self, node: int) -> set[int]: + if node not in self._physical_nodes: + msg = f"Node does not exist node={node}" + raise ValueError(msg) + return self._neighbors.get(node, set()).copy() + + def check_canonical_form(self) -> None: + for node in self._physical_nodes - self._output_node_indices.keys(): + if node not in self._meas_bases: + msg = f"Measurement basis not set for node {node}" + raise ValueError(msg) + + +def _command_nodes(cmd: Command) -> set[int]: + r"""Return node ids referenced by a command. + + Returns + ------- + `set`\[`int`\] + Node ids referenced by the command. + """ + if isinstance(cmd, (N, M, X, Z)): + return {cmd.node} + if isinstance(cmd, E): + return set(cmd.nodes) + return set() + + +def _build_pattern(data: _PatternData) -> Pattern: + """Build a Pattern from parsed .ptn data. + + Returns + ------- + `Pattern` + Reconstructed pattern. + + Raises + ------ + ValueError + If parsed commands contain invalid graph structure. + """ + nodes: set[int] = set(data.input_node_indices) | set(data.output_node_indices) | set(data.input_coordinates) + edges: set[tuple[int, int]] = set() + meas_bases: dict[int, MeasBasis] = {} + coordinates = dict(data.input_coordinates) + + for cmd in data.commands: + nodes.update(_command_nodes(cmd)) + if isinstance(cmd, E): + node1, node2 = cmd.nodes + edge = (node1, node2) if node1 < node2 else (node2, node1) + if edge[0] == edge[1]: + msg = f"Self-loop edge is not allowed: {cmd.nodes}" + raise ValueError(msg) + edges.add(edge) + elif isinstance(cmd, M): + meas_bases[cmd.node] = cmd.meas_basis + elif isinstance(cmd, N) and cmd.coordinate is not None: + coordinates[cmd.node] = cmd.coordinate + + for source, targets in data.xflow.items(): + nodes.add(source) + nodes.update(targets) + for source, targets in data.zflow.items(): + nodes.add(source) + nodes.update(targets) + for group in data.parity_check_groups: + nodes.update(group) + for nodes_in_observable in data.logical_observables.values(): + nodes.update(nodes_in_observable) + + graphstate = _LoadedGraphState( + input_node_indices=data.input_node_indices, + output_node_indices=data.output_node_indices, + physical_nodes=nodes, + physical_edges=edges, + meas_bases=meas_bases, + coordinates=coordinates, + ) + pauli_frame = PauliFrame( + graphstate, + data.xflow, + data.zflow, + parity_check_group=data.parity_check_groups, + logical_observables=data.logical_observables, + ) + return Pattern( + input_node_indices=dict(data.input_node_indices), + output_node_indices=dict(data.output_node_indices), + commands=tuple(data.commands), + pauli_frame=pauli_frame, + input_coordinates=dict(data.input_coordinates), + ) class _Parser: """Internal parser state for loads().""" def __init__(self) -> None: - self.result = PatternData() + self.result = _PatternData() self.current_timeslice = -1 self.version_found = False - def parse(self, s: str) -> PatternData: - r"""Parse the input string and return PatternData. + def parse(self, s: str) -> Pattern: + r"""Parse the input string and return Pattern. Parameters ---------- @@ -382,8 +632,8 @@ def parse(self, s: str) -> PatternData: Returns ------- - `PatternData` - Container with pattern components. + `Pattern` + Loaded measurement pattern. Raises ------ @@ -397,23 +647,39 @@ def parse(self, s: str) -> PatternData: msg = "Missing .version directive" raise ValueError(msg) - return self.result + return _build_pattern(self.result) def _parse_line(self, line_num: int, raw_line: str) -> None: - """Parse a single line.""" + """Parse a single line. + + Raises + ------ + ValueError + If the line is malformed. + """ line = raw_line.split("#", 1)[0].strip() if not line: return - if line.startswith("."): - self._parse_directive(line) - elif line.startswith("[") and line.endswith("]"): - self._parse_timeslice(line) - else: - self._parse_command(line_num, line) + try: + if line.startswith("."): + self._parse_directive(line) + elif line.startswith("[") and line.endswith("]"): + self._parse_timeslice(line) + else: + self._parse_command(line) + except ValueError as exc: + msg = f"Line {line_num}: {exc}" + raise ValueError(msg) from exc def _parse_directive(self, line: str) -> None: - """Parse a directive line (starts with '.').""" + """Parse a directive line (starts with '.'). + + Raises + ------ + ValueError + If the directive is invalid. + """ parts = line.split(maxsplit=1) directive = parts[0] content = parts[1] if len(parts) > 1 else "" @@ -427,14 +693,19 @@ def _parse_directive(self, line: str) -> None: elif directive == ".coord": self._handle_coord(content) elif directive == ".xflow": - source, targets = _parse_flow(content) + source, targets = _parse_arrow_mapping(content, ".xflow") self.result.xflow[source] = targets elif directive == ".zflow": - source, targets = _parse_flow(content) + source, targets = _parse_arrow_mapping(content, ".zflow") self.result.zflow[source] = targets elif directive == ".detector": - nodes = {int(n) for n in content.split()} - self.result.parity_check_groups.append(nodes) + self.result.parity_check_groups.append(_parse_node_set(content.split(), ".detector")) + elif directive == ".observable": + logical_idx, nodes = _parse_arrow_mapping(content, ".observable") + self.result.logical_observables[logical_idx] = nodes + else: + msg = f"Unknown directive: {directive}" + raise ValueError(msg) def _handle_version(self, content: str) -> None: r"""Handle .version directive. @@ -444,22 +715,43 @@ def _handle_version(self, content: str) -> None: ValueError If the version is unsupported. """ - version = int(content) + version = _parse_int(content, "version") if version != PTN_VERSION: msg = f"Unsupported .ptn version: {version} (expected {PTN_VERSION})" raise ValueError(msg) self.version_found = True def _handle_coord(self, content: str) -> None: - """Handle .coord directive.""" + """Handle .coord directive. + + Raises + ------ + ValueError + If the coordinate directive is malformed. + """ coord_parts = content.split() - node = int(coord_parts[0]) + if len(coord_parts) not in {3, 4}: + msg = ".coord requires a node and 2D or 3D coordinates" + raise ValueError(msg) + node = _parse_int(coord_parts[0], "node") coord = _parse_coord(coord_parts[1:]) self.result.input_coordinates[node] = coord def _parse_timeslice(self, line: str) -> None: - """Parse timeslice marker [n].""" - slice_num = int(line[1:-1]) + """Parse timeslice marker [n]. + + Raises + ------ + ValueError + If the timeslice marker is malformed. + """ + slice_num = _parse_int(line[1:-1], "timeslice") + if slice_num < 0: + msg = "Timeslice must be non-negative" + raise ValueError(msg) + if slice_num < self.current_timeslice: + msg = "Timeslices must be monotonically increasing" + raise ValueError(msg) while self.current_timeslice < slice_num - 1: self.result.commands.append(TICK()) self.current_timeslice += 1 @@ -468,7 +760,7 @@ def _parse_timeslice(self, line: str) -> None: self.result.commands.append(TICK()) self.current_timeslice = slice_num - def _parse_command(self, line_num: int, line: str) -> None: + def _parse_command(self, line: str) -> None: r"""Parse a command line. Raises @@ -486,46 +778,86 @@ def _parse_command(self, line_num: int, line: str) -> None: elif cmd_type == "M": self._parse_m_command(parts) elif cmd_type == "X": - self.result.commands.append(X(node=int(parts[1]))) + if len(parts) != 2: # noqa: PLR2004 + msg = "X command requires exactly one node" + raise ValueError(msg) + self.result.commands.append(X(node=_parse_int(parts[1], "node"))) elif cmd_type == "Z": - self.result.commands.append(Z(node=int(parts[1]))) + if len(parts) != 2: # noqa: PLR2004 + msg = "Z command requires exactly one node" + raise ValueError(msg) + self.result.commands.append(Z(node=_parse_int(parts[1], "node"))) else: - msg = f"Unknown command at line {line_num}: {cmd_type}" + msg = f"Unknown command: {cmd_type}" raise ValueError(msg) - def _parse_n_command(self, parts: list[str]) -> None: - """Parse N (node) command.""" - node = int(parts[1]) + def _parse_n_command(self, parts: Sequence[str]) -> None: + """Parse N (node) command. + + Raises + ------ + ValueError + If the command is malformed. + """ + if len(parts) not in {2, 4, 5}: + msg = "N command requires a node and optional 2D or 3D coordinates" + raise ValueError(msg) + node = _parse_int(parts[1], "node") coord: tuple[float, ...] | None = _parse_coord(parts[2:]) if len(parts) > 2 else None # noqa: PLR2004 self.result.commands.append(N(node=node, coordinate=coord)) - def _parse_e_command(self, parts: list[str]) -> None: - """Parse E (entangle) command.""" - node1 = int(parts[1]) - node2 = int(parts[2]) + def _parse_e_command(self, parts: Sequence[str]) -> None: + """Parse E (entangle) command. + + Raises + ------ + ValueError + If the command is malformed. + """ + if len(parts) != 3: # noqa: PLR2004 + msg = "E command requires exactly two nodes" + raise ValueError(msg) + node1 = _parse_int(parts[1], "node") + node2 = _parse_int(parts[2], "node") self.result.commands.append(E(nodes=(node1, node2))) - def _parse_m_command(self, parts: list[str]) -> None: - """Parse M (measure) command.""" - node = int(parts[1]) + def _parse_m_command(self, parts: Sequence[str]) -> None: + """Parse M (measure) command. + + Raises + ------ + ValueError + If the command is malformed. + """ + if len(parts) != 4: # noqa: PLR2004 + msg = "M command requires a node, basis, and angle/sign" + raise ValueError(msg) + node = _parse_int(parts[1], "node") basis_spec = parts[2] meas_basis: MeasBasis if basis_spec in {"X", "Y", "Z"}: sign_str = parts[3] + if sign_str not in {"+", "-"}: + msg = f"Invalid Pauli measurement sign: {sign_str!r}" + raise ValueError(msg) sign = Sign.PLUS if sign_str == "+" else Sign.MINUS axis = Axis[basis_spec] meas_basis = AxisMeasBasis(axis, sign) else: - plane = Plane[basis_spec] + try: + plane = Plane[basis_spec] + except KeyError as exc: + msg = f"Invalid measurement basis: {basis_spec!r}" + raise ValueError(msg) from exc angle = _parse_angle(parts[3]) meas_basis = PlannerMeasBasis(plane, angle) self.result.commands.append(M(node=node, meas_basis=meas_basis)) -def loads(s: str) -> PatternData: - """Deserialize a .ptn format string to pattern components. +def loads(s: str) -> Pattern: + """Deserialize a .ptn format string to a pattern. Parameters ---------- @@ -534,8 +866,8 @@ def loads(s: str) -> PatternData: Returns ------- - `PatternData` - Container with pattern components. + `Pattern` + The loaded pattern. See Also -------- @@ -544,8 +876,8 @@ def loads(s: str) -> PatternData: return _Parser().parse(s) -def load(file: Path | str) -> PatternData: - """Read pattern components from a .ptn file. +def load(file: Path | str) -> Pattern: + """Read a pattern from a .ptn file. Parameters ---------- @@ -554,8 +886,8 @@ def load(file: Path | str) -> PatternData: Returns ------- - `PatternData` - Container with pattern components. + `Pattern` + The loaded pattern. See `loads` for details. """ path = Path(file) diff --git a/tests/test_ptn_format.py b/tests/test_ptn_format.py index cb65d577..29a1e1e0 100644 --- a/tests/test_ptn_format.py +++ b/tests/test_ptn_format.py @@ -3,25 +3,26 @@ from __future__ import annotations import math -import tempfile -from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pytest from graphqomb.command import TICK, E, M, N, X, Z -from graphqomb.common import Plane, PlannerMeasBasis +from graphqomb.common import Plane, PlannerMeasBasis, determine_pauli_axis from graphqomb.graphstate import GraphState from graphqomb.ptn_format import ( - PatternData, dump, dumps, load, loads, ) from graphqomb.qompiler import qompile +from graphqomb.stim_compiler import stim_compile if TYPE_CHECKING: + from pathlib import Path + + from graphqomb.command import Command from graphqomb.pattern import Pattern @@ -51,6 +52,68 @@ def create_simple_pattern() -> Pattern: return qompile(graph, xflow) +def create_measured_output_pattern_with_observable() -> Pattern: + """Create a stim-compatible pattern with a logical observable.""" + graph = GraphState() + in_node = graph.add_physical_node(coordinate=(10.0, 0.0)) + out_node = graph.add_physical_node(coordinate=(20.0, 0.0)) + + graph.register_input(in_node, 0) + graph.register_output(out_node, 0) + graph.add_physical_edge(in_node, out_node) + graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(out_node, PlannerMeasBasis(Plane.XY, 0.0)) + + return qompile(graph, {in_node: {out_node}}, logical_observables={0: {in_node}}) + + +def create_measured_output_pattern_with_detector() -> Pattern: + """Create a stim-compatible pattern with a detector.""" + graph = GraphState() + in_node = graph.add_physical_node(coordinate=(30.0, 0.0)) + out_node = graph.add_physical_node(coordinate=(40.0, 0.0)) + + graph.register_input(in_node, 0) + graph.register_output(out_node, 0) + graph.add_physical_edge(in_node, out_node) + graph.assign_meas_basis(in_node, PlannerMeasBasis(Plane.XY, 0.0)) + graph.assign_meas_basis(out_node, PlannerMeasBasis(Plane.XY, 0.0)) + + return qompile(graph, {in_node: {out_node}}, parity_check_group=[{in_node}]) + + +def command_signature(cmd: Command) -> tuple[Any, ...]: # noqa: PLR0911 + """Return a behavior-level signature for a pattern command.""" + if isinstance(cmd, N): + return ("N", cmd.node, cmd.coordinate) + if isinstance(cmd, E): + return ("E", cmd.nodes) + if isinstance(cmd, M): + pauli_axis = determine_pauli_axis(cmd.meas_basis) + if pauli_axis is not None: + return ("M", cmd.node, pauli_axis, cmd.meas_basis.angle) + return ("M", cmd.node, cmd.meas_basis.plane, cmd.meas_basis.angle) + if isinstance(cmd, X): + return ("X", cmd.node) + if isinstance(cmd, Z): + return ("Z", cmd.node) + if isinstance(cmd, TICK): + return ("TICK",) + return ("UNKNOWN", type(cmd).__name__) + + +def assert_pattern_equivalent(actual: Pattern, expected: Pattern) -> None: + """Assert that serialized pattern content survived a roundtrip.""" + assert actual.input_node_indices == expected.input_node_indices + assert actual.output_node_indices == expected.output_node_indices + assert actual.input_coordinates == expected.input_coordinates + assert actual.pauli_frame.xflow == expected.pauli_frame.xflow + assert actual.pauli_frame.zflow == expected.pauli_frame.zflow + assert actual.pauli_frame.parity_check_group == expected.pauli_frame.parity_check_group + assert actual.pauli_frame.logical_observables == expected.pauli_frame.logical_observables + assert [command_signature(cmd) for cmd in actual.commands] == [command_signature(cmd) for cmd in expected.commands] + + def test_dumps_basic() -> None: """Test basic pattern serialization.""" pattern = create_simple_pattern() @@ -112,7 +175,6 @@ def test_loads_basic() -> None: """ result = loads(ptn_str) - assert isinstance(result, PatternData) assert result.input_node_indices == {0: 0} assert result.output_node_indices == {2: 0} assert result.input_coordinates == {0: (0.0, 0.0)} @@ -250,8 +312,8 @@ def test_loads_flow_parsing() -> None: """ result = loads(ptn_str) - assert result.xflow == {0: {1, 2}} - assert result.zflow == {0: {3, 4}} + assert result.pauli_frame.xflow == {0: {1, 2}} + assert result.pauli_frame.zflow == {0: {3, 4}} def test_loads_detector_parsing() -> None: @@ -269,9 +331,27 @@ def test_loads_detector_parsing() -> None: """ result = loads(ptn_str) - assert len(result.parity_check_groups) == 2 - assert result.parity_check_groups[0] == {0, 1, 2} - assert result.parity_check_groups[1] == {3, 4} + assert len(result.pauli_frame.parity_check_group) == 2 + assert result.pauli_frame.parity_check_group[0] == {0, 1, 2} + assert result.pauli_frame.parity_check_group[1] == {3, 4} + + +def test_loads_observable_parsing() -> None: + """Test logical observable parsing.""" + ptn_str = """ +.version 1 +.input 0:0 +.output 1:0 + +[0] +M 0 X + +M 1 X + + +.observable 0 -> 0 1 +""" + result = loads(ptn_str) + + assert result.pauli_frame.logical_observables == {0: {0, 1}} def test_loads_missing_version() -> None: @@ -317,35 +397,21 @@ def test_roundtrip() -> None: ptn_str = dumps(pattern) result = loads(ptn_str) - # Check input/output nodes match - assert result.input_node_indices == pattern.input_node_indices - assert result.output_node_indices == pattern.output_node_indices - - # Check command count matches (excluding internal differences) - original_count = len([c for c in pattern.commands if not isinstance(c, (X, Z))]) - parsed_count = len([c for c in result.commands if not isinstance(c, (X, Z))]) - assert original_count == parsed_count + assert_pattern_equivalent(result, pattern) -def test_dump_and_load_file() -> None: +def test_dump_and_load_file(tmp_path: Path) -> None: """Test file I/O operations.""" pattern = create_simple_pattern() + filepath = tmp_path / "test.ptn" - with tempfile.TemporaryDirectory() as tmpdir: - filepath = Path(tmpdir) / "test.ptn" + dump(pattern, filepath) - # Write to file - dump(pattern, filepath) + assert filepath.exists() - # Verify file exists - assert filepath.exists() + result = load(filepath) - # Read from file - result = load(filepath) - - # Verify content - assert result.input_node_indices == pattern.input_node_indices - assert result.output_node_indices == pattern.output_node_indices + assert_pattern_equivalent(result, pattern) def test_multiple_input_output_qubits() -> None: @@ -434,8 +500,8 @@ def test_empty_flow() -> None: """ result = loads(ptn_str) - assert result.xflow == {} - assert result.zflow == {} + assert result.pauli_frame.xflow == {} + assert result.pauli_frame.zflow == {} def test_comments_ignored() -> None: @@ -452,5 +518,81 @@ def test_comments_ignored() -> None: M 0 XY 0 """ result = loads(ptn_str) - # Should parse without error - assert result.input_node_indices is not None + assert result.input_node_indices == {0: 0} + + +def test_roundtrip_preserves_logical_observables_for_stim() -> None: + """Logical observables should survive .ptn serialization.""" + pattern = create_measured_output_pattern_with_observable() + ptn_str = dumps(pattern) + + assert ".observable 0 -> 0" in ptn_str + + result = loads(ptn_str) + + assert_pattern_equivalent(result, pattern) + assert stim_compile(result) == stim_compile(pattern) + + +def test_roundtrip_preserves_detectors_for_stim() -> None: + """Detectors should survive .ptn serialization.""" + pattern = create_measured_output_pattern_with_detector() + ptn_str = dumps(pattern) + + assert ".detector 0" in ptn_str + + result = loads(ptn_str) + + assert_pattern_equivalent(result, pattern) + assert stim_compile(result) == stim_compile(pattern) + + +def test_loads_preserves_non_contiguous_node_ids() -> None: + """Loading should preserve node ids instead of remapping them.""" + ptn_str = """ +.version 1 +.input 10:0 +.output 30:0 +.coord 10 1.0 2.0 + +[0] +N 20 3.0 4.0 +E 10 20 +E 20 30 +[1] +M 10 X + +[2] +M 20 Y - +[3] +X 30 +Z 30 + +.xflow 10 -> 20 +.zflow 20 -> 30 +""" + result = loads(ptn_str) + + assert result.input_node_indices == {10: 0} + assert result.output_node_indices == {30: 0} + assert {10, 20, 30} <= result.pauli_frame.graphstate.physical_nodes + assert any(isinstance(cmd, N) and cmd.node == 20 for cmd in result.commands) + assert any(isinstance(cmd, X) and cmd.node == 30 for cmd in result.commands) + + +@pytest.mark.parametrize( + ("ptn_str", "message"), + [ + (".version 1\n.foo whatever\n", "Unknown directive"), + (".version 1\n[0]\nM 0 X bad\n", "Invalid Pauli measurement sign"), + (".version 1\n[0]\nM 0 X + junk\n", "M command requires"), + (".version 1\n.xflow 0 1\n", "must contain exactly one"), + (".version 1\n[-1]\n", "Timeslice must be non-negative"), + (".version 1\n[1]\n[0]\n", "monotonically increasing"), + (".version 1\n[0]\nM 0 XY pi/0\n", "denominator"), + (".version 1\n.detector\n", "requires at least one node"), + ], +) +def test_loads_rejects_malformed_input(ptn_str: str, message: str) -> None: + """Malformed .ptn input should fail instead of being guessed.""" + with pytest.raises(ValueError, match=message): + loads(ptn_str) From 3e966714b316b152621492eada741cc333051ac0 Mon Sep 17 00:00:00 2001 From: Masato Fukushima Date: Thu, 14 May 2026 23:45:37 -0400 Subject: [PATCH 07/10] Fix PTN Pauli signs and empty slices --- graphqomb/ptn_format.py | 34 +++++++++++++++++------- tests/test_ptn_format.py | 56 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/graphqomb/ptn_format.py b/graphqomb/ptn_format.py index 0b866993..b532b0c1 100644 --- a/graphqomb/ptn_format.py +++ b/graphqomb/ptn_format.py @@ -204,15 +204,30 @@ def _write_command(out: StringIO, cmd: Command) -> None: out.write(f"Z {cmd.node}\n") +def _is_positive_pauli_measurement(meas_basis: MeasBasis, pauli_axis: Axis) -> bool: + """Return whether a Pauli measurement is on the positive eigenbasis. + + Returns + ------- + bool + True if the measurement basis is the positive Pauli eigenbasis. + """ + angle = meas_basis.angle + plane = meas_basis.plane + if pauli_axis == Axis.X: + positive_angle = math.pi / 2 if plane == Plane.XZ else 0.0 + elif pauli_axis == Axis.Y: + positive_angle = math.pi / 2 + else: + positive_angle = 0.0 + return is_close_angle(angle, positive_angle) + + def _write_measurement(out: StringIO, cmd: M) -> None: """Write measurement command with appropriate format.""" pauli_axis = determine_pauli_axis(cmd.meas_basis) if pauli_axis is not None: - angle = cmd.meas_basis.angle - if pauli_axis == Axis.Y: - sign = "+" if is_close_angle(angle, math.pi / 2) else "-" - else: - sign = "+" if is_close_angle(angle, 0.0) else "-" + sign = "+" if _is_positive_pauli_measurement(cmd.meas_basis, pauli_axis) else "-" out.write(f"M {cmd.node} {pauli_axis.name} {sign}\n") else: plane_name = cmd.meas_basis.plane.name @@ -229,10 +244,9 @@ def _write_quantum_section(out: StringIO, pattern: Pattern) -> None: current_slice_commands: list[Command] = [] def write_slice(slice_num: int, commands: list[Command]) -> None: - if commands or slice_num == 0: - out.write(f"[{slice_num}]\n") - for cmd in commands: - _write_command(out, cmd) + out.write(f"[{slice_num}]\n") + for cmd in commands: + _write_command(out, cmd) for cmd in pattern.commands: if isinstance(cmd, TICK): @@ -242,7 +256,7 @@ def write_slice(slice_num: int, commands: list[Command]) -> None: else: current_slice_commands.append(cmd) - if current_slice_commands or timeslice == 0: + if current_slice_commands or timeslice == 0 or (pattern.commands and isinstance(pattern.commands[-1], TICK)): write_slice(timeslice, current_slice_commands) diff --git a/tests/test_ptn_format.py b/tests/test_ptn_format.py index 29a1e1e0..f8a17966 100644 --- a/tests/test_ptn_format.py +++ b/tests/test_ptn_format.py @@ -8,8 +8,10 @@ import pytest from graphqomb.command import TICK, E, M, N, X, Z -from graphqomb.common import Plane, PlannerMeasBasis, determine_pauli_axis +from graphqomb.common import Axis, AxisMeasBasis, Plane, PlannerMeasBasis, Sign, determine_pauli_axis from graphqomb.graphstate import GraphState +from graphqomb.pattern import Pattern +from graphqomb.pauli_frame import PauliFrame from graphqomb.ptn_format import ( dump, dumps, @@ -23,7 +25,6 @@ from pathlib import Path from graphqomb.command import Command - from graphqomb.pattern import Pattern def create_simple_pattern() -> Pattern: @@ -156,6 +157,57 @@ def test_dumps_pauli_measurements() -> None: assert "M 1 Y +" in ptn_str +def test_dumps_preserves_xz_plane_x_pauli_sign() -> None: + """Plane.XZ X measurements should serialize with the correct Pauli sign.""" + graph = GraphState() + plus_node = graph.add_physical_node() + minus_node = graph.add_physical_node() + pattern = Pattern( + input_node_indices={}, + output_node_indices={}, + commands=( + M(plus_node, PlannerMeasBasis(Plane.XZ, math.pi / 2)), + M(minus_node, PlannerMeasBasis(Plane.XZ, 3 * math.pi / 2)), + ), + pauli_frame=PauliFrame(graph, xflow={}, zflow={}), + ) + + ptn_str = dumps(pattern) + result = loads(ptn_str) + + assert f"M {plus_node} X +" in ptn_str + assert f"M {minus_node} X -" in ptn_str + measurements = {cmd.node: cmd for cmd in result.commands if isinstance(cmd, M)} + assert isinstance(measurements[plus_node].meas_basis, AxisMeasBasis) + assert measurements[plus_node].meas_basis.axis == Axis.X + assert measurements[plus_node].meas_basis.sign == Sign.PLUS + assert isinstance(measurements[minus_node].meas_basis, AxisMeasBasis) + assert measurements[minus_node].meas_basis.axis == Axis.X + assert measurements[minus_node].meas_basis.sign == Sign.MINUS + + +def test_dumps_preserves_consecutive_trailing_ticks() -> None: + """Empty final timeslices should preserve consecutive trailing TICK commands.""" + graph = GraphState() + node = graph.add_physical_node() + pattern = Pattern( + input_node_indices={}, + output_node_indices={}, + commands=(N(node), TICK(), TICK()), + pauli_frame=PauliFrame(graph, xflow={}, zflow={}), + ) + + ptn_str = dumps(pattern) + result = loads(ptn_str) + + assert f"[0]\nN {node}\n[1]\n[2]\n" in ptn_str + assert [command_signature(cmd) for cmd in result.commands] == [ + ("N", node, None), + ("TICK",), + ("TICK",), + ] + + def test_loads_basic() -> None: """Test basic pattern deserialization.""" ptn_str = """# Test pattern From 15451c898b0303ba7e8d56bb1ae54a80d2cd6cb2 Mon Sep 17 00:00:00 2001 From: Masato Fukushima Date: Thu, 14 May 2026 23:48:05 -0400 Subject: [PATCH 08/10] Fix PTN format test typing --- tests/test_ptn_format.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_ptn_format.py b/tests/test_ptn_format.py index f8a17966..c1818db8 100644 --- a/tests/test_ptn_format.py +++ b/tests/test_ptn_format.py @@ -178,12 +178,14 @@ def test_dumps_preserves_xz_plane_x_pauli_sign() -> None: assert f"M {plus_node} X +" in ptn_str assert f"M {minus_node} X -" in ptn_str measurements = {cmd.node: cmd for cmd in result.commands if isinstance(cmd, M)} - assert isinstance(measurements[plus_node].meas_basis, AxisMeasBasis) - assert measurements[plus_node].meas_basis.axis == Axis.X - assert measurements[plus_node].meas_basis.sign == Sign.PLUS - assert isinstance(measurements[minus_node].meas_basis, AxisMeasBasis) - assert measurements[minus_node].meas_basis.axis == Axis.X - assert measurements[minus_node].meas_basis.sign == Sign.MINUS + plus_basis = measurements[plus_node].meas_basis + minus_basis = measurements[minus_node].meas_basis + assert isinstance(plus_basis, AxisMeasBasis) + assert plus_basis.axis == Axis.X + assert plus_basis.sign == Sign.PLUS + assert isinstance(minus_basis, AxisMeasBasis) + assert minus_basis.axis == Axis.X + assert minus_basis.sign == Sign.MINUS def test_dumps_preserves_consecutive_trailing_ticks() -> None: From e2d6ed8812772ca2c106d868ebef7f1fe46016f5 Mon Sep 17 00:00:00 2001 From: Masato Fukushima Date: Thu, 14 May 2026 23:58:46 -0400 Subject: [PATCH 09/10] Fix initial PTN timeslice parsing --- graphqomb/ptn_format.py | 10 +++------- tests/test_ptn_format.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/graphqomb/ptn_format.py b/graphqomb/ptn_format.py index b532b0c1..50181894 100644 --- a/graphqomb/ptn_format.py +++ b/graphqomb/ptn_format.py @@ -766,13 +766,9 @@ def _parse_timeslice(self, line: str) -> None: if slice_num < self.current_timeslice: msg = "Timeslices must be monotonically increasing" raise ValueError(msg) - while self.current_timeslice < slice_num - 1: - self.result.commands.append(TICK()) - self.current_timeslice += 1 - if self.current_timeslice < slice_num: - if self.current_timeslice >= 0: - self.result.commands.append(TICK()) - self.current_timeslice = slice_num + ticks_to_insert = slice_num if self.current_timeslice < 0 else slice_num - self.current_timeslice + self.result.commands.extend(TICK() for _ in range(ticks_to_insert)) + self.current_timeslice = slice_num def _parse_command(self, line: str) -> None: r"""Parse a command line. diff --git a/tests/test_ptn_format.py b/tests/test_ptn_format.py index c1818db8..5bfa6252 100644 --- a/tests/test_ptn_format.py +++ b/tests/test_ptn_format.py @@ -283,6 +283,23 @@ def test_loads_timeslices() -> None: assert tick_count == 2 +def test_loads_initial_empty_timeslices() -> None: + """Test that an initial non-zero timeslice marker creates the matching TICKs.""" + ptn_str = """ +.version 1 + +[2] +M 0 XY 0 +""" + result = loads(ptn_str) + + assert [command_signature(cmd) for cmd in result.commands] == [ + ("TICK",), + ("TICK",), + ("M", 0, Axis.X, 0), + ] + + def test_loads_angle_parsing() -> None: """Test various angle format parsing.""" ptn_str = """ From df30a815d6b9e71942991a70caeee743083ab7d4 Mon Sep 17 00:00:00 2001 From: Masato Fukushima Date: Sun, 17 May 2026 13:32:26 -0400 Subject: [PATCH 10/10] Add ptn format docs and example output --- .gitignore | 1 + docs/source/ptn_format.rst | 9 +++ docs/source/references.rst | 1 + examples/pattern_generation.py | 13 +++ graphqomb/ptn_format.py | 142 +++++++++++++++++++++++---------- tests/test_ptn_format.py | 17 +++- 6 files changed, 141 insertions(+), 42 deletions(-) create mode 100644 docs/source/ptn_format.rst diff --git a/.gitignore b/.gitignore index b100828f..9210a8a8 100644 --- a/.gitignore +++ b/.gitignore @@ -162,6 +162,7 @@ cython_debug/ #.idea/ # Custom stuff +examples/generated/ docs/source/gallery docs/source/sg_execution_times.rst *.png diff --git a/docs/source/ptn_format.rst b/docs/source/ptn_format.rst new file mode 100644 index 00000000..bd6b1f14 --- /dev/null +++ b/docs/source/ptn_format.rst @@ -0,0 +1,9 @@ +Pattern Text Format +=================== + +:mod:`graphqomb.ptn_format` module +++++++++++++++++++++++++++++++++++ + +.. automodule:: graphqomb.ptn_format + :members: + :member-order: bysource diff --git a/docs/source/references.rst b/docs/source/references.rst index 37a72600..05571380 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -17,6 +17,7 @@ Module reference focus_flow command pattern + ptn_format pauli_frame qompiler scheduler diff --git a/examples/pattern_generation.py b/examples/pattern_generation.py index d691fae3..9e874cd8 100644 --- a/examples/pattern_generation.py +++ b/examples/pattern_generation.py @@ -6,7 +6,11 @@ """ # %% +import sys +from pathlib import Path + from graphqomb.pattern import print_pattern +from graphqomb.ptn_format import dump from graphqomb.qompiler import qompile from graphqomb.random_objects import generate_random_flow_graph @@ -22,3 +26,12 @@ print("pattern depth:", pattern.depth) print("pattern max space:", pattern.max_space) print_pattern(pattern) + +# Dump the pattern in GraphQOMB's text format. +script_path = Path(globals().get("__file__", sys.argv[0])).resolve() +example_dir = script_path.parent +output_dir = example_dir / "generated" +output_dir.mkdir(exist_ok=True) +ptn_path = output_dir / "pattern_generation.ptn" +dump(pattern, ptn_path) +print("wrote pattern:", ptn_path) diff --git a/graphqomb/ptn_format.py b/graphqomb/ptn_format.py index 50181894..20d0d92b 100644 --- a/graphqomb/ptn_format.py +++ b/graphqomb/ptn_format.py @@ -13,6 +13,7 @@ import math import operator import re +from dataclasses import dataclass, field from io import StringIO from pathlib import Path from types import MappingProxyType @@ -27,6 +28,7 @@ PlannerMeasBasis, Sign, determine_pauli_axis, + is_clifford_angle, is_close_angle, ) from graphqomb.graphstate import BaseGraphState @@ -34,8 +36,7 @@ from graphqomb.pauli_frame import PauliFrame if TYPE_CHECKING: - from collections.abc import Mapping, Sequence - from collections.abc import Set as AbstractSet + from collections.abc import Sequence PTN_VERSION = 1 @@ -80,12 +81,14 @@ def _format_angle(angle: float) -> str: `str` Formatted angle string. """ - for ref_angle, label in _ANGLE_TO_STR.items(): - tol = 1e-10 if label == "0" else None - if tol is not None: - if math.isclose(angle, ref_angle, abs_tol=tol): - return label - elif math.isclose(angle, ref_angle, rel_tol=1e-10): + candidates = ( + ((ref_angle, label) for ref_angle, label in _ANGLE_TO_STR.items() if is_clifford_angle(ref_angle)) + if is_clifford_angle(angle) + else _ANGLE_TO_STR.items() + ) + ordered_candidates = sorted(candidates, key=lambda item: item[0] < 0 if angle >= 0 else item[0] >= 0) + for ref_angle, label in ordered_candidates: + if is_close_angle(angle, ref_angle): return label return f"{angle}" @@ -313,7 +316,7 @@ def dump(pattern: Pattern, file: Path | str) -> None: ---------- pattern : `Pattern` The pattern to write. - file : `Path` | `str` + file : `pathlib.Path` | `str` The file path to write to. """ path = Path(file) @@ -433,6 +436,62 @@ def _parse_arrow_mapping(line: str, label: str) -> tuple[int, set[int]]: return source, targets +def _empty_node_index_map() -> dict[int, int]: + r"""Return an empty node-to-qubit-index map. + + Returns + ------- + `dict`\[`int`, `int`\] + Empty node-to-qubit-index map. + """ + return {} + + +def _empty_coordinates() -> dict[int, tuple[float, ...]]: + r"""Return an empty coordinate map. + + Returns + ------- + `dict`\[`int`, `tuple`\[`float`, ...\]\] + Empty coordinate map. + """ + return {} + + +def _empty_commands() -> list[Command]: + r"""Return an empty command list. + + Returns + ------- + `list`\[`Command`\] + Empty command list. + """ + return [] + + +def _empty_node_set_map() -> dict[int, set[int]]: + r"""Return an empty node-to-node-set map. + + Returns + ------- + `dict`\[`int`, `set`\[`int`\]\] + Empty node-to-node-set map. + """ + return {} + + +def _empty_node_groups() -> list[set[int]]: + r"""Return an empty node group list. + + Returns + ------- + `list`\[`set`\[`int`\]\] + Empty node group list. + """ + return [] + + +@dataclass(slots=True) class _PatternData: """Container for parsed pattern data from .ptn format. @@ -454,36 +513,37 @@ class _PatternData: Parity check groups for error detection. """ - def __init__(self) -> None: - self.input_node_indices: dict[int, int] = {} - self.output_node_indices: dict[int, int] = {} - self.input_coordinates: dict[int, tuple[float, ...]] = {} - self.commands: list[Command] = [] - self.xflow: dict[int, set[int]] = {} - self.zflow: dict[int, set[int]] = {} - self.parity_check_groups: list[set[int]] = [] - self.logical_observables: dict[int, set[int]] = {} + input_node_indices: dict[int, int] = field(default_factory=_empty_node_index_map) + output_node_indices: dict[int, int] = field(default_factory=_empty_node_index_map) + input_coordinates: dict[int, tuple[float, ...]] = field(default_factory=_empty_coordinates) + commands: list[Command] = field(default_factory=_empty_commands) + xflow: dict[int, set[int]] = field(default_factory=_empty_node_set_map) + zflow: dict[int, set[int]] = field(default_factory=_empty_node_set_map) + parity_check_groups: list[set[int]] = field(default_factory=_empty_node_groups) + logical_observables: dict[int, set[int]] = field(default_factory=_empty_node_set_map) +@dataclass(slots=True) class _LoadedGraphState(BaseGraphState): """Read-only graph state reconstructed from a .ptn file.""" - def __init__( # noqa: PLR0913 - self, - *, - input_node_indices: Mapping[int, int], - output_node_indices: Mapping[int, int], - physical_nodes: AbstractSet[int], - physical_edges: AbstractSet[tuple[int, int]], - meas_bases: Mapping[int, MeasBasis], - coordinates: Mapping[int, tuple[float, ...]], - ) -> None: - self._input_node_indices = dict(input_node_indices) - self._output_node_indices = dict(output_node_indices) - self._physical_nodes = set(physical_nodes) - self._physical_edges = {(node1, node2) if node1 < node2 else (node2, node1) for node1, node2 in physical_edges} - self._meas_bases = dict(meas_bases) - self._coordinates = dict(coordinates) + _input_node_indices: dict[int, int] + _output_node_indices: dict[int, int] + _physical_nodes: set[int] + _physical_edges: set[tuple[int, int]] + _meas_bases: dict[int, MeasBasis] + _coordinates: dict[int, tuple[float, ...]] + _neighbors: dict[int, set[int]] = field(init=False, repr=False) + + def __post_init__(self) -> None: + self._input_node_indices = dict(self._input_node_indices) + self._output_node_indices = dict(self._output_node_indices) + self._physical_nodes = set(self._physical_nodes) + self._physical_edges = { + (node1, node2) if node1 < node2 else (node2, node1) for node1, node2 in self._physical_edges + } + self._meas_bases = dict(self._meas_bases) + self._coordinates = dict(self._coordinates) self._neighbors: dict[int, set[int]] = {node: set() for node in self._physical_nodes} for node1, node2 in self._physical_edges: self._neighbors.setdefault(node1, set()).add(node2) @@ -605,12 +665,12 @@ def _build_pattern(data: _PatternData) -> Pattern: nodes.update(nodes_in_observable) graphstate = _LoadedGraphState( - input_node_indices=data.input_node_indices, - output_node_indices=data.output_node_indices, - physical_nodes=nodes, - physical_edges=edges, - meas_bases=meas_bases, - coordinates=coordinates, + _input_node_indices=data.input_node_indices, + _output_node_indices=data.output_node_indices, + _physical_nodes=nodes, + _physical_edges=edges, + _meas_bases=meas_bases, + _coordinates=coordinates, ) pauli_frame = PauliFrame( graphstate, @@ -891,7 +951,7 @@ def load(file: Path | str) -> Pattern: Parameters ---------- - file : `Path` | `str` + file : `pathlib.Path` | `str` The file path to read from. Returns diff --git a/tests/test_ptn_format.py b/tests/test_ptn_format.py index 5bfa6252..99bc0371 100644 --- a/tests/test_ptn_format.py +++ b/tests/test_ptn_format.py @@ -157,6 +157,20 @@ def test_dumps_pauli_measurements() -> None: assert "M 1 Y +" in ptn_str +def test_dumps_formats_near_known_angle() -> None: + """Angles near known constants should serialize canonically.""" + graph = GraphState() + node = graph.add_physical_node() + pattern = Pattern( + input_node_indices={}, + output_node_indices={}, + commands=(M(node, PlannerMeasBasis(Plane.XY, math.pi / 4 + 1e-10)),), + pauli_frame=PauliFrame(graph, xflow={}, zflow={}), + ) + + assert f"M {node} XY pi/4" in dumps(pattern) + + def test_dumps_preserves_xz_plane_x_pauli_sign() -> None: """Plane.XZ X measurements should serialize with the correct Pauli sign.""" graph = GraphState() @@ -645,7 +659,8 @@ def test_loads_preserves_non_contiguous_node_ids() -> None: assert result.input_node_indices == {10: 0} assert result.output_node_indices == {30: 0} - assert {10, 20, 30} <= result.pauli_frame.graphstate.physical_nodes + assert result.pauli_frame.graphstate.physical_nodes == {10, 20, 30} + assert result.pauli_frame.graphstate.physical_edges == {(10, 20), (20, 30)} assert any(isinstance(cmd, N) and cmd.node == 20 for cmd in result.commands) assert any(isinstance(cmd, X) and cmd.node == 30 for cmd in result.commands)