diff --git a/graphix/circ_ext/extraction.py b/graphix/circ_ext/extraction.py index 7ae81f69..a2bff872 100644 --- a/graphix/circ_ext/extraction.py +++ b/graphix/circ_ext/extraction.py @@ -95,69 +95,176 @@ class PauliString: sign: Sign = Sign.PLUS @staticmethod - def from_measured_node(flow: PauliFlow[Measurement], node: Node) -> PauliString: - """Extract the Pauli string of a measured node and its focused correction set. + def from_str(ps: str) -> PauliString: + """Construct a PauliString from its string representation. Parameters ---------- - flow : PauliFlow[Measurement] - A focused Pauli flow. The resulting Pauli string is extracted from its correction function. - node : int - A measured node whose associated Pauli string is computed. + ps : str + String encoding of a Pauli string. The first character must be + ``'+'`` or ``'-'`` (the sign), followed by one or more single-character + Pauli operators (``'X'``, ``'Y'``, ``'Z'``, or ``'I'``). + Example: ``'+XYZ'``, ``'-IXI'``. Returns ------- PauliString - Primary extraction string associated to the input measured nodes. The Pauli string is defined over qubit indices corresponding to positions in ``output_nodes``. + The PauliString instance corresponding to the input string. - Notes - ----- - See Eq. (13) and Lemma 4.4 in Ref. [1]. The phase of the Pauli string is given by Eq. (37). + Raises + ------ + ValueError + If the string is shorter than 2 characters, + the first character is not ``'+'`` or ``'-'``, or + any operator character is not one of ``'X'``, ``'Y'``, ``'Z'``, ``'I'``. + + Examples + -------- + >>> PauliString.from_str("+XYZ") + PauliString(dim=3, axes={0: Axis.X, 1: Axis.Y, 2: Axis.Z}, sign=Sign.PLUS) + >>> PauliString.from_str("-IXI") + PauliString(dim=3, axes={1: Axis.X}, sign=Sign.MINUS) + """ + if len(ps) < 2: + raise ValueError("Input string must have at least 2 characters (a sign followed by operators).") - References + sign_char, ops = ps[0], ps[1:] # Mypy disallows string unpacking + + _sign_map = {"+": Sign.PLUS, "-": Sign.MINUS} + _axis_map = {"X": Axis.X, "Y": Axis.Y, "Z": Axis.Z} + + if sign_char not in _sign_map: + raise ValueError(f"First character must be '+' or '-', got '{sign_char}'.") + + invalid = {op for op in ops if op not in _axis_map and op != "I"} + if invalid: + raise ValueError(f"Invalid Pauli operator(s): {invalid}. Each operator must be 'X', 'Y', 'Z', or 'I'.") + + return PauliString( + sign=_sign_map[sign_char], + dim=len(ops), + axes={i: _axis_map[op] for i, op in enumerate(ops) if op != "I"}, + ) + + def __str__(self) -> str: + """Return a string representation of the Pauli string.""" + pauli_str = ( + str(self.sign), + *(getattr(self.axes.get(node), "name", "I") for node in range(self.dim)), + ) + + return "".join(pauli_str) + + @staticmethod + def from_tableau(tab: MatGF2) -> PauliString: + r"""Construct a `PauliString` from a one-dimensional tableau representation. + + The tableau encodes a Pauli operator of the form + :math:`\pm P_0 \otimes P_1 \otimes \cdots \otimes P_{n-1}`, + where each single-qubit Pauli is stored as an (x, z) bit pair and the final + element encodes the sign. + + Layout of ``tab`` (length ``2n + 1``):: + + [ x_0, x_1, …, x_{n-1} | z_0, z_1, …, z_{n-1} | sign ] + + Encoding conventions: + + * ``(x=1, z=0)`` → X + * ``(x=0, z=1)`` → Z + * ``(x=1, z=1)`` → Y + * ``(x=0, z=0)`` → I (identity, qubit absent in ``axes``) + * ``sign = 0`` → +1 + * ``sign = 1`` → -1 + + Parameters ---------- - [1] Simmons, 2021 (arXiv:2109.05654). + tab : MatGF2 + A one-dimensional GF(2) array of odd length ``2n + 1`` + representing an n-qubit Pauli operator. + + Returns + ------- + PauliString + The Pauli operator encoded by ``tab``. + + Raises + ------ + ValueError + If ``tab`` is not one-dimensional or ``len(tab)`` is even. + + Examples + -------- + >>> tab = MatGF2(np.array([1, 1, 1])) + >>> PauliString.from_tableau(tab) + PauliString(dim=1, axes={0: Axis.Y}, sign=Sign.MINUS) + >>> tab = MatGF2(np.array([0, 0, 0, 1, 0])) + >>> PauliString.from_tableau(tab) + PauliString(dim=2, axes={1: Axis.Z}, sign=Sign.PLUS) """ - og = flow.og - dim = len(flow.og.output_nodes) - c_set = set(flow.correction_function[node]) - odd_c_set = og.odd_neighbors(c_set) - inter_c_odd_set = c_set & odd_c_set + if tab.ndim != 1: + raise ValueError( + f"Attempted to initialise a PauliString from a {tab.ndim}-dimensional tableau. `PauliString.from_tableau` expects a one-dimensional array." + ) + if len(tab) % 2 == 0: + raise ValueError( + f"`PauliString.from_tableau` expects an array with an odd number of elements (got {len(tab)})." + ) - x_corrections = frozenset((c_set - odd_c_set).intersection(og.output_nodes)) - y_corrections = frozenset(inter_c_odd_set.intersection(og.output_nodes)) - z_corrections = frozenset((odd_c_set - c_set).intersection(og.output_nodes)) + dim = len(tab) // 2 + sign = Sign.minus_if(tab[-1]) - # Sign computation. - negative_sign = False + axes: dict[int, Axis] = {} + for i, (x, z) in enumerate(zip(tab[:dim], tab[dim:-1], strict=True)): + if (x, z) == (1, 0): + axes[i] = Axis.X + elif (x, z) == (0, 1): + axes[i] = Axis.Z + elif (x, z) == (1, 1): + axes[i] = Axis.Y - # One phase flip per edge between adjacent vertices in the correction set. - negative_sign ^= og.graph.subgraph(c_set).number_of_edges() % 2 == 1 + return PauliString(dim, axes, sign) - # One phase flip per two Ys in the graph state stabilizer. - negative_sign ^= bool(len(inter_c_odd_set) // 2 % 2) + def to_tableau(self) -> MatGF2: + """Serialise this PauliString into a one-dimensional tableau representation. - # One phase flip per node in the graph state stabilizer that is absorbed from a Pauli measurement with angle π. - for n in c_set | odd_c_set: - meas = og.measurements.get(n, None) - if isinstance(meas, PauliMeasurement): - negative_sign ^= meas.sign == Sign.MINUS + Produces the inverse of :meth:`from_tableau`: a ``MatGF2`` of length + ``2n + 1`` whose layout is:: - # One phase flip if measured on the YZ plane. - negative_sign ^= flow.node_measurement_label(node) == Plane.YZ + [ x_0, x_1, …, x_{n-1} | z_0, z_1, …, z_{n-1} | sign ] - axes_dict: dict[int, Axis] = {} - output_to_qubit_mapping = NodeIndex() - output_to_qubit_mapping.extend(og.output_nodes) + Encoding conventions: - # Sets `x_corrections`, `y_corrections` and `z_corrections` are disjoint. - corrections = (x_corrections, y_corrections, z_corrections) - for correction, axis in zip(corrections, Axis, strict=True): - for cnode in correction: - qubit = output_to_qubit_mapping.index(cnode) - axes_dict[qubit] = axis + * X → ``(x=1, z=0)`` + * Z → ``(x=0, z=1)`` + * Y → ``(x=1, z=1)`` + * I → ``(x=0, z=0)`` (absent in ``self.axes``) + * ``+`` sign → ``0`` + * ``-`` sign → ``1`` - return PauliString(dim, axes_dict, Sign.minus_if(negative_sign)) + Returns + ------- + MatGF2 + A one-dimensional GF(2) array of length ``2 * self.dim + 1``. + + Examples + -------- + >>> ps = PauliString.from_str("-XY") + >>> ps.to_tableau() + MatGF2([1, 1, 0, 1, 1], dtype=uint8) + """ + tab = MatGF2(np.zeros(2 * self.dim + 1, dtype=np.uint8)) + + for i, ax in self.axes.items(): + if ax in {Axis.X, Axis.Y}: + tab[i] = 1 + if ax in {Axis.Y, Axis.Z}: + tab[i + self.dim] = 1 + + if self.sign is Sign.MINUS: + tab[2 * self.dim] = 1 + + return tab @dataclass(frozen=True) @@ -386,20 +493,72 @@ def to_tableau(self) -> MatGF2: f"Isometries are not supported yet: # of inputs ({len(self.input_nodes)}) must be equal to the # of outputs ({len(self.output_nodes)})." ) - tab = MatGF2(np.zeros((2 * n, 2 * n + 1))) + return MatGF2(np.vstack((*(ps.to_tableau() for ps in self.x_map), *(ps.to_tableau() for ps in self.z_map)))) - for mapping, shift in (self.x_map, 0), (self.z_map, n): - for i, ps in enumerate(mapping): # Indices in the Clifford map correspond to qubits (0 to n-1). - for j, ax in ps.axes.items(): - if ax in {Axis.X, Axis.Y}: - tab[i + shift, j] = 1 - if ax in {Axis.Y, Axis.Z}: - tab[i + shift, j + n] = 1 - if ps.sign is Sign.MINUS: - tab[i + shift, 2 * n] = 1 +def extraction_ps_from_corrected_node(flow: PauliFlow[Measurement], node: Node) -> PauliString: + """Extract the Pauli string of a measured node and its focused correction set. - return tab + Parameters + ---------- + flow : PauliFlow[Measurement] + A focused Pauli flow. The resulting Pauli string is extracted from its correction function. + node : int + A measured node whose associated Pauli string is computed. + + Returns + ------- + PauliString + Primary extraction string associated to the input measured nodes. The Pauli string is defined over qubit indices corresponding to positions in ``output_nodes``. + + Notes + ----- + See Eq. (13) and Lemma 4.4 in Ref. [1]. The phase of the Pauli string is given by Eq. (37). + + References + ---------- + [1] Simmons, 2021 (arXiv:2109.05654). + """ + og = flow.og + dim = len(flow.og.output_nodes) + c_set = set(flow.correction_function[node]) + odd_c_set = og.odd_neighbors(c_set) + inter_c_odd_set = c_set & odd_c_set + + x_corrections = frozenset((c_set - odd_c_set).intersection(og.output_nodes)) + y_corrections = frozenset(inter_c_odd_set.intersection(og.output_nodes)) + z_corrections = frozenset((odd_c_set - c_set).intersection(og.output_nodes)) + + # Sign computation. + negative_sign = False + + # One phase flip per edge between adjacent vertices in the correction set. + negative_sign ^= og.graph.subgraph(c_set).number_of_edges() % 2 == 1 + + # One phase flip per two Ys in the graph state stabilizer. + negative_sign ^= bool(len(inter_c_odd_set) // 2 % 2) + + # One phase flip per node in the graph state stabilizer that is absorbed from a Pauli measurement with angle π. + for n in c_set | odd_c_set: + meas = og.measurements.get(n, None) + if isinstance(meas, PauliMeasurement): + negative_sign ^= meas.sign == Sign.MINUS + + # One phase flip if measured on the YZ plane. + negative_sign ^= flow.node_measurement_label(node) == Plane.YZ + + axes_dict: dict[int, Axis] = {} + output_to_qubit_mapping = NodeIndex() + output_to_qubit_mapping.extend(og.output_nodes) + + # Sets `x_corrections`, `y_corrections` and `z_corrections` are disjoint. + corrections = (x_corrections, y_corrections, z_corrections) + for correction, axis in zip(corrections, Axis, strict=True): + for cnode in correction: + qubit = output_to_qubit_mapping.index(cnode) + axes_dict[qubit] = axis + + return PauliString(dim, axes_dict, Sign.minus_if(negative_sign)) def extend_input(og: OpenGraph[Measurement]) -> tuple[OpenGraph[Measurement], dict[int, int]]: @@ -512,6 +671,6 @@ def clifford_x_map_from_focused_flow(flow: PauliFlow[Measurement]) -> tuple[Paul # In the context for `CliffordMap.from_focused_flow` the check is performed when accessing the cached property `flow.pauli_strings` in the function `clifford_z_map_from_focused_flow`. # It's better to call the `PauliString` constructor instead of the cached property `flow_extended.pauli_strings` since the latter will compute a `PauliString` for _every_ node in the correction function and we just need it for the input nodes. - x_map_ancillas = {node: PauliString.from_measured_node(flow_extended, node) for node in og_extended.input_nodes} + x_map_ancillas = {node: extraction_ps_from_corrected_node(flow_extended, node) for node in og_extended.input_nodes} return tuple(x_map_ancillas[ancillary_inputs_map[input_node]] for input_node in og.input_nodes) diff --git a/graphix/flow/core.py b/graphix/flow/core.py index 10e02f34..c9878586 100644 --- a/graphix/flow/core.py +++ b/graphix/flow/core.py @@ -15,7 +15,12 @@ # `override` introduced in Python 3.12, `assert_never` introduced in Python 3.11 from typing_extensions import assert_never, override -from graphix.circ_ext.extraction import CliffordMap, ExtractionResult, PauliExponentialDAG, PauliString +from graphix.circ_ext.extraction import ( + CliffordMap, + ExtractionResult, + PauliExponentialDAG, + extraction_ps_from_corrected_node, +) from graphix.command import E, M, N, X, Z from graphix.flow._find_gpflow import ( CorrectionMatrix, @@ -51,6 +56,7 @@ # Unpack introduced in Python 3.12 from typing_extensions import Unpack + from graphix.circ_ext.extraction import PauliString from graphix.opengraph import OpenGraph from graphix.parameter import ExpressionOrSupportsFloat, Parameter from graphix.pattern import Pattern @@ -830,7 +836,7 @@ def extraction_pauli_strings(self: PauliFlow[Measurement]) -> dict[int, PauliStr """ if not self.is_focused(): raise ValueError("Flow is not focused.") - return {node: PauliString.from_measured_node(self, node) for node in self.correction_function} + return {node: extraction_ps_from_corrected_node(self, node) for node in self.correction_function} def extract_circuit(self: PauliFlow[Measurement]) -> ExtractionResult: """Extract a circuit from a flow. diff --git a/tests/test_circ_extraction.py b/tests/test_circ_extraction.py index 4116860a..d0172fe5 100644 --- a/tests/test_circ_extraction.py +++ b/tests/test_circ_extraction.py @@ -603,6 +603,15 @@ def test_parametric_angles(self, test_case: float, fx_rng: Generator) -> None: assert s1.isclose(s2) +class TestPauliString: + @pytest.mark.parametrize("ps", ["+X", "-XIY", "+II", "+YXZZYX", "-IIIYXZII"]) + def test_round_trip_conversions(self, ps: str) -> None: + tab = PauliString.from_str(ps).to_tableau() + ps_test = str(PauliString.from_tableau(tab)) + + assert ps == ps_test + + def test_extend_input() -> None: og = OpenGraph( graph=nx.Graph([(1, 3), (2, 4), (3, 4), (3, 5), (4, 6)]),