Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 214 additions & 55 deletions graphix/circ_ext/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,69 +95,176 @@ class PauliString:
sign: Sign = Sign.PLUS

@staticmethod
def from_measured_node(flow: PauliFlow[Measurement], node: Node) -> PauliString:
"""Extract the Pauli string of a measured node and its focused correction set.
def from_str(ps: str) -> PauliString:
"""Construct a PauliString from its string representation.

Parameters
----------
flow : PauliFlow[Measurement]
A focused Pauli flow. The resulting Pauli string is extracted from its correction function.
node : int
A measured node whose associated Pauli string is computed.
ps : str
String encoding of a Pauli string. The first character must be
``'+'`` or ``'-'`` (the sign), followed by one or more single-character
Pauli operators (``'X'``, ``'Y'``, ``'Z'``, or ``'I'``).
Example: ``'+XYZ'``, ``'-IXI'``.

Returns
-------
PauliString
Primary extraction string associated to the input measured nodes. The Pauli string is defined over qubit indices corresponding to positions in ``output_nodes``.
The PauliString instance corresponding to the input string.

Notes
-----
See Eq. (13) and Lemma 4.4 in Ref. [1]. The phase of the Pauli string is given by Eq. (37).
Raises
------
ValueError
If the string is shorter than 2 characters,
the first character is not ``'+'`` or ``'-'``, or
any operator character is not one of ``'X'``, ``'Y'``, ``'Z'``, ``'I'``.

Examples
--------
>>> PauliString.from_str("+XYZ")
PauliString(dim=3, axes={0: Axis.X, 1: Axis.Y, 2: Axis.Z}, sign=Sign.PLUS)
>>> PauliString.from_str("-IXI")
PauliString(dim=3, axes={1: Axis.X}, sign=Sign.MINUS)
"""
if len(ps) < 2:
raise ValueError("Input string must have at least 2 characters (a sign followed by operators).")
Comment on lines +128 to +129
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we disallow empty Pauli strings?
Should the default sign be +?

For comparison:

>>> stim.PauliString("X")
stim.PauliString("+X")
>>> stim.PauliString("")
stim.PauliString("+")

In addition, PauliString.from_str(str(PauliString(0, {}))) now raises an exception. I'm not sure whether this behavior is intentional.


References
sign_char, ops = ps[0], ps[1:] # Mypy disallows string unpacking
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Python, string unpacking is just a particular case of iterable unpacking. For example,

sign_char, ops = ps[0], ps[1:]

uses slicing, so ops becomes a string. In contrast,

sign_char, *ops = ps

produces a list of characters for ops. For the rest of the code, the difference is irrelevant.

If you prefer the more declarative string-unpacking style, you can be explicit about using iterable unpacking; mypy accepts this:

Suggested change
sign_char, ops = ps[0], ps[1:] # Mypy disallows string unpacking
sign_char, *ops = iter(ps)


_sign_map = {"+": Sign.PLUS, "-": Sign.MINUS}
_axis_map = {"X": Axis.X, "Y": Axis.Y, "Z": Axis.Z}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class Axis itself behaves much like a dictionary: Axis["X"] returns Axis.X. However, the expression "X" in Axis returns False. The proper mapping is Axis.__members__: "X" in Axis.__members__ returns True.


if sign_char not in _sign_map:
raise ValueError(f"First character must be '+' or '-', got '{sign_char}'.")

invalid = {op for op in ops if op not in _axis_map and op != "I"}
if invalid:
raise ValueError(f"Invalid Pauli operator(s): {invalid}. Each operator must be 'X', 'Y', 'Z', or 'I'.")

return PauliString(
sign=_sign_map[sign_char],
dim=len(ops),
axes={i: _axis_map[op] for i, op in enumerate(ops) if op != "I"},
)

def __str__(self) -> str:
"""Return a string representation of the Pauli string."""
pauli_str = (
str(self.sign),
*(getattr(self.axes.get(node), "name", "I") for node in range(self.dim)),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to avoid getattr because it isn't type-safe. Moreover, it's clearer to test explicitly whether an axis exists or not rather than relying on the fact that there is no field name in None.

Suggested change
*(getattr(self.axes.get(node), "name", "I") for node in range(self.dim)),
*(axis.name if (axis := self.axes.get(node)) else "I" for node in range(self.dim)),

)

return "".join(pauli_str)

@staticmethod
def from_tableau(tab: MatGF2) -> PauliString:
r"""Construct a `PauliString` from a one-dimensional tableau representation.

The tableau encodes a Pauli operator of the form
:math:`\pm P_0 \otimes P_1 \otimes \cdots \otimes P_{n-1}`,
where each single-qubit Pauli is stored as an (x, z) bit pair and the final
element encodes the sign.

Layout of ``tab`` (length ``2n + 1``)::

[ x_0, x_1, …, x_{n-1} | z_0, z_1, …, z_{n-1} | sign ]

Encoding conventions:

* ``(x=1, z=0)`` → X
* ``(x=0, z=1)`` → Z
* ``(x=1, z=1)`` → Y
* ``(x=0, z=0)`` → I (identity, qubit absent in ``axes``)
* ``sign = 0`` → +1
* ``sign = 1`` → -1

Parameters
----------
[1] Simmons, 2021 (arXiv:2109.05654).
tab : MatGF2
A one-dimensional GF(2) array of odd length ``2n + 1``
representing an n-qubit Pauli operator.

Returns
-------
PauliString
The Pauli operator encoded by ``tab``.

Raises
------
ValueError
If ``tab`` is not one-dimensional or ``len(tab)`` is even.

Examples
--------
>>> tab = MatGF2(np.array([1, 1, 1]))
>>> PauliString.from_tableau(tab)
PauliString(dim=1, axes={0: Axis.Y}, sign=Sign.MINUS)
>>> tab = MatGF2(np.array([0, 0, 0, 1, 0]))
>>> PauliString.from_tableau(tab)
PauliString(dim=2, axes={1: Axis.Z}, sign=Sign.PLUS)
"""
og = flow.og
dim = len(flow.og.output_nodes)
c_set = set(flow.correction_function[node])
odd_c_set = og.odd_neighbors(c_set)
inter_c_odd_set = c_set & odd_c_set
if tab.ndim != 1:
raise ValueError(
f"Attempted to initialise a PauliString from a {tab.ndim}-dimensional tableau. `PauliString.from_tableau` expects a one-dimensional array."
)
if len(tab) % 2 == 0:
raise ValueError(
f"`PauliString.from_tableau` expects an array with an odd number of elements (got {len(tab)})."
)

x_corrections = frozenset((c_set - odd_c_set).intersection(og.output_nodes))
y_corrections = frozenset(inter_c_odd_set.intersection(og.output_nodes))
z_corrections = frozenset((odd_c_set - c_set).intersection(og.output_nodes))
dim = len(tab) // 2
sign = Sign.minus_if(tab[-1])

# Sign computation.
negative_sign = False
axes: dict[int, Axis] = {}
for i, (x, z) in enumerate(zip(tab[:dim], tab[dim:-1], strict=True)):
if (x, z) == (1, 0):
axes[i] = Axis.X
elif (x, z) == (0, 1):
axes[i] = Axis.Z
elif (x, z) == (1, 1):
axes[i] = Axis.Y

# One phase flip per edge between adjacent vertices in the correction set.
negative_sign ^= og.graph.subgraph(c_set).number_of_edges() % 2 == 1
return PauliString(dim, axes, sign)

# One phase flip per two Ys in the graph state stabilizer.
negative_sign ^= bool(len(inter_c_odd_set) // 2 % 2)
def to_tableau(self) -> MatGF2:
"""Serialise this PauliString into a one-dimensional tableau representation.

# One phase flip per node in the graph state stabilizer that is absorbed from a Pauli measurement with angle π.
for n in c_set | odd_c_set:
meas = og.measurements.get(n, None)
if isinstance(meas, PauliMeasurement):
negative_sign ^= meas.sign == Sign.MINUS
Produces the inverse of :meth:`from_tableau`: a ``MatGF2`` of length
``2n + 1`` whose layout is::

# One phase flip if measured on the YZ plane.
negative_sign ^= flow.node_measurement_label(node) == Plane.YZ
[ x_0, x_1, …, x_{n-1} | z_0, z_1, …, z_{n-1} | sign ]

axes_dict: dict[int, Axis] = {}
output_to_qubit_mapping = NodeIndex()
output_to_qubit_mapping.extend(og.output_nodes)
Encoding conventions:

# Sets `x_corrections`, `y_corrections` and `z_corrections` are disjoint.
corrections = (x_corrections, y_corrections, z_corrections)
for correction, axis in zip(corrections, Axis, strict=True):
for cnode in correction:
qubit = output_to_qubit_mapping.index(cnode)
axes_dict[qubit] = axis
* X → ``(x=1, z=0)``
* Z → ``(x=0, z=1)``
* Y → ``(x=1, z=1)``
* I → ``(x=0, z=0)`` (absent in ``self.axes``)
* ``+`` sign → ``0``
* ``-`` sign → ``1``

return PauliString(dim, axes_dict, Sign.minus_if(negative_sign))
Returns
-------
MatGF2
A one-dimensional GF(2) array of length ``2 * self.dim + 1``.

Examples
--------
>>> ps = PauliString.from_str("-XY")
>>> ps.to_tableau()
MatGF2([1, 1, 0, 1, 1], dtype=uint8)
"""
tab = MatGF2(np.zeros(2 * self.dim + 1, dtype=np.uint8))

for i, ax in self.axes.items():
if ax in {Axis.X, Axis.Y}:
tab[i] = 1
if ax in {Axis.Y, Axis.Z}:
tab[i + self.dim] = 1

if self.sign is Sign.MINUS:
tab[2 * self.dim] = 1

return tab


@dataclass(frozen=True)
Expand Down Expand Up @@ -386,20 +493,72 @@ def to_tableau(self) -> MatGF2:
f"Isometries are not supported yet: # of inputs ({len(self.input_nodes)}) must be equal to the # of outputs ({len(self.output_nodes)})."
)

tab = MatGF2(np.zeros((2 * n, 2 * n + 1)))
return MatGF2(np.vstack((*(ps.to_tableau() for ps in self.x_map), *(ps.to_tableau() for ps in self.z_map))))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return MatGF2(np.vstack((*(ps.to_tableau() for ps in self.x_map), *(ps.to_tableau() for ps in self.z_map))))
return MatGF2(np.vstack([ps.to_tableau() for psmap in (self.x_map, self.z_map) for ps in psmap]))


for mapping, shift in (self.x_map, 0), (self.z_map, n):
for i, ps in enumerate(mapping): # Indices in the Clifford map correspond to qubits (0 to n-1).
for j, ax in ps.axes.items():
if ax in {Axis.X, Axis.Y}:
tab[i + shift, j] = 1
if ax in {Axis.Y, Axis.Z}:
tab[i + shift, j + n] = 1

if ps.sign is Sign.MINUS:
tab[i + shift, 2 * n] = 1
def extraction_ps_from_corrected_node(flow: PauliFlow[Measurement], node: Node) -> PauliString:
"""Extract the Pauli string of a measured node and its focused correction set.

return tab
Parameters
----------
flow : PauliFlow[Measurement]
A focused Pauli flow. The resulting Pauli string is extracted from its correction function.
node : int
A measured node whose associated Pauli string is computed.

Returns
-------
PauliString
Primary extraction string associated to the input measured nodes. The Pauli string is defined over qubit indices corresponding to positions in ``output_nodes``.

Notes
-----
See Eq. (13) and Lemma 4.4 in Ref. [1]. The phase of the Pauli string is given by Eq. (37).

References
----------
[1] Simmons, 2021 (arXiv:2109.05654).
"""
og = flow.og
dim = len(flow.og.output_nodes)
c_set = set(flow.correction_function[node])
odd_c_set = og.odd_neighbors(c_set)
inter_c_odd_set = c_set & odd_c_set

x_corrections = frozenset((c_set - odd_c_set).intersection(og.output_nodes))
y_corrections = frozenset(inter_c_odd_set.intersection(og.output_nodes))
z_corrections = frozenset((odd_c_set - c_set).intersection(og.output_nodes))

# Sign computation.
negative_sign = False

# One phase flip per edge between adjacent vertices in the correction set.
negative_sign ^= og.graph.subgraph(c_set).number_of_edges() % 2 == 1

# One phase flip per two Ys in the graph state stabilizer.
negative_sign ^= bool(len(inter_c_odd_set) // 2 % 2)

# One phase flip per node in the graph state stabilizer that is absorbed from a Pauli measurement with angle π.
for n in c_set | odd_c_set:
meas = og.measurements.get(n, None)
if isinstance(meas, PauliMeasurement):
negative_sign ^= meas.sign == Sign.MINUS

# One phase flip if measured on the YZ plane.
negative_sign ^= flow.node_measurement_label(node) == Plane.YZ

axes_dict: dict[int, Axis] = {}
output_to_qubit_mapping = NodeIndex()
output_to_qubit_mapping.extend(og.output_nodes)

# Sets `x_corrections`, `y_corrections` and `z_corrections` are disjoint.
corrections = (x_corrections, y_corrections, z_corrections)
for correction, axis in zip(corrections, Axis, strict=True):
for cnode in correction:
qubit = output_to_qubit_mapping.index(cnode)
axes_dict[qubit] = axis

return PauliString(dim, axes_dict, Sign.minus_if(negative_sign))


def extend_input(og: OpenGraph[Measurement]) -> tuple[OpenGraph[Measurement], dict[int, int]]:
Expand Down Expand Up @@ -512,6 +671,6 @@ def clifford_x_map_from_focused_flow(flow: PauliFlow[Measurement]) -> tuple[Paul
# In the context for `CliffordMap.from_focused_flow` the check is performed when accessing the cached property `flow.pauli_strings` in the function `clifford_z_map_from_focused_flow`.

# It's better to call the `PauliString` constructor instead of the cached property `flow_extended.pauli_strings` since the latter will compute a `PauliString` for _every_ node in the correction function and we just need it for the input nodes.
x_map_ancillas = {node: PauliString.from_measured_node(flow_extended, node) for node in og_extended.input_nodes}
x_map_ancillas = {node: extraction_ps_from_corrected_node(flow_extended, node) for node in og_extended.input_nodes}

return tuple(x_map_ancillas[ancillary_inputs_map[input_node]] for input_node in og.input_nodes)
10 changes: 8 additions & 2 deletions graphix/flow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
# `override` introduced in Python 3.12, `assert_never` introduced in Python 3.11
from typing_extensions import assert_never, override

from graphix.circ_ext.extraction import CliffordMap, ExtractionResult, PauliExponentialDAG, PauliString
from graphix.circ_ext.extraction import (
CliffordMap,
ExtractionResult,
PauliExponentialDAG,
extraction_ps_from_corrected_node,
)
from graphix.command import E, M, N, X, Z
from graphix.flow._find_gpflow import (
CorrectionMatrix,
Expand Down Expand Up @@ -51,6 +56,7 @@
# Unpack introduced in Python 3.12
from typing_extensions import Unpack

from graphix.circ_ext.extraction import PauliString
from graphix.opengraph import OpenGraph
from graphix.parameter import ExpressionOrSupportsFloat, Parameter
from graphix.pattern import Pattern
Expand Down Expand Up @@ -830,7 +836,7 @@ def extraction_pauli_strings(self: PauliFlow[Measurement]) -> dict[int, PauliStr
"""
if not self.is_focused():
raise ValueError("Flow is not focused.")
return {node: PauliString.from_measured_node(self, node) for node in self.correction_function}
return {node: extraction_ps_from_corrected_node(self, node) for node in self.correction_function}

def extract_circuit(self: PauliFlow[Measurement]) -> ExtractionResult:
"""Extract a circuit from a flow.
Expand Down
9 changes: 9 additions & 0 deletions tests/test_circ_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,15 @@ def test_parametric_angles(self, test_case: float, fx_rng: Generator) -> None:
assert s1.isclose(s2)


class TestPauliString:
@pytest.mark.parametrize("ps", ["+X", "-XIY", "+II", "+YXZZYX", "-IIIYXZII"])
def test_round_trip_conversions(self, ps: str) -> None:
tab = PauliString.from_str(ps).to_tableau()
ps_test = str(PauliString.from_tableau(tab))

assert ps == ps_test


def test_extend_input() -> None:
og = OpenGraph(
graph=nx.Graph([(1, 3), (2, 4), (3, 4), (3, 5), (4, 6)]),
Expand Down
Loading