diff --git a/CHANGELOG.md b/CHANGELOG.md index 9cdab9f23..587b29737 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,12 +11,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - #490: Introduced new `Instruction` and `Command` namespace classes for instruction and command instantiation. -- #476 Introduced new methods `OpenGraph.extract_circuit`, `CliffordMap.to_tableau` and new function `graphix.circ_ext.compilation.cm_berg_pass`. Circuit extraction can be done natively in Graphix. +- #476: Introduced new methods `OpenGraph.extract_circuit`, `CliffordMap.to_tableau` and new function `graphix.circ_ext.compilation.cm_berg_pass`. Circuit extraction can be done natively in Graphix. ### Fixed ### Changed +- #168, #498: `Pattern.remove_pauli_measurements` replaces `Pattern.perform_pauli_measurements`. The new algorithm removes all non-input Pauli nodes from patterns that have a flow and returns a pattern that is equivalent for every input state. + - Field `Pattern.results` and function `incorporate_pauli_results` removed. + - #490: Exposed more common classes and methods to top level `__init__.py`. - Renamed `Instruction`, `InstructionWithoutRZZ` and `Command` to `InstructionType`, `InstructionTypeWithoutRZZ` and `CommandType` respectively. - Moved `InstructionType`, `InstructionTypeWithoutRZZ`, `CommandType`, `Correction` and `CommandOrNoise` to `TYPE_CHECKING` blocks. diff --git a/README.md b/README.md index d40169b05..c44e006f2 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ pattern.draw_graph(flow_from_pattern=False) ### preprocessing Pauli measurements (Clifford gates) ```python -pattern.perform_pauli_measurements() +pattern.remove_pauli_measurements() pattern.draw_graph() ``` diff --git a/docs/source/intro.rst b/docs/source/intro.rst index c3570dd2e..3fa3686c8 100644 --- a/docs/source/intro.rst +++ b/docs/source/intro.rst @@ -124,16 +124,17 @@ which we can express by a long sequence, Note that the input state has *teleported* to qubits 6 and 7 after the computation. .. - We can inspect the graph state using :class:`~graphix.graphsim.GraphState` class: + We can inspect the graph state using :class:`~graphix.opengraph.OpenGraph` class: .. code-block:: python - from graphix import GraphState - g = GraphState(nodes=[0,1],edges=[(0,1)]) + from graphix import OpenGraph + import networkx as nx + og = OpenGraph(nx.Graph([0, 1]), output_nodes=[0, 1]) - >>> print(g.to_statevector()) - Statevec, data=[[ 0.5+0.j 0.5+0.j] - [ 0.5+0.j -0.5+0.j]], shape=(2, 2) + >>> print(og.to_pattern().simulate_pattern()) + Statevec object with statevector [[ 0.5+0.j 0.5+0.j] + [ 0.5+0.j -0.5+0.j]] and length (2, 2). diff --git a/docs/source/modifier.rst b/docs/source/modifier.rst index eee84b258..4fe4f5727 100644 --- a/docs/source/modifier.rst +++ b/docs/source/modifier.rst @@ -32,7 +32,9 @@ Pattern Manipulation .. automethod:: remove_input_nodes - .. automethod:: perform_pauli_measurements + .. automethod:: perform_pauli_pushing + + .. automethod:: remove_pauli_measurements .. automethod:: to_ascii @@ -83,4 +85,4 @@ Pattern Manipulation .. automethod:: to_bloch -.. autofunction:: measure_pauli +.. autofunction:: shift_outcomes diff --git a/docs/source/optimization.rst b/docs/source/optimization.rst index b2663c8ac..ee965a0d1 100644 --- a/docs/source/optimization.rst +++ b/docs/source/optimization.rst @@ -16,3 +16,13 @@ This module defines space minimization procedures for patterns. .. automodule:: graphix.space_minimization :members: + +:mod:`graphix.remove_pauli_measurements` module ++++++++++++++++++++++++++++++++++++++++++++++++ + +This module provides procedures for pushing Pauli measurements in +front of a pattern and for subsequently removing them from the +pattern. + +.. automodule:: graphix.remove_pauli_measurements + :members: diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index c44505bee..a828719e7 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -172,14 +172,12 @@ Performing Pauli measurements +++++++++++++++++++++++++++++ It is known that quantum circuit consisting of Pauli basis states, Clifford gates and Pauli measurements can be simulated classically (see `Gottesman-Knill theorem -`_; e.g. the graph state simulator runs in :math:`\mathcal{O}(n \log n)` time). -The Pauli measurement part of the MBQC is exactly this, and they can be preprocessed by our graph state simulator :class:`~graphix.graphsim.GraphState` - see :doc:`lc-mbqc` for more detailed description. +`_). -We can call this in a line by calling :meth:`~graphix.pattern.Pattern.remove_input_nodes` followed by :meth:`~graphix.pattern.Pattern.perform_pauli_measurements()` (both methods of the :class:`~graphix.pattern.Pattern` object). The first method removes the input nodes, while the second method optimizes the measurement pattern. +We can call :meth:`~graphix.pattern.Pattern.remove_pauli_measurements()` (method of the :class:`~graphix.pattern.Pattern` object) to optimize the measurement pattern. We get an updated measurement pattern without Pauli measurements as follows: ->>> pattern.remove_input_nodes() ->>> pattern.perform_pauli_measurements() +>>> pattern.remove_pauli_measurements() >>> pattern Pattern(input_nodes=[], cmds=[N(0), N(1), N(3), N(7), E((0, 3)), E((1, 3)), E((1, 7)), M(0, Plane.YZ, 0.2907266109187514), M(1, Plane.YZ, 0.01258854060311348), C(3, Clifford.I), C(7, Clifford.I), Z(3, {0, 1, 5}), Z(7, {1, 5}), X(3, {2}), X(7, {2, 4, 6})], output_nodes=[3, 7]) diff --git a/examples/deutsch_jozsa.py b/examples/deutsch_jozsa.py index e5f82d056..68a032846 100644 --- a/examples/deutsch_jozsa.py +++ b/examples/deutsch_jozsa.py @@ -73,8 +73,7 @@ # %% # Now we preprocess all Pauli measurements, which requires that we move inputs to N commands -pattern.remove_input_nodes() -pattern.perform_pauli_measurements() +pattern.remove_pauli_measurements() print( pattern.to_ascii( left_to_right=True, diff --git a/examples/fusion_extraction.py b/examples/fusion_extraction.py index 803edddf0..abc0a1746 100644 --- a/examples/fusion_extraction.py +++ b/examples/fusion_extraction.py @@ -18,19 +18,23 @@ import itertools -import graphix +import matplotlib.pyplot as plt +import networkx as nx + from graphix import extraction -from graphix.extraction import graph_to_fusion_network +from graphix.extraction import Graph, graph_to_fusion_network # %% # Here we say we want a graph state with 9 nodes and 12 edges. # We can obtain resource graph for a measurement pattern by using :code:`pattern.extract_graph()`. -gs = graphix.GraphState() +gs = Graph() nodes = [0, 1, 2, 3, 4, 5, 6, 7, 8] edges = [(0, 1), (1, 2), (2, 3), (3, 0), (3, 4), (0, 5), (4, 5), (5, 6), (6, 7), (7, 0), (7, 8), (8, 1)] gs.add_nodes_from(nodes) gs.add_edges_from(edges) -gs.draw() +labels = {i: i for i in iter(nodes)} +nx.draw(gs, labels=labels, node_color="C0", edgecolors="k") +plt.show() # %% # Decomposition with GHZ and linear cluster resource states with no limitation in their sizes. diff --git a/examples/mbqc_vqe.py b/examples/mbqc_vqe.py index fce62723a..9b690df9f 100644 --- a/examples/mbqc_vqe.py +++ b/examples/mbqc_vqe.py @@ -92,8 +92,7 @@ def build_mbqc_pattern(self, params: Iterable[ParameterizedAngle]) -> Pattern: pattern = circuit.transpile().pattern pattern.standardize() pattern.shift_signals() - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() # Perform Pauli measurements + pattern.remove_pauli_measurements() # Perform Pauli measurements return pattern # %% diff --git a/examples/qaoa.py b/examples/qaoa.py index 0e57f3b66..f0f2c1a42 100644 --- a/examples/qaoa.py +++ b/examples/qaoa.py @@ -40,8 +40,7 @@ # %% # perform Pauli measurements and plot the new (minimal) graph to perform the same quantum computation -pattern.remove_input_nodes() -pattern.perform_pauli_measurements() +pattern.remove_pauli_measurements() pattern.draw(flow_from_pattern=False) # %% diff --git a/examples/qft_with_tn.py b/examples/qft_with_tn.py index eb92a6b3e..8abb6ded5 100644 --- a/examples/qft_with_tn.py +++ b/examples/qft_with_tn.py @@ -64,10 +64,10 @@ def qft(circuit: Circuit, n: int) -> None: print(f"Number of edges: {len(graph.edges)}") # %% -# Using efficient graph state simulator `graphix.graphsim`, we can classically preprocess Pauli measurements. +# Using graph rewriting rules, we can classically preprocess Pauli measurements. # We are currently improving the speed of this process by using rust-based graph manipulation backend. pattern.remove_input_nodes() -pattern.perform_pauli_measurements() +pattern.remove_pauli_measurements(standardize=True) # %% diff --git a/examples/tn_simulation.py b/examples/tn_simulation.py index 8de62c998..63c061e1b 100644 --- a/examples/tn_simulation.py +++ b/examples/tn_simulation.py @@ -83,9 +83,9 @@ def ansatz( print(f"Number of edges: {len(graph.edges)}") # %% -# Optimizing by performing Pauli measurements in the pattern using efficient stabilizer simulator. +# Optimizing by removing Pauli measurements in the pattern. pattern.remove_input_nodes() -pattern.perform_pauli_measurements() +pattern.remove_pauli_measurements(standardize=True) # %% # Simulate using the TN backend of graphix, which will return an MBQCTensorNet object. @@ -203,7 +203,7 @@ def cost( pattern.standardize() pattern.shift_signals() pattern.remove_input_nodes() - pattern.perform_pauli_measurements() + pattern.remove_pauli_measurements(standardize=True) mbqc_tn = pattern.simulate_pattern(backend="tensornetwork", graph_prep="parallel") exp_val: float = 0 for op in ham: diff --git a/examples/visualization.py b/examples/visualization.py index 0aa642b2e..282cb6b91 100644 --- a/examples/visualization.py +++ b/examples/visualization.py @@ -37,8 +37,7 @@ # %% # next, show the gflow: -pattern.remove_input_nodes() -pattern.perform_pauli_measurements() +pattern.remove_pauli_measurements() pattern.draw(flow_from_pattern=False, measurement_labels=True) diff --git a/graphix/__init__.py b/graphix/__init__.py index 5b0e89eca..12f8da89c 100644 --- a/graphix/__init__.py +++ b/graphix/__init__.py @@ -4,6 +4,7 @@ from graphix._version import __version__ from graphix.branch_selector import ConstBranchSelector, FixedBranchSelector, RandomBranchSelector +from graphix.channels import KrausChannel from graphix.circ_ext import CliffordMap, PauliExponential, PauliExponentialDAG, PauliString from graphix.clifford import Clifford from graphix.command import Command diff --git a/graphix/command.py b/graphix/command.py index b1c5951d8..8dcad749e 100644 --- a/graphix/command.py +++ b/graphix/command.py @@ -14,6 +14,9 @@ from graphix.repr_mixins import DataclassReprMixin from graphix.states import BasicStates, State +if TYPE_CHECKING: + from collections.abc import Callable + Node: TypeAlias = int logger = logging.getLogger(__name__) @@ -141,6 +144,22 @@ def clifford(self, clifford_gate: Clifford) -> M: domains.t_domain, ) + def map(self, f: Callable[[Measurement], Measurement]) -> M: + """Return a measurement command where the function ``f`` has been applied to the measurement. + + Parameters + ---------- + f: Callable[[Measurement], Measurement] + Function applied to the measurement. + + Returns + ------- + M + The resulting command. + + """ + return M(self.node, f(self.measurement), self.s_domain, self.t_domain) + @dataclasses.dataclass(repr=False) class E(_KindChecker, BaseCommand): diff --git a/graphix/extraction.py b/graphix/extraction.py index 5dd400541..eea0dc32e 100644 --- a/graphix/extraction.py +++ b/graphix/extraction.py @@ -6,11 +6,17 @@ import dataclasses import operator from enum import Enum +from typing import TYPE_CHECKING import networkx as nx import numpy as np -from graphix.graphsim import GraphState +if TYPE_CHECKING: + from typing import TypeAlias + + Graph: TypeAlias = nx.Graph[int] +else: + Graph = nx.Graph class ResourceType(Enum): @@ -33,12 +39,12 @@ class ResourceGraph: ---------- cltype : :class:`ResourceType` object Type of the cluster. - graph : :class:`~graphix.graphsim.GraphState` object + graph : :class:`Graph` object Graph state of the cluster. """ cltype: ResourceType - graph: GraphState + graph: Graph def __eq__(self, other: object) -> bool: """Return `True` if two resource graphs are equal, `False` otherwise.""" @@ -49,11 +55,11 @@ def __eq__(self, other: object) -> bool: def graph_to_fusion_network( - graph: GraphState, + graph: Graph, max_ghz: float = np.inf, max_lin: float = np.inf, ) -> list[ResourceGraph]: - """Extract GHZ and linear cluster graph state decomposition of desired resource state :class:`~graphix.graphsim.GraphState`. + """Extract GHZ and linear cluster graph state decomposition of desired resource state :class:`Graph`. Extraction algorithm is based on [1]. @@ -61,7 +67,7 @@ def graph_to_fusion_network( Parameters ---------- - graph : :class:`~graphix.graphsim.GraphState` object + graph : :class:`Graph` object Graph state. phasedict : dict Dictionary of phases for each node. @@ -161,7 +167,7 @@ def create_resource_graph(node_ids: list[int], root: int | None = None) -> Resou else: edges = [(node_ids[i], node_ids[i + 1]) for i in range(len(node_ids)) if i + 1 < len(node_ids)] cluster_type = ResourceType.LINEAR - tmp_graph = GraphState() + tmp_graph = Graph() tmp_graph.add_nodes_from(node_ids) tmp_graph.add_edges_from(edges) return ResourceGraph(cltype=cluster_type, graph=tmp_graph) diff --git a/graphix/optimization.py b/graphix/optimization.py index 5b23d51c1..c04bea650 100644 --- a/graphix/optimization.py +++ b/graphix/optimization.py @@ -13,8 +13,6 @@ import networkx as nx # assert_never added in Python 3.11 -from typing_extensions import assert_never - from graphix import command from graphix.clifford import Clifford, Domains from graphix.command import CommandKind, Node @@ -25,7 +23,7 @@ FlowGenericErrorReason, ) from graphix.fundamentals import Axis, Plane, Sign -from graphix.measurements import BlochMeasurement, Measurement, Outcome, PauliMeasurement +from graphix.measurements import BlochMeasurement, Measurement, PauliMeasurement from graphix.opengraph import OpenGraph from graphix.space_minimization import ( minimize_space, @@ -86,13 +84,12 @@ class _StandardizedPattern: input_nodes: tuple[Node, ...] output_nodes: tuple[Node, ...] - results: Mapping[Node, Outcome] n_list: tuple[command.N, ...] e_set: frozenset[frozenset[Node]] m_list: tuple[command.M, ...] - c_dict: Mapping[Node, Clifford] z_dict: Mapping[Node, frozenset[Node]] x_dict: Mapping[Node, frozenset[Node]] + c_dict: Mapping[Node, Clifford] class StandardizedPattern(_StandardizedPattern): @@ -119,20 +116,18 @@ class StandardizedPattern(_StandardizedPattern): Input nodes. output_nodes: tuple[Node, ...] Output nodes. - results: Mapping[Node, Outcome] - Already measured nodes (by Pauli presimulation). n_list: tuple[command.N] The N commands. e_set: frozenset[frozenset[Node]] Set of edges. Each edge is a set with two elements. m_list: tuple[command.M] The M commands. - c_dict: Mapping[Node, Clifford] - Mapping associating Clifford corrections to some nodes. z_dict: Mapping[Node, frozenset[Node]] Mapping associating Z-domains to some nodes. x_dict: Mapping[Node, frozenset[Node]] Mapping associating X-domains to some nodes. + c_dict: Mapping[Node, Clifford] + Mapping associating Clifford corrections to some nodes. """ @@ -140,25 +135,23 @@ def __init__( self, input_nodes: Iterable[Node], output_nodes: Iterable[Node], - results: Mapping[Node, Outcome], n_list: Iterable[command.N], e_set: Iterable[Iterable[Node]], m_list: Iterable[command.M], - c_dict: Mapping[Node, Clifford], z_dict: Mapping[Node, Iterable[Node]], x_dict: Mapping[Node, Iterable[Node]], + c_dict: Mapping[Node, Clifford], ) -> None: """Return a new StandardizedPattern with immutable data structures.""" super().__init__( tuple(input_nodes), tuple(output_nodes), - MappingProxyType(dict(results)), tuple(n_list), frozenset(frozenset(edge) for edge in e_set), tuple(m_list), - MappingProxyType(dict(c_dict)), MappingProxyType({node: frozenset(nodes) for node, nodes in z_dict.items()}), MappingProxyType({node: frozenset(nodes) for node, nodes in x_dict.items()}), + MappingProxyType(dict(c_dict)), ) @classmethod @@ -172,9 +165,9 @@ def from_pattern(cls, pattern: Pattern) -> Self: n_list: list[command.N] = [] e_set: set[frozenset[Node]] = set() m_list: list[command.M] = [] - c_dict: dict[Node, Clifford] = {} z_dict: dict[Node, set[Node]] = {} x_dict: dict[Node, set[Node]] = {} + c_dict: dict[Node, Clifford] = {} # Standardization could turn non-runnable patterns into # runnable ones, so we check runnability first to avoid hiding @@ -231,9 +224,7 @@ def from_pattern(cls, pattern: Pattern) -> Self: # has been already applied to a node, applying a clifford `C'` to the same # node is equivalent to apply `C'C` to a fresh node. c_dict[cmd.node] = cmd.clifford @ c_dict.get(cmd.node, Clifford.I) - return cls( - pattern.input_nodes, pattern.output_nodes, pattern.results, n_list, e_set, m_list, c_dict, z_dict, x_dict - ) + return cls(pattern.input_nodes, pattern.output_nodes, n_list, e_set, m_list, z_dict, x_dict, c_dict) def extract_graph(self) -> nx.Graph[int]: """Return the graph state from the command sequence, extracted from 'N' and 'E' commands. @@ -250,9 +241,15 @@ def extract_graph(self) -> nx.Graph[int]: graph.add_edge(u, v) return graph - def perform_pauli_pushing(self, leave_nodes: AbstractSet[Node] | None = None, *, stacklevel: int = 1) -> Self: + def perform_pauli_pushing( + self, leave_nodes: AbstractSet[Node] | None = None, *, stacklevel: int = 1 + ) -> StandardizedPattern: """Move Pauli measurements before the other measurements. + If you need to recover the cut between Pauli measurements and + non-Pauli measurements or the shifted signal, you can use + :meth:`~graphix.remove_pauli_measurements.PauliPushingCut.from_standardized_pattern` instead. + Parameters ---------- leave_nodes : AbstractSet[Node], optional @@ -264,93 +261,15 @@ def perform_pauli_pushing(self, leave_nodes: AbstractSet[Node] | None = None, *, Returns ------- - Pattern + StandardizedPattern The pattern in which Pauli measurements have been moved before the other measurements. """ - self._warn_non_inferred_pauli_measurements(stacklevel=stacklevel + 1) - - if leave_nodes: - leave_non_pauli_nodes = [ - cmd.node - for cmd in self.m_list - if not isinstance(cmd.measurement, PauliMeasurement) and cmd.node in leave_nodes - ] - if leave_non_pauli_nodes: - warn( - f"`leave_nodes` contains nodes that are not Pauli: {leave_non_pauli_nodes}. The constraint has no effect on these nodes.", - stacklevel=stacklevel + 1, - ) + from graphix.remove_pauli_measurements import PauliPushingCut # noqa: PLC0415 - shift_domains: dict[int, set[int]] = {} - - def expand_domain(domain: AbstractSet[int]) -> set[int]: - """Merge previously shifted domains into ``domain``. - - Parameters - ---------- - domain : set[int] - Domain to update with any accumulated shift information. - """ - new_domain = set(domain) - for node in domain & shift_domains.keys(): - new_domain ^= shift_domains[node] - return new_domain - - pauli_list = [] - non_pauli_list = [] - for cmd in self.m_list: - s_domain = expand_domain(cmd.s_domain) - t_domain = expand_domain(cmd.t_domain) - if not isinstance(cmd.measurement, PauliMeasurement) or (leave_nodes and cmd.node in leave_nodes): - non_pauli_list.append( - command.M(node=cmd.node, measurement=cmd.measurement, s_domain=s_domain, t_domain=t_domain) - ) - else: - match cmd.measurement.axis: - case Axis.X: - # M^X X^s Z^t = M^{XY,0} X^s Z^t - # = M^{XY,(-1)^s·0+tπ} - # = S^t M^X - # M^{-X} X^s Z^t = M^{XY,π} X^s Z^t - # = M^{XY,(-1)^s·π+tπ} - # = S^t M^{-X} - shift_domains[cmd.node] = t_domain - case Axis.Y: - # M^Y X^s Z^t = M^{XY,π/2} X^s Z^t - # = M^{XY,(-1)^s·π/2+tπ} - # = M^{XY,π/2+(s+t)π} (since -π/2 = π/2 - π ≡ π/2 + π (mod 2π)) - # = S^{s+t} M^Y - # M^{-Y} X^s Z^t = M^{XY,-π/2} X^s Z^t - # = M^{XY,(-1)^s·(-π/2)+tπ} - # = M^{XY,-π/2+(s+t)π} (since π/2 = -π/2 + π) - # = S^{s+t} M^{-Y} - shift_domains[cmd.node] = s_domain ^ t_domain - case Axis.Z: - # M^Z X^s Z^t = M^{XZ,0} X^s Z^t - # = M^{XZ,(-1)^t((-1)^s·0+sπ)} - # = M^{XZ,(-1)^t·sπ} - # = M^{XZ,sπ} (since (-1)^t·π ≡ π (mod 2π)) - # = S^s M^Z - # M^{-Z} X^s Z^t = M^{XZ,π} X^s Z^t - # = M^{XZ,(-1)^t((-1)^s·π+sπ)} - # = M^{XZ,(s+1)π} - # = S^s M^{-Z} - shift_domains[cmd.node] = s_domain - case _: - assert_never(cmd.measurement.axis) - pauli_list.append(command.M(node=cmd.node, measurement=cmd.measurement)) - return self.__class__( - self.input_nodes, - self.output_nodes, - self.results, - self.n_list, - self.e_set, - pauli_list + non_pauli_list, - self.c_dict, - {node: expand_domain(domain) for node, domain in self.z_dict.items()}, - {node: expand_domain(domain) for node, domain in self.x_dict.items()}, - ) + return PauliPushingCut.from_standardized_pattern( + self, leave_nodes, stacklevel=stacklevel + 1 + ).to_standardized_pattern() def max_space(self) -> int: """Compute the maximum number of nodes that must be present in the graph (graph space) during the execution of the space-optimal pattern for the given measurement order. @@ -396,7 +315,6 @@ def to_pattern(self) -> Pattern: from graphix.pattern import Pattern # noqa: PLC0415 pattern = Pattern(input_nodes=self.input_nodes) - pattern.results = dict(self.results) pattern.extend( self.n_list, (command.E((u, v)) for u, v in self.e_set), @@ -470,8 +388,7 @@ def extract_partial_order_layers(self) -> tuple[frozenset[int], ...]: - There cannot be any empty layers. """ oset = frozenset(self.output_nodes) # First layer by convention. - pre_measured_nodes = self.results.keys() # Not included in the partial order layers. - excluded_nodes = oset | pre_measured_nodes + excluded_nodes = oset zero_indegree = set(self.input_nodes).union(n.node for n in self.n_list) - excluded_nodes dag: dict[int, set[int]] = { @@ -533,7 +450,6 @@ def extract_causal_flow(self) -> CausalFlow[BlochMeasurement]: In general, there may exist various layerings which represent the corrections of the pattern. To ensure that a given layering is compatible with the pattern's induced correction function, the partial order must be extracted from a standardized pattern. Commutation of entanglement commands with X and Z corrections in the standardization procedure may generate new corrections, which guarantees that all the topological information of the underlying graph is encoded in the extracted partial order. """ correction_function: dict[int, set[int]] = defaultdict(set) - pre_measured_nodes = self.results.keys() # Not included in the flow. for m in self.m_list: try: @@ -544,10 +460,10 @@ def extract_causal_flow(self) -> CausalFlow[BlochMeasurement]: valid = bloch.plane == Plane.XY if not valid: raise FlowGenericError(FlowGenericErrorReason.XYPlane) - _update_corrections(m.node, m.s_domain - pre_measured_nodes, correction_function) + _update_corrections(m.node, m.s_domain, correction_function) for node, domain in self.x_dict.items(): - _update_corrections(node, domain - pre_measured_nodes, correction_function) + _update_corrections(node, domain, correction_function) og = ( self.extract_opengraph() @@ -584,16 +500,15 @@ def extract_gflow(self) -> GFlow[BlochMeasurement]: The notes provided in :func:`self.extract_causal_flow` apply here as well. """ correction_function: dict[int, set[int]] = {} - pre_measured_nodes = self.results.keys() # Not included in the flow. for m in self.m_list: # Raises a `TypeError` if the measurement is not represented as a Bloch measurement if m.measurement.downcast_bloch().plane in {Plane.XZ, Plane.YZ}: correction_function.setdefault(m.node, set()).add(m.node) - _update_corrections(m.node, m.s_domain - pre_measured_nodes, correction_function) + _update_corrections(m.node, m.s_domain, correction_function) for node, domain in self.x_dict.items(): - _update_corrections(node, domain - pre_measured_nodes, correction_function) + _update_corrections(node, domain, correction_function) og = ( self.extract_opengraph() @@ -623,17 +538,15 @@ def extract_xzcorrections(self) -> XZCorrections[Measurement]: x_corr: dict[int, set[int]] = {} z_corr: dict[int, set[int]] = {} - pre_measured_nodes = self.results.keys() # Not included in the xz-corrections. - for m in self.m_list: - _update_corrections(m.node, m.s_domain - pre_measured_nodes, x_corr) - _update_corrections(m.node, m.t_domain - pre_measured_nodes, z_corr) + _update_corrections(m.node, m.s_domain, x_corr) + _update_corrections(m.node, m.t_domain, z_corr) for node, domain in self.x_dict.items(): - _update_corrections(node, domain - pre_measured_nodes, x_corr) + _update_corrections(node, domain, x_corr) for node, domain in self.z_dict.items(): - _update_corrections(node, domain - pre_measured_nodes, z_corr) + _update_corrections(node, domain, z_corr) og = ( self.extract_opengraph() @@ -643,6 +556,35 @@ def extract_xzcorrections(self) -> XZCorrections[Measurement]: og, x_corr, z_corr ) # Raises a `XZCorrectionsError` if the input dictionaries are not well formed. + def map(self, f: Callable[[Measurement], Measurement]) -> StandardizedPattern: + """Return a pattern where the function ``f`` has been applied to each measurement. + + Parameters + ---------- + f: Callable[[Measurement], Measurement] + Function applied to each measurement. + + Returns + ------- + StandardizedPattern + The resulting pattern. + """ + m_list = tuple(cmd_m.map(f) for cmd_m in self.m_list) + return StandardizedPattern( + self.input_nodes, + self.output_nodes, + self.n_list, + self.e_set, + m_list, + self.z_dict, + self.x_dict, + self.c_dict, + ) + + def to_bloch(self) -> StandardizedPattern: + """Return an equivalent pattern in which all measurements are represented as Bloch measurements.""" + return self.map(lambda m: m.to_bloch()) + def _warn_non_inferred_pauli_measurements(self, stacklevel: int) -> None: for m in self.m_list: if isinstance(m.measurement, BlochMeasurement) and m.measurement.try_to_pauli() is not None: @@ -695,16 +637,6 @@ def _commute_clifford(clifford_gate: Clifford, c_dict: dict[int, Clifford], i: i ) -def _incorporate_pauli_results_in_domain( - results: Mapping[int, int], domain: AbstractSet[int] -) -> tuple[bool, set[int]] | None: - if not (results.keys() & domain): - return None - new_domain = set(domain - results.keys()) - odd_outcome = sum(outcome for node, outcome in results.items() if node in domain) % 2 - return odd_outcome == 1, new_domain - - def _update_corrections(node: Node, domain: AbstractSet[Node], correction: dict[Node, set[Node]]) -> None: """Update the correction mapping by adding a node to all entries in a domain. @@ -726,59 +658,11 @@ def _update_corrections(node: Node, domain: AbstractSet[Node], correction: dict[ correction.setdefault(measured_node, set()).add(node) -def incorporate_pauli_results(pattern: Pattern) -> Pattern: - """Return an equivalent pattern where results from Pauli presimulation are integrated in corrections.""" - from graphix.pattern import Pattern # noqa: PLC0415 - - result = Pattern(input_nodes=pattern.input_nodes) - for cmd in pattern: - match cmd.kind: - case CommandKind.M: - s = _incorporate_pauli_results_in_domain(pattern.results, cmd.s_domain) - t = _incorporate_pauli_results_in_domain(pattern.results, cmd.t_domain) - if s or t: - if s: - apply_x, new_s_domain = s - else: - apply_x = False - new_s_domain = cmd.s_domain - if t: - apply_z, new_t_domain = t - else: - apply_z = False - new_t_domain = cmd.t_domain - new_cmd = command.M(cmd.node, cmd.measurement, new_s_domain, new_t_domain) - if apply_x: - new_cmd = new_cmd.clifford(Clifford.X) - if apply_z: - new_cmd = new_cmd.clifford(Clifford.Z) - result.add(new_cmd) - else: - result.add(cmd) - case CommandKind.X | CommandKind.Z: - signal = _incorporate_pauli_results_in_domain(pattern.results, cmd.domain) - if signal: - apply_c, new_domain = signal - if new_domain: - cmd_cstr = command.X if cmd.kind == CommandKind.X else command.Z - result.add(cmd_cstr(cmd.node, new_domain)) - if apply_c: - c = Clifford.X if cmd.kind == CommandKind.X else Clifford.Z - result.add(command.C(cmd.node, c)) - else: - result.add(cmd) - case _: - result.add(cmd) - result.reorder_output_nodes(pattern.output_nodes) - return result - - def remove_useless_domains(pattern: Pattern) -> Pattern: """Return an equivalent pattern where measurement domains that are not used given the specific measurement angles and planes are removed.""" from graphix.pattern import Pattern # noqa: PLC0415 new_pattern = Pattern(input_nodes=pattern.input_nodes) - new_pattern.results = pattern.results for cmd in pattern: if cmd.kind == CommandKind.M: match cmd.measurement: @@ -800,7 +684,6 @@ def single_qubit_domains(pattern: Pattern) -> Pattern: from graphix.pattern import Pattern # noqa: PLC0415 new_pattern = Pattern(input_nodes=pattern.input_nodes) - new_pattern.results = pattern.results def decompose_domain( cmd: Callable[[int, set[int]], command.CommandType], node: int, domain: AbstractSet[int] diff --git a/graphix/pattern.py b/graphix/pattern.py index c28bf7555..bb502bb3a 100644 --- a/graphix/pattern.py +++ b/graphix/pattern.py @@ -21,12 +21,10 @@ from typing_extensions import assert_never from graphix import command, optimization -from graphix.clifford import Clifford from graphix.command import CommandKind, Node from graphix.flow.exceptions import FlowError -from graphix.fundamentals import Axis, Plane, Sign -from graphix.graphsim import GraphState -from graphix.measurements import BlochMeasurement, Measurement, Outcome, PauliMeasurement, toggle_outcome +from graphix.fundamentals import Plane +from graphix.measurements import BlochMeasurement, Measurement, Outcome, toggle_outcome from graphix.opengraph import OpenGraph from graphix.pretty_print import OutputFormat, pattern_to_str from graphix.qasm3_exporter import pattern_to_qasm3_lines @@ -46,9 +44,9 @@ # Unpack introduced in Python 3.12 from typing_extensions import Unpack + from graphix.clifford import Clifford from graphix.command import CommandType from graphix.flow.core import CausalFlow, GFlow, PauliFlow, XZCorrections - from graphix.optimization import StandardizedPattern from graphix.parameter import ExpressionOrSupportsComplex, ExpressionOrSupportsFloat, Parameter from graphix.sim import Backend, Data, DensityMatrixBackend, StatevectorBackend from graphix.sim.base_backend import _StateT_co @@ -77,27 +75,8 @@ class Pattern: efficiency of the pattern accoring to measurement calculus. ref: V. Danos, E. Kashefi and P. Panangaden. J. ACM 54.2 8 (2007) - - Attributes - ---------- - list(self) : - list of commands. - - .. line-block:: - each command is a list [type, nodes, attr] which will be applied in the order of list indices. - type: one of {'N', 'M', 'E', 'X', 'Z', 'S', 'C'} - nodes: int for {'N', 'M', 'X', 'Z', 'S', 'C'} commands, tuple (i, j) for {'E'} command - attr for N: none - attr for M: meas_plane, angle, s_domain, t_domain - attr for X: signal_domain - attr for Z: signal_domain - attr for S: signal_domain - attr for C: clifford_index, as defined in :py:mod:`graphix.clifford` - n_node : int - total number of nodes in the resource state """ - results: dict[int, Outcome] __seq: list[CommandType] def __init__( @@ -118,7 +97,6 @@ def __init__( output_nodes : Iterable[int] | None Optional. List of output qubits. """ - self.results = {} # measurement results from the graph state simulator if input_nodes is None: self.__input_nodes = [] else: @@ -226,8 +204,8 @@ def compose( - Input (and, respectively, output) nodes in the returned pattern have the order of the pattern ``self`` followed by those of the pattern ``other``. Merged nodes are removed. - If ``preserve_mapping = True`` and :math:`|M_1| = |I_2| = |O_2|`, then the outputs of the returned pattern are the outputs of pattern ``self``, where the nth merged output is replaced by the output of pattern ``other`` corresponding to its nth input instead. """ - nodes_p1 = self.extract_nodes() | self.results.keys() # Results contain preprocessed Pauli nodes - nodes_p2 = other.extract_nodes() | other.results.keys() + nodes_p1 = self.extract_nodes() # Results contain preprocessed Pauli nodes + nodes_p2 = other.extract_nodes() if not mapping.keys() <= nodes_p2: raise PatternError("Keys of `mapping` must correspond to the nodes of `other`.") @@ -265,7 +243,6 @@ def compose( mapped_inputs = [mapping_complete[n] for n in other.input_nodes] mapped_outputs = [mapping_complete[n] for n in other.output_nodes] - mapped_results: dict[int, Outcome] = {mapping_complete[n]: m for n, m in other.results.items()} merged = mapping_values_set.intersection(self.__output_nodes) @@ -306,9 +283,7 @@ def update_command(cmd: CommandType) -> CommandType: seq = self.__seq + [update_command(c) for c in other] - results: dict[int, Outcome] = {**self.results, **mapped_results} p = Pattern(input_nodes=inputs, output_nodes=outputs, cmds=seq) - p.results = results return p, mapping_complete @@ -386,7 +361,6 @@ def __eq__(self, other: object) -> bool: self.__seq == other.__seq and self.__input_nodes == other.__input_nodes and self.__output_nodes == other.__output_nodes - and self.results == other.results ) def to_ascii( @@ -1428,28 +1402,6 @@ def remove_input_nodes(self) -> None: empty_nodes: list[int] = [] self.__input_nodes = empty_nodes - def perform_pauli_measurements(self, ignore_pauli_with_deps: bool = False, *, stacklevel: int = 1) -> None: - """Perform Pauli measurements in the pattern using efficient stabilizer simulator. - - Parameters - ---------- - ignore_pauli_with_deps : bool - Optional (*False* by default). - If *True*, Pauli measurements with domains depending on other measures are preserved as-is in the pattern. - If *False*, all Pauli measurements are preprocessed. Formally, measurements are swapped so that all Pauli measurements are applied first, and domains are updated accordingly. - stacklevel : int, optional - Stack level to use for warnings. Defaults to 1, meaning that warnings - are reported at this function's call site. - - .. seealso:: :func:`measure_pauli` - - """ - if self.input_nodes: - raise PatternError("Remove inputs with `self.remove_input_nodes()` before performing Pauli presimulation.") - self.__dict__.update( - measure_pauli(self, ignore_pauli_with_deps=ignore_pauli_with_deps, stacklevel=stacklevel + 1).__dict__ - ) - def _warn_non_inferred_pauli_measurements(self, stacklevel: int) -> None: for cmd in self: if ( @@ -1591,7 +1543,6 @@ def copy(self) -> Pattern: result.__input_nodes = self.__input_nodes.copy() result.__output_nodes = self.__output_nodes.copy() result.__n_node = self.__n_node - result.results = self.results.copy() return result def check_runnability(self) -> None: @@ -1609,7 +1560,7 @@ def check_runnability(self) -> None: have hidden domains that cannot be checked. """ active = set(self.input_nodes) - measured = set(self.results) + measured = set() def check_active(cmd: CommandType, node: int) -> None: if node in measured: @@ -1678,11 +1629,10 @@ def map(self, f: Callable[[Measurement], Measurement]) -> Pattern: Pattern(input_nodes=[0], cmds=[M(0, Measurement.XZ(1.25))]) """ new_pattern = Pattern(input_nodes=self.input_nodes) - new_pattern.results = self.results for cmd in self: if cmd.kind == CommandKind.M: - new_pattern.add(command.M(cmd.node, f(cmd.measurement), cmd.s_domain, cmd.t_domain)) + new_pattern.add(cmd.map(f)) else: new_pattern.add(cmd) @@ -1739,6 +1689,10 @@ def perform_pauli_pushing( ) -> Pattern: """Move Pauli measurements before the other measurements. + If you need to recover the cut between Pauli measurements and + non-Pauli measurements or the shifted signal, you can use + :meth:`~graphix.remove_pauli_measurements.PauliPushingCut.from_standardized_pattern` instead. + Parameters ---------- leave_nodes : AbstractSet[Node], optional @@ -1776,6 +1730,46 @@ def perform_pauli_pushing( self.__seq = pattern.__seq return self + def remove_pauli_measurements( + self, *, copy: bool = False, standardize: bool = False, stacklevel: int = 1 + ) -> Pattern: + """Remove non-input Pauli measurements from the given pattern. + + See :func:`~remove_pauli_measurements.remove_pauli_measurements` for more information. + + Parameters + ---------- + copy : bool, optional + If ``True``, the current pattern remains unchanged and a + new pattern is returned. The default is ``False``, meaning + that changes are performed in place. + standardize: bool, optional + If ``True``, the pattern is returned in standardized form. + The default is ``False``: the nodes are prepared on a + need-by-need basis, minimizing space usage. + stacklevel : int, optional + Stack level to use for warnings. Defaults to 1, meaning that warnings + are reported at this function's call site. + + Returns + ------- + Pattern + The pattern in which Pauli measurements have been moved + before the other measurements. If ``copy`` is ``False``, + the result is ``self``. + """ + from graphix.remove_pauli_measurements import PauliPushingCut, remove_pauli_measurements # noqa: PLC0415 + + standardized_pattern = optimization.StandardizedPattern.from_pattern(self) + cut = PauliPushingCut.from_standardized_pattern(standardized_pattern, stacklevel=stacklevel + 1) + standardized_pattern = remove_pauli_measurements(cut) + pattern = standardized_pattern.to_pattern() if standardize else standardized_pattern.to_space_optimal_pattern() + if copy: + return pattern + self.__seq = pattern.__seq + self.__output_nodes = pattern.__output_nodes + return self + class PatternError(Exception): """Exception subclass to handle pattern errors.""" @@ -1825,158 +1819,6 @@ def __str__(self) -> str: assert_never(self.reason) -def measure_pauli(pattern: Pattern, *, ignore_pauli_with_deps: bool = False, stacklevel: int = 1) -> Pattern: - """Perform Pauli measurement of a pattern by fast graph state simulator. - - Uses the decorated-graph method implemented in graphix.graphsim to perform the measurements in Pauli bases, and then sort remaining nodes back into - pattern together with Clifford commands. Users are required to ensure there are no input nodes with :func:`graphix.pattern.Pattern.remove_input_nodes` before using this function. - - TODO: non-XY plane measurements in original pattern - - Parameters - ---------- - pattern : graphix.pattern.Pattern object - ignore_pauli_with_deps : bool - Optional (*False* by default). - If *True*, Pauli measurements with domains depending on other measures are preserved as-is in the pattern. - If *False*, all Pauli measurements are preprocessed. Formally, measurements are swapped so that all Pauli measurements are applied first, and domains are updated accordingly. - stacklevel : int, optional - Stack level to use for warnings. Defaults to 1, meaning that warnings - are reported at this function's call site. - - Returns - ------- - new_pattern : graphix.Pattern object - pattern with Pauli measurement removed. - only returned if copy argument is True. - - - .. seealso:: :class:`graphix.pattern.Pattern.remove_input_nodes` - .. seealso:: :class:`graphix.graphsim.GraphState` - """ - pattern._warn_non_inferred_pauli_measurements(stacklevel=stacklevel + 1) - pat = Pattern() - standardized_pattern = optimization.StandardizedPattern.from_pattern(pattern) - if not ignore_pauli_with_deps: - standardized_pattern = standardized_pattern.perform_pauli_pushing(stacklevel=stacklevel + 1) - output_nodes = set(pattern.output_nodes) - graph = standardized_pattern.extract_graph() - graph_state = GraphState(nodes=graph.nodes, edges=graph.edges, vops=standardized_pattern.c_dict) - results: dict[int, Outcome] = pattern.results - to_measure, non_pauli_meas = pauli_nodes(standardized_pattern) - if not to_measure: - return pattern - for cmd in to_measure: - pattern_cmd = cmd[0] - measurement_basis = cmd[1] - # extract signals for adaptive angle. - s_signal = 0 - t_signal = 0 - match measurement_basis.axis: - case Axis.X: # X measurement is not affected by s_signal - t_signal = sum(results[j] for j in pattern_cmd.t_domain) - case Axis.Y: - s_signal = sum(results[j] for j in pattern_cmd.s_domain) - t_signal = sum(results[j] for j in pattern_cmd.t_domain) - case Axis.Z: # Z measurement is not affected by t_signal - s_signal = sum(results[j] for j in pattern_cmd.s_domain) - case _: - assert_never(measurement_basis.axis) - - if int(s_signal % 2) == 1: # equivalent to X byproduct - graph_state.h(pattern_cmd.node) - graph_state.z(pattern_cmd.node) - graph_state.h(pattern_cmd.node) - if int(t_signal % 2) == 1: # equivalent to Z byproduct - graph_state.z(pattern_cmd.node) - basis = measurement_basis - match basis.axis: - case Axis.X: - measure = graph_state.measure_x - case Axis.Y: - measure = graph_state.measure_y - case Axis.Z: - measure = graph_state.measure_z - case _: - assert_never(basis.axis) - if basis.sign == Sign.PLUS: - results[pattern_cmd.node] = measure(pattern_cmd.node, choice=0) - else: - results[pattern_cmd.node] = 0 if measure(pattern_cmd.node, choice=1) else 1 - - # measure (remove) isolated nodes. if they aren't Pauli measurements, - # measuring one of the results with probability of 1 should not occur as was possible above for Pauli measurements, - # which means we can just choose s=0. We should not remove output nodes even if isolated. - isolates = graph_state.isolated_nodes() - for node in non_pauli_meas: - if (node in isolates) and (node not in output_nodes): - graph_state.remove_node(node) - results[node] = 0 - - # update command sequence - vops = graph_state.extract_vops() - new_seq: list[CommandType] = [] - new_seq.extend(command.N(node=index) for index in set(graph_state.nodes)) - new_seq.extend(command.E(nodes=edge) for edge in graph_state.edges) - new_seq.extend( - cmd.clifford(Clifford(vops[cmd.node])) for cmd in standardized_pattern.m_list if cmd.node in graph_state.nodes - ) - new_seq.extend( - command.C(node=index, clifford=Clifford(vops[index])) - for index in pattern.output_nodes - if vops[index] != Clifford.I - ) - new_seq.extend(command.Z(node=node, domain=set(domain)) for node, domain in standardized_pattern.z_dict.items()) - new_seq.extend(command.X(node=node, domain=set(domain)) for node, domain in standardized_pattern.x_dict.items()) - pat.replace(new_seq, input_nodes=[]) - pat.reorder_output_nodes(standardized_pattern.output_nodes) - assert pat.n_node == len(graph_state.nodes) - pat.results = results - return pat - - -def pauli_nodes(pattern: StandardizedPattern) -> tuple[list[tuple[command.M, PauliMeasurement]], set[int]]: - """Return the list of measurement commands that are in Pauli bases and that are not dependent on any non-Pauli measurements. - - Parameters - ---------- - pattern : optimization.StandardizedPattern - - Returns - ------- - pauli_node : list - list of measures - non_pauli_nodes : set[int] - """ - pauli_node: list[tuple[command.M, PauliMeasurement]] = [] - # Nodes that are non-Pauli measured, or pauli measured but depends on pauli measurement - non_pauli_node: set[int] = set() - for cmd in pattern.m_list: - if isinstance(cmd.measurement, PauliMeasurement): - # Pauli measurement to be removed - match cmd.measurement.axis: - case Axis.X: - if cmd.t_domain & non_pauli_node: # cmd depend on non-Pauli measurement - non_pauli_node.add(cmd.node) - else: - pauli_node.append((cmd, cmd.measurement)) - case Axis.Y: - if (cmd.s_domain | cmd.t_domain) & non_pauli_node: # cmd depend on non-Pauli measurement - non_pauli_node.add(cmd.node) - else: - pauli_node.append((cmd, cmd.measurement)) - case Axis.Z: - if cmd.s_domain & non_pauli_node: # cmd depend on non-Pauli measurement - non_pauli_node.add(cmd.node) - else: - pauli_node.append((cmd, cmd.measurement)) - case _: - raise PatternError("Unknown Pauli measurement basis") - else: - non_pauli_node.add(cmd.node) - return pauli_node, non_pauli_node - - def assert_permutation(original: list[int], user: list[int]) -> None: """Check that the provided `user` node list is a permutation from `original`.""" node_set = set(user) @@ -2016,23 +1858,23 @@ def extract_signal(plane: Plane, s_domain: set[int], t_domain: set[int]) -> Extr assert_never(plane) -def shift_outcomes(outcomes: dict[int, Outcome], signal_dict: dict[int, set[int]]) -> dict[int, Outcome]: +def shift_outcomes(outcomes: Mapping[int, Outcome], signal_dict: Mapping[int, AbstractSet[int]]) -> dict[int, Outcome]: """Update outcomes with shifted signals. Shifted signals (as returned by the method :func:`Pattern.shift_signals`) affect classical outputs (measurements) while leaving the quantum state invariant. - This method updates the given `outcomes` by swapping the + This method updates the given ``outcomes`` by swapping the measurements affected by signals. This can be used either to - transform the value of :data:`Pattern.results` into measurements - observed in the unshifted pattern, or vice versa. + transform the results into measurements observed in the unshifted + pattern, or vice versa. Parameters ---------- - outcomes : dict[int, int] + outcomes : Mapping[int, Outcome] Classical outputs. - signal_dict : dict[int, set[int]] + signal_dict : Mapping[int, AbstractSet[int]] For each node, the signal that has been shifted (as returned by :func:`Pattern.shift_signals`). diff --git a/graphix/qasm3_exporter.py b/graphix/qasm3_exporter.py index 8b0066a4b..7e4a9336d 100644 --- a/graphix/qasm3_exporter.py +++ b/graphix/qasm3_exporter.py @@ -66,7 +66,7 @@ def qasm3_gate_call(gate: str, operands: Iterable[str], args: Iterable[str] | No def angle_to_qasm3(angle: ParameterizedAngle) -> str: """Get the OpenQASM3 representation of an angle.""" - if not isinstance(angle, float): + if not isinstance(angle, (int, float)): raise TypeError("QASM export of symbolic pattern is not supported") return angle_to_str(angle, output=OutputFormat.ASCII, multiplication_sign=True) @@ -120,9 +120,8 @@ def pattern_to_qasm3(pattern: Pattern, input_state: dict[int, State] | State = B qubits if the pattern has been Pauli-presimulated, and it may include Boolean expressions using xor (`^`) if some domains contain multiple qubits. These features are not supported by - `qiskit-qasm3-import`. The functions - :func:`graphix.optimization.incorporate_pauli_results` and - :func:`graphix.optimization.single_qubit_domains` transform any + `qiskit-qasm3-import`. The function + :func:`graphix.optimization.single_qubit_domains` transforms any pattern into an equivalent one such that exporting to OpenQASM 3.0 produces a circuit that can be imported into Qiskit. @@ -151,12 +150,6 @@ def pattern_to_qasm3_lines(pattern: Pattern, input_state: dict[int, State] | Sta state = input_state if isinstance(input_state, State) else input_state[node] yield from state_to_qasm3_lines(node, state) yield "\n" - if pattern.results != {}: - for i in pattern.results: - res = pattern.results[i] - yield f"// measurement result of qubit q{i}\n" - yield f"bit c{i} = {res};\n" - yield "\n" for cmd in pattern: yield from command_to_qasm3_lines(cmd) diff --git a/graphix/remove_pauli_measurements.py b/graphix/remove_pauli_measurements.py new file mode 100644 index 000000000..a0e292109 --- /dev/null +++ b/graphix/remove_pauli_measurements.py @@ -0,0 +1,579 @@ +"""Remove Pauli measurements. + +This module provides procedures for pushing Pauli measurements in +front of a pattern and for subsequently removing them from the +pattern. + +Pauli pushing uses commutation rules of Pauli measurements to move +them before other measurements while appropriately shifting their +signals, so that all Pauli measurements end up with empty +domains. This step is required before the actual removal can be +performed. + +For the removal itself, this module implements the algorithm described +in [BMBdF+21], Theorem 4.12 (Section 4.3: Removing Clifford +vertices). + +[BMBdF+21] Miriam Backens, Hector Miller-Bakewell, Giovanni de Felice, + Leo Lobski, and John van de Wetering, + There and back again: A circuit extraction tale, Quantum, 2021, + https://doi.org/10.22331/q-2021-03-25-421 +""" + +from __future__ import annotations + +import dataclasses +import itertools +from dataclasses import dataclass +from typing import TYPE_CHECKING +from warnings import warn + +import networkx as nx +from typing_extensions import assert_never + +from graphix.clifford import Clifford, Domains +from graphix.command import Command +from graphix.fundamentals import Axis, Sign +from graphix.measurements import PauliMeasurement +from graphix.optimization import StandardizedPattern + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + from collections.abc import Set as AbstractSet + from typing import TypeAlias + + from graphix.command import Node + + Graph: TypeAlias = nx.Graph[int] +else: + Graph = nx.Graph + + +@dataclass(frozen=True, slots=True) +class PauliPushingCut: + """Cut of the pattern measurements into Pauli and non-Pauli measurements.""" + + original_pattern: StandardizedPattern + + pauli_measurements: tuple[Command.M, ...] + """Pauli measurements: they are all applied before non-Pauli measurements and their domains are empty.""" + + non_pauli_measurements: tuple[Command.M, ...] + + shifted_domains: dict[int, set[int]] + """The shifted domains. + + The output of the original pattern can be retrieved by using + :func:`~graphix.pattern.shift_outcomes` with these domains. + """ + + @property + def measurements(self) -> tuple[Command.M, ...]: + """Return the list of measurements, where Pauli measurements appear first and without signal.""" + return self.pauli_measurements + self.non_pauli_measurements + + @classmethod + def from_standardized_pattern( + cls, pattern: StandardizedPattern, leave_nodes: AbstractSet[Node] | None = None, *, stacklevel: int = 1 + ) -> PauliPushingCut: + """Move Pauli measurements before the other measurements and return the cut between Pauli measurements and non-Pauli measurements. + + If you only need the resulting pattern, you can use + :meth:`~graphix.optimization.StandardizedPattern.perform_pauli_pushing` or + :meth:`~graphix.pattern.Pattern.perform_pauli_pushing` instead. + + Parameters + ---------- + pattern: StandardizedPattern + The pattern to reorder. + leave_nodes : AbstractSet[Node], optional + Nodes that should not be moved. This constraint only + applies to Pauli nodes and has no effect on non-Pauli nodes. + stacklevel : int, optional + Stack level to use for warnings. Defaults to 1, meaning that warnings + are reported at this function's call site. + + Returns + ------- + PauliPushingCut + The cut between Pauli measurements and non-Pauli measurements. + """ + pattern._warn_non_inferred_pauli_measurements(stacklevel=stacklevel + 1) + + if leave_nodes: + leave_non_pauli_nodes = [ + cmd.node + for cmd in pattern.m_list + if not isinstance(cmd.measurement, PauliMeasurement) and cmd.node in leave_nodes + ] + if leave_non_pauli_nodes: + warn( + f"`leave_nodes` contains nodes that are not Pauli: {leave_non_pauli_nodes}. The constraint has no effect on these nodes.", + stacklevel=stacklevel + 1, + ) + + shifted_domains: dict[int, set[int]] = {} + + pauli_measurements: list[Command.M] = [] + non_pauli_measurements: list[Command.M] = [] + for cmd in pattern.m_list: + s_domain = _expand_domain(shifted_domains, cmd.s_domain) + t_domain = _expand_domain(shifted_domains, cmd.t_domain) + if not isinstance(cmd.measurement, PauliMeasurement) or (leave_nodes and cmd.node in leave_nodes): + non_pauli_measurements.append( + Command.M(node=cmd.node, measurement=cmd.measurement, s_domain=s_domain, t_domain=t_domain) + ) + else: + match cmd.measurement.axis: + case Axis.X: + # M^X X^s Z^t = M^{XY,0} X^s Z^t + # = M^{XY,(-1)^s·0+tπ} + # = S^t M^X + # M^{-X} X^s Z^t = M^{XY,π} X^s Z^t + # = M^{XY,(-1)^s·π+tπ} + # = S^t M^{-X} + shifted_domains[cmd.node] = t_domain + case Axis.Y: + # M^Y X^s Z^t = M^{XY,π/2} X^s Z^t + # = M^{XY,(-1)^s·π/2+tπ} + # = M^{XY,π/2+(s+t)π} (since -π/2 = π/2 - π ≡ π/2 + π (mod 2π)) + # = S^{s+t} M^Y + # M^{-Y} X^s Z^t = M^{XY,-π/2} X^s Z^t + # = M^{XY,(-1)^s·(-π/2)+tπ} + # = M^{XY,-π/2+(s+t)π} (since π/2 = -π/2 + π) + # = S^{s+t} M^{-Y} + shifted_domains[cmd.node] = s_domain ^ t_domain + case Axis.Z: + # M^Z X^s Z^t = M^{XZ,0} X^s Z^t + # = M^{XZ,(-1)^t((-1)^s·0+sπ)} + # = M^{XZ,(-1)^t·sπ} + # = M^{XZ,sπ} (since (-1)^t·π ≡ π (mod 2π)) + # = S^s M^Z + # M^{-Z} X^s Z^t = M^{XZ,π} X^s Z^t + # = M^{XZ,(-1)^t((-1)^s·π+sπ)} + # = M^{XZ,(s+1)π} + # = S^s M^{-Z} + shifted_domains[cmd.node] = s_domain + case _: # pragma: no cover + assert_never(cmd.measurement.axis) + pauli_measurements.append(Command.M(node=cmd.node, measurement=cmd.measurement)) + return cls(pattern, tuple(pauli_measurements), tuple(non_pauli_measurements), shifted_domains) + + def to_standardized_pattern(self) -> StandardizedPattern: + """Return the standardized pattern where all Pauli measurements have been pushed.""" + return StandardizedPattern( + self.original_pattern.input_nodes, + self.original_pattern.output_nodes, + self.original_pattern.n_list, + self.original_pattern.e_set, + self.measurements, + _expand_corrections(self.shifted_domains, self.original_pattern.z_dict), + _expand_corrections(self.shifted_domains, self.original_pattern.x_dict), + self.original_pattern.c_dict, + ) + + +def _expand_domain(shifted_domains: Mapping[Node, AbstractSet[Node]], domain: AbstractSet[Node]) -> set[Node]: + """Merge previously shifted domains into ``domain``. + + Parameters + ---------- + shifted_domains: Mapping[Node, AbstractSet[Node]] + Shifted domains + domain : AbstractSet[Node] + Domain to update with any accumulated shift information. + """ + new_domain = set(domain) + for node in domain & shifted_domains.keys(): + new_domain ^= shifted_domains[node] + return new_domain + + +def _expand_corrections( + shifted_domains: Mapping[Node, AbstractSet[Node]], corrections: Mapping[Node, AbstractSet[Node]] +) -> dict[Node, set[Node]]: + return {node: _expand_domain(shifted_domains, domain) for node, domain in corrections.items()} + + +@dataclass(slots=True) +class _NodeSpec: + """Annotations attached to every node of the graph state.""" + + src: Node + """The corresponding node in the original pattern.""" + + domains: Domains = dataclasses.field(default_factory=lambda: Domains(set(), set())) + """Correction domains (the nodes refer to the numbering of the original pattern).""" + + clifford: Clifford = Clifford.I + + pauli_measurement: PauliMeasurement | None = None + """Pauli measurement if the node is measured with a Pauli measurement. + + ``None`` if the node is an output or measured with a non-Pauli measurement. + """ + + +class _RemovePauliMeasurements: + """Processing structure for Pauli measurement removal. + + This class is instantiated from a Pauli-pushing cut and can be + converted back to a standardized pattern with the method + :meth:`to_standardized_pattern`. The public methods preserve the + pattern semantics as invariant, such that an equivalent + standardized pattern can be obtained at any stage of the process. + """ + + cut: PauliPushingCut + """Cut of the pattern measurements obtained by Pauli-pushing.""" + + graph: Graph + node_specs: dict[Node, _NodeSpec] + + measurements: tuple[Command.M, ...] + """List of the original measurements after Pauli-pushing.""" + + pauli_measurements: dict[Axis, set[Node]] + """For each axis, the set of non-input nodes that have a Pauli measurement on that axis. + + Nodes are given with the indexing of the original pattern: use ``node_map`` to retrieve the index in the graph.""" + + input_node_set: set[Node] + """Set of input nodes: inputs nodes are never pivoted, therefore their indexing is preserved.""" + + output_node_set: set[Node] + """Set of output nodes, using the new indexing.""" + + node_map: dict[Node, Node] + """Mapping from the nodes of the original pattern to the nodes of the graph (that may have been pivoted). + + The following invariant is maintained for all node ``u``: ``node_specs[node_map[u]].src == u``. + """ + + def __init__(self, cut: PauliPushingCut) -> None: + self.cut = cut + self.graph = cut.original_pattern.extract_graph() + self.node_specs = {node: _NodeSpec(node) for node in self.graph.nodes()} + for node, domain in cut.original_pattern.x_dict.items(): + self.node_specs[node].domains.s_domain = _expand_domain(cut.shifted_domains, domain) + for node, domain in cut.original_pattern.z_dict.items(): + self.node_specs[node].domains.t_domain = _expand_domain(cut.shifted_domains, domain) + for node, clifford in cut.original_pattern.c_dict.items(): + self.node_specs[node].clifford = clifford + self.measurements = cut.measurements + self.pauli_measurements = {axis: set() for axis in Axis} + self.input_node_set = set(cut.original_pattern.input_nodes) + self.output_node_set = set(cut.original_pattern.output_nodes) + for cmd_m in self.cut.pauli_measurements: + if not isinstance(cmd_m.measurement, PauliMeasurement): # pragma: no cover + msg = "Pauli measurement expected." + raise TypeError(msg) + self.node_specs[cmd_m.node].pauli_measurement = cmd_m.measurement + if cmd_m.node not in self.input_node_set: + self.pauli_measurements[cmd_m.measurement.axis].add(cmd_m.node) + self.node_map = {node: node for node in self.graph.nodes()} + + def _apply_clifford(self, node: Node, clifford: Clifford) -> None: + """Apply a single-qubit Clifford gate to a node. + + This internal method breaks the semantics invariant: the + semantics of the pattern is not preserved. + """ + spec = self.node_specs[node] + spec.clifford @= clifford + spec.domains = clifford.commute_domains(spec.domains) + if spec.pauli_measurement is not None: + axis = spec.pauli_measurement.axis + spec.pauli_measurement = spec.pauli_measurement.clifford(clifford) + # Update pauli_measurements: sets in `pauli_measurements` + # dict only cover non-input nodes, so this update is + # skipped when the node is an input. + if node in self.input_node_set: + return + new_axis = spec.pauli_measurement.axis + if new_axis != axis: + self.pauli_measurements[axis].remove(spec.src) + self.pauli_measurements[new_axis].add(spec.src) + + def local_complement(self, u: Node) -> None: + """ + Local complement. + + Implements Lemma 2.31 and 4.3 [BMBdF+21]. + """ + n_u = list(self.graph.neighbors(u)) + _complement_subgraph(self.graph, n_u) + # |+⟩⟨+| + exp(-iπ/2) |-⟩⟨-| = H S† H + self._apply_clifford(u, Clifford.H @ Clifford.SDG @ Clifford.H) + for node in n_u: + # |0⟩⟨0| + exp(iπ/2) |1⟩⟨1| = S + self._apply_clifford(node, Clifford.S) + + def pivot_vertices(self, u: Node, v: Node) -> None: + """ + Pivot two vertices. + + Prerequisite (not checked): + - (u, v) is a graph edge; + - u and v are not input nodes. + + Implements Lemmas 2.32 and 4.5 [BMBdF+21]. + """ + n_u = set(self.graph.neighbors(u)) + n_v = set(self.graph.neighbors(v)) + + inter = n_u & n_v + only_u = n_u - inter - {v} + only_v = n_v - inter - {u} + + _complement_edges(self.graph, only_u, only_v) + _complement_edges(self.graph, only_u, inter) + _complement_edges(self.graph, only_v, inter) + + spec_u = self.node_specs[u] + spec_v = self.node_specs[v] + self.node_specs[v] = spec_u + self.node_specs[u] = spec_v + self.node_map[spec_u.src] = v + self.node_map[spec_v.src] = u + + self._apply_clifford(u, Clifford.H) + self._apply_clifford(v, Clifford.H) + + for node in inter: + self._apply_clifford(node, Clifford.Z) + + u_output = u in self.output_node_set + v_output = v in self.output_node_set + if u_output != v_output: + if u_output: + old_output, new_output = u, v + else: + old_output, new_output = v, u + self.output_node_set.remove(old_output) + self.output_node_set.add(new_output) + + def _remove_node(self, u: Node) -> None: + """Remove a node from the graph. + + This internal method breaks the semantics invariant: the + semantics of the pattern is not preserved. + """ + spec = self.node_specs[u] + if spec.pauli_measurement is not None: + self.pauli_measurements[spec.pauli_measurement.axis].remove(spec.src) + del self.node_map[spec.src] + del self.node_specs[u] + self.graph.remove_node(u) + + def remove_z(self, u: Node, sign: Sign) -> None: + """ + Remove Z/-Z measurement. + + Prerequisite (not checked): + - u measured in Z (sign==PLUS) or -Z (sign=MINUS); + - u is not an input node. + + Implements Lemma 4.7 [BMBdF+21]. + """ + if sign == Sign.MINUS: + for node in self.graph.neighbors(u): + self._apply_clifford(node, Clifford.Z) + self._remove_node(u) + + def remove_y(self, u: Node, sign: Sign) -> None: + """ + Remove Y/-Y measurement. + + Prerequisite (not checked): + - u measured in Y (sign==PLUS) or -Y (sign=MINUS); + - u is not an input node. + + Implements Lemma 4.8 [BMBdF+21]. + """ + self.local_complement(u) + self.remove_z(u, sign) + + def remove_x_with_internal_neighbor(self, u: Node, v: Node, sign: Sign) -> None: + """ + Remove X/-X measurement. + + Prerequisite (not checked): + - u measured in X (sign==PLUS) or -X (sign=MINUS); + - (u, v) is a graph edge; + - u and v are internal nodes. + + Implements Lemma 4.9 [BMBdF+21]. + """ + self.pivot_vertices(u, v) + self.remove_z(v, sign) + + def remove_all_y_or_z(self) -> None: + """ + Remove all Y and Z measurements, repeatedly. + + Implements Theorem 4.12, Steps 1 and 2. + """ + for axis, remove in ( + (Axis.Y, self.remove_y), # Step 1: remove any non-input Y measured node + (Axis.Z, self.remove_z), # Step 2: remove any non-input Z measured node + ): + while (node := next(iter(self.pauli_measurements[axis]), None)) is not None: + new_node = self.node_map[node] + spec = self.node_specs[new_node] + if spec.pauli_measurement is None: # pragma: no cover + msg = "Pauli measurement expected." + raise RuntimeError(msg) + remove(new_node, spec.pauli_measurement.sign) + + def try_remove_x_with_internal_neighbor(self) -> bool: + """ + Find an X measurement connected to internal neighbor and remove it if any. + + Implements Theorem 4.12, Step 3. + + Returns + ------- + bool + ``True`` if a node has been found and removed, ``False`` otherwise + """ + for node in self.pauli_measurements[Axis.X]: + new_node = self.node_map[node] + internal_neighbors = set(self.graph.neighbors(new_node)) - self.input_node_set - self.output_node_set + v = next(iter(internal_neighbors), None) + if v is None: + continue + spec = self.node_specs[new_node] + if spec.pauli_measurement is None: # pragma: no cover + msg = "Pauli measurement expected." + raise RuntimeError(msg) + self.remove_x_with_internal_neighbor(new_node, v, spec.pauli_measurement.sign) + return True + return False + + def try_pivot_x_with_output_node(self) -> bool: + """ + Find an X measurement connected to an output node that is not also an input and pivot it if any. + + Implements Lemma 4.11 and Theorem 4.12, Step 4. + + Returns + ------- + bool + ``True`` if a node has been found and pivoted, ``False`` otherwise + """ + for node in self.pauli_measurements[Axis.X]: + new_node = self.node_map[node] + non_input_output_nodes = set(self.graph.neighbors(new_node)) & self.output_node_set - self.input_node_set + v = next(iter(non_input_output_nodes), None) + if v is None: + continue + self.pivot_vertices(new_node, v) + return True + return False + + def remove_isolated_internal_nodes(self) -> None: + """Remove isolated internal nodes.""" + # Construct the list first since the graph should not be + # modified while enumerating isolated nodes. + for node in list(nx.isolates(self.graph)): + if node not in self.input_node_set and node not in self.output_node_set: + self._remove_node(node) + + def _create_new_m(self, original_m: Command.M) -> Command.M | None: + node = self.node_map.get(original_m.node) + if node is None: + return None + spec = self.node_specs[node] + new_m = original_m.clifford(spec.clifford) + new_m.node = node + new_m.s_domain = _map_domain(self.node_map, new_m.s_domain) + new_m.t_domain = _map_domain(self.node_map, new_m.t_domain) + return new_m + + def to_standardized_pattern(self) -> StandardizedPattern: + n_list = tuple(cmd_n for cmd_n in self.cut.original_pattern.n_list if cmd_n.node in self.node_specs) + output_nodes = tuple(self.node_map[node] for node in self.cut.original_pattern.output_nodes) + measurements = tuple(new_m for original_m in self.measurements if (new_m := self._create_new_m(original_m))) + z_dict = { + node: t_domain + for node in output_nodes + if (t_domain := _map_domain(self.node_map, self.node_specs[node].domains.t_domain)) + } + x_dict = { + node: s_domain + for node in output_nodes + if (s_domain := _map_domain(self.node_map, self.node_specs[node].domains.s_domain)) + } + c_dict = {node: clifford for node in output_nodes if (clifford := self.node_specs[node].clifford) != Clifford.I} + return StandardizedPattern( + self.cut.original_pattern.input_nodes, + output_nodes, + n_list, + self.graph.edges(), + measurements, + z_dict, + x_dict, + c_dict, + ) + + +def _complement_subgraph(graph: nx.Graph[Node], s: Iterable[Node]) -> None: + """Complement edges in a given subgraph.""" + all_pairs = set(itertools.combinations(s, 2)) + existing = all_pairs & graph.edges() + graph.remove_edges_from(existing) + graph.add_edges_from(all_pairs - existing) + + +def _complement_edges(graph: nx.Graph[Node], s: set[Node], t: set[Node]) -> None: + """Complement edges between two set of nodes. + + ``s`` and ``t`` are supposed to be disjoint. + """ + all_pairs = {(u, v) for u in s for v in t} + existing = {(u, v) for u, v in graph.edges(s) if v in t} + graph.remove_edges_from(existing) + graph.add_edges_from(all_pairs - existing) + + +def _map_domain(node_map: Mapping[Node, Node], domain: set[Node]) -> set[Node]: + return {v for node in domain if (v := node_map.get(node)) is not None} + + +def remove_pauli_measurements(cut: PauliPushingCut) -> StandardizedPattern: + """Remove non-input Pauli measurements from the given pattern. + + This function implements the algorithm described in [BMBdF+21], + Theorem 4.12 (Section 4.3: Removing Clifford vertices). + + This function removes all non-input Y and Z measured nodes and all + non-input X measured nodes connected to any other internal vertex. + Furthermore, if any non-input X measured node is connected to an + output node that is not also an input, pivoting these nodes + enables eliminating further nodes. In particular, if the pattern + has flow, all non-input Pauli measurements are removed. + + Note that if the pattern is nondeterministic, only the 0-branch is + preserved. + + Parameters + ---------- + cut: PauliPushingCut + The Pauli-pushed pattern to optimize. + + Returns + ------- + StandardizedPattern + The pattern in which Pauli measurements have been removed. + """ + process = _RemovePauliMeasurements(cut) + while True: + process.remove_all_y_or_z() # Steps 1 and 2 + if ( + not process.try_remove_x_with_internal_neighbor() # Step 3 + and not process.try_pivot_x_with_output_node() # Step 4 + ): + break + process.remove_isolated_internal_nodes() + return process.to_standardized_pattern() diff --git a/graphix/sim/tensornet.py b/graphix/sim/tensornet.py index 11f0bd072..f9e8c56f7 100644 --- a/graphix/sim/tensornet.py +++ b/graphix/sim/tensornet.py @@ -32,6 +32,7 @@ from graphix import Pattern from graphix.clifford import Clifford + from graphix.command import Node from graphix.measurements import Measurement, Outcome from graphix.sim import Data @@ -643,7 +644,7 @@ def __init__( graph_prep = "sequential" if max_degree > 5 or not pattern.is_standard() else "parallel" case _: raise ValueError(f"Invalid graph preparation strategy: {graph_prep}") - results = deepcopy(pattern.results) + results: dict[Node, Outcome] = {} if graph_prep == "parallel": if not pattern.is_standard(): raise ValueError("parallel preparation strategy does not support not-standardized pattern") diff --git a/graphix/simulator.py b/graphix/simulator.py index 29afa9858..e81439ebc 100644 --- a/graphix/simulator.py +++ b/graphix/simulator.py @@ -374,7 +374,7 @@ def __init__( prepare_method = DefaultPrepareMethod() self.__prepare_method = prepare_method if measure_method is None: - measure_method = DefaultMeasureMethod(pattern.results) + measure_method = DefaultMeasureMethod() self.__measure_method = measure_method @property diff --git a/graphix/space_minimization.py b/graphix/space_minimization.py index 602e42461..a20988803 100644 --- a/graphix/space_minimization.py +++ b/graphix/space_minimization.py @@ -110,7 +110,6 @@ def standardized_to_space_optimal_pattern(pattern: StandardizedPattern) -> Patte """ target = graphix.Pattern(input_nodes=pattern.input_nodes) - target.results = dict(pattern.results) initialized = set(pattern.input_nodes) done: set[Node] = set() n_dict = {n.node: n for n in pattern.n_list} @@ -229,8 +228,6 @@ def greedy_degree(pattern: StandardizedPattern) -> SpaceMinimizationHeuristicRes nodes = set(graph.nodes) not_measured = nodes - set(pattern.output_nodes) dependency = _extract_dependency(pattern) - # keys() should be converted into `set` because it is transient. - _update_dependency(set(pattern.results.keys()), dependency) meas_order = [] while not_measured: next_node = min((i for i in not_measured if not dependency[i]), key=graph.degree) diff --git a/noxfile.py b/noxfile.py index 5379888a5..65419b640 100644 --- a/noxfile.py +++ b/noxfile.py @@ -110,16 +110,16 @@ class ReverseDependency: @nox.parametrize( "package", [ - ReverseDependency("https://github.com/thierry-martinez/graphix-stim-backend", branch="fix/graphix_namespace"), ReverseDependency( - "https://github.com/TeamGraphix/graphix-symbolic", + "https://github.com/thierry-martinez/graphix-stim-backend", branch="fix/graphix_498_remove_pauli" ), - ReverseDependency("https://github.com/TeamGraphix/graphix-qasm-parser", branch="fix_angles"), + ReverseDependency("https://github.com/TeamGraphix/graphix-symbolic"), + ReverseDependency("https://github.com/TeamGraphix/graphix-qasm-parser"), ReverseDependency( "https://github.com/thierry-martinez/veriphix", doctest_modules=False, install_target=".[dev]", - branch="fix/graphix_namespace", + branch="fix/graphix_498_remove_pauli", ), ReverseDependency("https://github.com/TeamGraphix/graphix-ibmq", doctest_modules=False), ReverseDependency("https://github.com/qat-inria/graphix-stim-compiler", branch="ps_dim"), diff --git a/tests/baseline/test_draw_graph_reference_False.png b/tests/baseline/test_draw_graph_reference_False.png index 6d0ee27c3..55b09a743 100644 Binary files a/tests/baseline/test_draw_graph_reference_False.png and b/tests/baseline/test_draw_graph_reference_False.png differ diff --git a/tests/test_extraction.py b/tests/test_extraction.py index da5981cc1..c448509c4 100644 --- a/tests/test_extraction.py +++ b/tests/test_extraction.py @@ -1,12 +1,12 @@ from __future__ import annotations from graphix import extraction -from graphix.graphsim import GraphState +from graphix.extraction import Graph class TestExtraction: def test_cluster_extraction_one_ghz_cluster(self) -> None: - gs = GraphState() + gs = Graph() nodes = [0, 1, 2, 3, 4] edges = [(0, 1), (0, 2), (0, 3), (0, 4)] gs.add_nodes_from(nodes) @@ -18,7 +18,7 @@ def test_cluster_extraction_one_ghz_cluster(self) -> None: # we consider everything smaller than 4, a GHZ def test_cluster_extraction_small_ghz_cluster_1(self) -> None: - gs = GraphState() + gs = Graph() nodes = [0, 1, 2] edges = [(0, 1), (1, 2)] gs.add_nodes_from(nodes) @@ -30,7 +30,7 @@ def test_cluster_extraction_small_ghz_cluster_1(self) -> None: # we consider everything smaller than 4, a GHZ def test_cluster_extraction_small_ghz_cluster_2(self) -> None: - gs = GraphState() + gs = Graph() nodes = [0, 1] edges = [(0, 1)] gs.add_nodes_from(nodes) @@ -41,7 +41,7 @@ def test_cluster_extraction_small_ghz_cluster_2(self) -> None: assert clusters[0] == extraction.ResourceGraph(cltype=extraction.ResourceType.GHZ, graph=gs) def test_cluster_extraction_one_linear_cluster(self) -> None: - gs = GraphState() + gs = Graph() nodes = [0, 1, 2, 3, 4, 5, 6] edges = [(0, 1), (1, 2), (2, 3), (5, 4), (4, 6), (6, 0)] gs.add_nodes_from(nodes) @@ -52,7 +52,7 @@ def test_cluster_extraction_one_linear_cluster(self) -> None: assert clusters[0] == extraction.ResourceGraph(cltype=extraction.ResourceType.LINEAR, graph=gs) def test_cluster_extraction_one_ghz_one_linear(self) -> None: - gs = GraphState() + gs = Graph() nodes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] edges = [(0, 1), (0, 2), (0, 3), (0, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9)] gs.add_nodes_from(nodes) @@ -61,11 +61,11 @@ def test_cluster_extraction_one_ghz_one_linear(self) -> None: assert len(clusters) == 2 clusters_expected = [] - lin_cluster = GraphState() + lin_cluster = Graph() lin_cluster.add_nodes_from([4, 5, 6, 7, 8, 9]) lin_cluster.add_edges_from([(4, 5), (5, 6), (6, 7), (7, 8), (8, 9)]) clusters_expected.append(extraction.ResourceGraph(extraction.ResourceType.LINEAR, lin_cluster)) - ghz_cluster = GraphState() + ghz_cluster = Graph() ghz_cluster.add_nodes_from([0, 1, 2, 3, 4]) ghz_cluster.add_edges_from([(0, 1), (0, 2), (0, 3), (0, 4)]) clusters_expected.append(extraction.ResourceGraph(extraction.ResourceType.GHZ, ghz_cluster)) @@ -75,7 +75,7 @@ def test_cluster_extraction_one_ghz_one_linear(self) -> None: ) def test_cluster_extraction_pentagonal_cluster(self) -> None: - gs = GraphState() + gs = Graph() nodes = [0, 1, 2, 3, 4] edges = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)] gs.add_nodes_from(nodes) @@ -92,7 +92,7 @@ def test_cluster_extraction_pentagonal_cluster(self) -> None: ) def test_cluster_extraction_one_plus_two(self) -> None: - gs = GraphState() + gs = Graph() nodes = [0, 1, 2] edges = [(0, 1)] gs.add_nodes_from(nodes) diff --git a/tests/test_optimization.py b/tests/test_optimization.py index 29bc663c9..42db33200 100644 --- a/tests/test_optimization.py +++ b/tests/test_optimization.py @@ -9,7 +9,7 @@ from graphix.command import C, CommandKind, E, M, N, X, Z from graphix.fundamentals import ANGLE_PI, Plane from graphix.measurements import Measurement -from graphix.optimization import StandardizedPattern, incorporate_pauli_results, remove_useless_domains +from graphix.optimization import StandardizedPattern, remove_useless_domains from graphix.pattern import Pattern from graphix.random_objects import rand_circuit from graphix.states import PlanarState @@ -57,23 +57,6 @@ def test_standardize_clifford_entanglement(fx_rng: Generator) -> None: assert state_p.isclose(state_ref) -@pytest.mark.parametrize("jumps", range(1, 11)) -def test_incorporate_pauli_results(fx_bg: PCG64, jumps: int) -> None: - rng = Generator(fx_bg.jumped(jumps)) - nqubits = 3 - depth = 3 - circuit = rand_circuit(nqubits, depth, rng) - pattern = circuit.transpile().pattern - pattern.standardize() - pattern.shift_signals() - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() - pattern2 = incorporate_pauli_results(pattern) - state = pattern.simulate_pattern(rng=rng) - state2 = pattern2.simulate_pattern(rng=rng) - assert state.isclose(state2) - - @pytest.mark.parametrize("jumps", range(1, 11)) def test_flow_after_pauli_preprocessing(fx_bg: PCG64, jumps: int) -> None: rng = Generator(fx_bg.jumped(jumps)) @@ -83,11 +66,10 @@ def test_flow_after_pauli_preprocessing(fx_bg: PCG64, jumps: int) -> None: pattern = circuit.transpile().pattern pattern.standardize() pattern.shift_signals() - # pattern.move_pauli_measurements_to_the_front() - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() - pattern2 = incorporate_pauli_results(pattern) - gflow = pattern2.extract_gflow() + pattern.remove_pauli_measurements() + # We should convert to Bloch measurement the remaining Pauli + # measurements on input nodes. + gflow = pattern.to_bloch().extract_gflow() gflow.check_well_formed() @@ -100,8 +82,7 @@ def test_remove_useless_domains(fx_bg: PCG64, jumps: int) -> None: pattern = circuit.transpile().pattern pattern.standardize() pattern.shift_signals() - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() + pattern.remove_pauli_measurements() pattern2 = remove_useless_domains(pattern) state = pattern.simulate_pattern(rng=rng) state2 = pattern2.simulate_pattern(rng=rng) diff --git a/tests/test_parameter.py b/tests/test_parameter.py index f33040634..88f68d07e 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -170,8 +170,7 @@ def test_random_circuit_with_parameters(fx_bg: PCG64, jumps: int, use_xreplace: pattern = circuit.transpile().pattern pattern.standardize() pattern.shift_signals() - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() + pattern.remove_pauli_measurements() pattern.minimize_space() assignment: dict[Parameter, float] = {alpha: rng.uniform(high=2), beta: rng.uniform(high=2)} if use_xreplace: diff --git a/tests/test_pattern.py b/tests/test_pattern.py index 948c5d9ca..7dc073a6f 100644 --- a/tests/test_pattern.py +++ b/tests/test_pattern.py @@ -141,8 +141,7 @@ def test_pauli_non_contiguous(self) -> None: M(0, Measurement.X, s_domain=set(), t_domain=set()), ] ) - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() + pattern.remove_pauli_measurements() @pytest.mark.parametrize("jumps", range(1, 11)) def test_minimize_space_with_gflow(self, fx_bg: PCG64, jumps: int) -> None: @@ -154,8 +153,7 @@ def test_minimize_space_with_gflow(self, fx_bg: PCG64, jumps: int) -> None: pattern = circuit.transpile().pattern pattern.standardize() pattern.shift_signals(method="mc") - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() + pattern.remove_pauli_measurements() pattern.minimize_space() state = circuit.simulate_statevector().statevec state_mbqc = pattern.simulate_pattern(rng=rng) @@ -233,18 +231,14 @@ def test_pauli_measurement_random_circuit( pattern = circuit.transpile().pattern pattern.standardize() pattern.shift_signals(method="mc") - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() + pattern.remove_pauli_measurements() pattern.minimize_space() state = circuit.simulate_statevector().statevec state_mbqc: Statevec | DensityMatrix = pattern.simulate_pattern(backend, rng=rng) assert compare_backend_result_with_statevec(state_mbqc, state) == pytest.approx(1) @pytest.mark.parametrize("jumps", range(1, 11)) - @pytest.mark.parametrize("ignore_pauli_with_deps", [False, True]) - def test_pauli_measurement_random_circuit_all_paulis( - self, fx_bg: PCG64, jumps: int, ignore_pauli_with_deps: bool - ) -> None: + def test_pauli_measurement_random_circuit_all_paulis(self, fx_bg: PCG64, jumps: int) -> None: rng = Generator(fx_bg.jumped(jumps)) nqubits = 3 depth = 3 @@ -252,10 +246,12 @@ def test_pauli_measurement_random_circuit_all_paulis( pattern = circuit.transpile().pattern pattern.standardize() pattern.shift_signals(method="mc") - pattern.remove_input_nodes() - pattern.perform_pauli_measurements(ignore_pauli_with_deps=ignore_pauli_with_deps) - assert ignore_pauli_with_deps or not any( - cmd.measurement.try_to_pauli() is not None for cmd in pattern if cmd.kind == CommandKind.M + pattern.remove_pauli_measurements() + input_node_set = set(pattern.input_nodes) + assert not any( + cmd.measurement.try_to_pauli() is not None + for cmd in pattern + if cmd.kind == CommandKind.M and cmd.node not in input_node_set ) @pytest.mark.parametrize("pm", PauliMeasurement) @@ -264,10 +260,10 @@ def test_pauli_measurement_single(self, pm: PauliMeasurement) -> None: pattern.add(E(nodes=(0, 1))) pattern.add(M(0, pm)) pattern_ref = pattern.copy() - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() - state = pattern.simulate_pattern() - state_ref = pattern_ref.simulate_pattern(branch_selector=ConstBranchSelector(0)) + pattern.remove_pauli_measurements() + branch_selector = ConstBranchSelector(0) + state = pattern.simulate_pattern(branch_selector=branch_selector) + state_ref = pattern_ref.simulate_pattern(branch_selector=branch_selector) assert state.isclose(state_ref) def test_pauli_measurement(self) -> None: @@ -288,46 +284,17 @@ def test_pauli_measurement(self) -> None: pattern = circuit.transpile().pattern pattern.standardize() pattern.shift_signals(method="mc") - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() - isolated_nodes = pattern.extract_isolated_nodes() - # 42-node is the isolated and output node. - isolated_nodes_ref = {42} - assert isolated_nodes == isolated_nodes_ref - - def test_pauli_measurement_error(self, fx_rng: Generator) -> None: - nqubits = 2 - depth = 1 - circuit = rand_circuit(nqubits, depth, fx_rng) - pattern = circuit.transpile().pattern - pattern.standardize() - with pytest.raises(PatternError): - pattern.perform_pauli_measurements() - - def test_pauli_measurement_leave_input(self) -> None: - # test pattern is obtained from 3-qubit QFT with pauli measurement - circuit = Circuit(3) - for i in range(3): - circuit.h(i) - circuit.x(1) - circuit.x(2) - - # QFT - circuit.h(2) - cp(circuit, ANGLE_PI / 4, 0, 2) - cp(circuit, ANGLE_PI / 2, 1, 2) - circuit.h(1) - cp(circuit, ANGLE_PI / 2, 0, 1) - circuit.h(0) - swap(circuit, 0, 2) - pattern = circuit.transpile().pattern - pattern.standardize() - with pytest.raises(PatternError): - pattern.perform_pauli_measurements() + pattern_opt = pattern.remove_pauli_measurements(copy=True) + isolated_nodes = pattern_opt.extract_isolated_nodes() + assert isolated_nodes == set() + pattern.minimize_space() + pattern_opt.minimize_space() + state = pattern.simulate_pattern() + state_opt = pattern.simulate_pattern() + assert state.isclose(state_opt) @pytest.mark.parametrize("jumps", range(1, 6)) - @pytest.mark.parametrize("ignore_pauli_with_deps", [False, True]) - def test_pauli_measured_against_nonmeasured(self, fx_bg: PCG64, jumps: int, ignore_pauli_with_deps: bool) -> None: + def test_pauli_measured_against_nonmeasured(self, fx_bg: PCG64, jumps: int) -> None: rng = Generator(fx_bg.jumped(jumps)) nqubits = 2 depth = 2 @@ -335,48 +302,11 @@ def test_pauli_measured_against_nonmeasured(self, fx_bg: PCG64, jumps: int, igno pattern = circuit.transpile().pattern pattern.standardize() pattern1 = copy.deepcopy(pattern) - pattern1.remove_input_nodes() - pattern1.perform_pauli_measurements(ignore_pauli_with_deps=ignore_pauli_with_deps) + pattern1.remove_pauli_measurements() state = pattern.simulate_pattern(rng=rng) state1 = pattern1.simulate_pattern(rng=rng) assert state.isclose(state1) - @pytest.mark.parametrize("jumps", range(1, 4)) - def test_pauli_repeated_measurement(self, fx_bg: PCG64, jumps: int) -> None: - rng = Generator(fx_bg.jumped(jumps)) - nqubits = 2 - depth = 2 - circuit = rand_circuit(nqubits, depth, rng, use_ccx=False) - pattern = circuit.transpile().pattern - pattern.remove_input_nodes() - assert not pattern.results - pattern.perform_pauli_measurements() - assert pattern.results - pattern.perform_pauli_measurements() - assert pattern.results - - @pytest.mark.parametrize("jumps", range(1, 4)) - def test_pauli_repeated_measurement_compose(self, fx_bg: PCG64, jumps: int) -> None: - rng = Generator(fx_bg.jumped(jumps)) - nqubits = 2 - depth = 2 - circuit = rand_circuit(nqubits, depth, rng, use_ccx=False) - circuit1 = rand_circuit(nqubits, depth, rng, use_ccx=False) - pattern = circuit.transpile().pattern - pattern1 = circuit1.transpile().pattern - composed_pattern, _ = pattern.compose( - pattern1, mapping=dict(zip(pattern1.input_nodes, pattern.output_nodes, strict=True)), preserve_mapping=True - ) - pattern.remove_input_nodes() - pattern1.remove_input_nodes() - assert not pattern.results - assert not pattern1.results - pattern.perform_pauli_measurements() - pattern1.perform_pauli_measurements() - composed_pattern.remove_input_nodes() - composed_pattern.perform_pauli_measurements() - assert abs(len(composed_pattern.results) - len(pattern.results) - len(pattern1.results)) <= 2 - def test_extract_measurement_commands(self) -> None: preset_meas_plane = [ Plane.XY, @@ -477,8 +407,7 @@ def test_pauli_measurement_then_standardize(self, fx_bg: PCG64, jumps: int) -> N depth = 3 circuit = rand_circuit(nqubits, depth, rng) pattern = circuit.transpile().pattern - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() + pattern.remove_pauli_measurements() pattern.standardize() pattern.minimize_space() state = circuit.simulate_statevector().statevec @@ -498,23 +427,6 @@ def test_standardize_two_cliffords(self, fx_bg: PCG64, jumps: int) -> None: state_p = pattern.simulate_pattern() assert state_p.isclose(state_ref) - @pytest.mark.parametrize("jumps", range(1, 48)) - def test_standardize_domains_and_clifford(self, fx_bg: PCG64, jumps: int) -> None: - rng = Generator(fx_bg.jumped(jumps)) - x, z = rng.integers(2, size=2) - c = rng.integers(len(Clifford)) - pattern = Pattern(input_nodes=[0]) - pattern.results[1] = x - pattern.add(X(node=0, domain={1})) - pattern.results[2] = z - pattern.add(Z(node=0, domain={2})) - pattern.add(C(node=0, clifford=Clifford(c))) - pattern_ref = pattern.copy() - pattern.standardize() - state_ref = pattern_ref.simulate_pattern() - state_p = pattern.simulate_pattern() - assert state_p.isclose(state_ref) - # Simple pattern composition def test_compose_1(self) -> None: i1_lst = [0] @@ -754,7 +666,7 @@ def test_compose_7(self, fx_rng: Generator) -> None: circuit_1.rz(0, alpha) p1 = circuit_1.transpile().pattern p1.remove_input_nodes() - p1.perform_pauli_measurements() + p1.remove_pauli_measurements() circuit_2 = Circuit(1) circuit_2.rz(0, alpha) @@ -816,13 +728,6 @@ def test_check_runnability_failures(self) -> None: assert exc_info.value.node == 0 assert exc_info.value.reason == RunnabilityErrorReason.NotYetActive - pattern = Pattern(cmds=[N(0), M(0)]) - pattern.results = {0: 0} - with pytest.raises(RunnabilityError) as exc_info: - pattern.check_runnability() - assert exc_info.value.node == 0 - assert exc_info.value.reason == RunnabilityErrorReason.AlreadyMeasured - pattern = Pattern(cmds=[N(0), M(0, s_domain={0})]) with pytest.raises(RunnabilityError) as exc_info: pattern.check_runnability() @@ -880,12 +785,11 @@ def test_extract_partial_order_layers_results(self) -> None: c = Circuit(1) c.rz(0, 0.2) p = c.transpile().pattern - p.remove_input_nodes() - p.perform_pauli_measurements() - assert p.extract_partial_order_layers() == (frozenset({2}), frozenset({0})) + p.remove_pauli_measurements() + assert p.extract_partial_order_layers() == (frozenset({1}), frozenset({0})) p = Pattern(cmds=[N(0), N(1), N(2), M(0), E((1, 2)), X(1, {0}), M(2, Measurement.XY(0.3))]) - p.perform_pauli_measurements() + p.remove_pauli_measurements() assert p.extract_partial_order_layers() == (frozenset({1}), frozenset({2})) class PatternFlowTestCase(NamedTuple): @@ -1006,10 +910,8 @@ def test_extract_causal_flow_rnd_circuit(self, fx_bg: PCG64, jumps: int) -> None p_ref = circuit_1.transpile().pattern p_test = p_ref.to_bloch().extract_causal_flow().to_corrections().to_pattern().infer_pauli_measurements() - p_ref.remove_input_nodes() - p_test.remove_input_nodes() - p_ref.perform_pauli_measurements() - p_test.perform_pauli_measurements() + p_ref.remove_pauli_measurements() + p_test.remove_pauli_measurements() s_ref = p_ref.simulate_pattern(rng=rng) s_test = p_test.simulate_pattern(rng=rng) @@ -1025,10 +927,8 @@ def test_extract_gflow_rnd_circuit(self, fx_bg: PCG64, jumps: int) -> None: p_ref = circuit_1.transpile().pattern p_test = p_ref.to_bloch().extract_gflow().to_corrections().to_pattern().infer_pauli_measurements() - p_ref.remove_input_nodes() - p_test.remove_input_nodes() - p_ref.perform_pauli_measurements() - p_test.perform_pauli_measurements() + p_ref.remove_pauli_measurements() + p_test.remove_pauli_measurements() s_ref = p_ref.simulate_pattern(rng=rng) s_test = p_test.simulate_pattern(rng=rng) @@ -1122,8 +1022,7 @@ def test_extract_xzc_rnd_circuit(self, fx_bg: PCG64, jumps: int) -> None: p_test = xzc.to_pattern() for p in [p_ref, p_test]: - p.remove_input_nodes() - p.perform_pauli_measurements() + p.remove_pauli_measurements() s_ref = p_ref.simulate_pattern(rng=rng) s_test = p_test.simulate_pattern(rng=rng) @@ -1343,8 +1242,7 @@ def test_pauli_measurement_end_with_measure(self) -> None: p = Pattern(input_nodes=[0]) p.add(N(node=1)) p.add(M(1, Measurement.X)) - p.remove_input_nodes() - p.perform_pauli_measurements() + p.remove_pauli_measurements() @pytest.mark.parametrize("backend", ["statevector", "densitymatrix"]) @pytest.mark.filterwarnings("ignore:Simulating using densitymatrix backend with no noise.") diff --git a/tests/test_pyzx.py b/tests/test_pyzx.py index e4a0d868e..a45762e18 100644 --- a/tests/test_pyzx.py +++ b/tests/test_pyzx.py @@ -78,8 +78,7 @@ def test_random_clifford_t() -> None: def simulate_pattern(pattern: Pattern, rng: Generator) -> Statevec: - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() + pattern.remove_pauli_measurements() pattern.minimize_space() return pattern.simulate_pattern(rng=rng) diff --git a/tests/test_qasm3_exporter.py b/tests/test_qasm3_exporter.py index f9751d0de..e0b4d590d 100644 --- a/tests/test_qasm3_exporter.py +++ b/tests/test_qasm3_exporter.py @@ -46,16 +46,16 @@ def test_to_qasm3_random_circuit(fx_bg: PCG64, jumps: int) -> None: See :func:`test_qasm3_exporter_to_qiskit:test_to_qasm3_random_circuit`, - where the result is validated. The current test does not go through the - normalization passes ``incorporate_pauli_results`` and ``single_qubit_domains``, - so it exercises execution paths that are not tested elsewhere. + where the result is validated. The current test does not go + through the normalization pass ``single_qubit_domains``, so it + exercises execution paths that are not tested elsewhere. """ rng = Generator(fx_bg.jumped(jumps)) nqubits = 5 depth = 5 circuit = rand_circuit(nqubits, depth, rng=rng) pattern = circuit.transpile().pattern - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() + pattern.remove_pauli_measurements() pattern.minimize_space() + print(pattern) _qasm3 = pattern_to_qasm3(pattern) diff --git a/tests/test_qasm3_exporter_to_qiskit.py b/tests/test_qasm3_exporter_to_qiskit.py index a857daa7d..c44eb9dda 100644 --- a/tests/test_qasm3_exporter_to_qiskit.py +++ b/tests/test_qasm3_exporter_to_qiskit.py @@ -14,7 +14,7 @@ from graphix.command import C, CommandKind, E, M, N from graphix.fundamentals import Plane from graphix.measurements import BlochMeasurement, Measurement, outcome -from graphix.optimization import incorporate_pauli_results, single_qubit_domains +from graphix.optimization import single_qubit_domains from graphix.qasm3_exporter import pattern_to_qasm3 from graphix.random_objects import rand_circuit from graphix.sim.statevec import StatevectorBackend @@ -119,13 +119,9 @@ def test_to_qasm3_random_circuit(fx_bg: PCG64, jumps: int) -> None: depth = 5 circuit = rand_circuit(nqubits, depth, rng=rng) pattern = circuit.transpile().pattern - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() + pattern.remove_pauli_measurements() pattern.minimize_space() - # qiskit_qasm3_import.exceptions.ConversionError: initialisation of classical bits is not supported - pattern = incorporate_pauli_results(pattern) - # qiskit_qasm3_import.exceptions.ConversionError: unhandled binary operator '^' pattern = single_qubit_domains(pattern) diff --git a/tests/test_remove_pauli_measurements.py b/tests/test_remove_pauli_measurements.py new file mode 100644 index 000000000..6c1455537 --- /dev/null +++ b/tests/test_remove_pauli_measurements.py @@ -0,0 +1,323 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import networkx as nx +import pytest +from numpy.random import Generator + +from graphix import ( + Axis, + BlochMeasurement, + Circuit, + Clifford, + Command, + Measurement, + OpenGraph, + Pattern, + PauliMeasurement, + Sign, + StandardizedPattern, +) +from graphix.random_objects import rand_circuit, rand_state_vector +from graphix.remove_pauli_measurements import PauliPushingCut, _RemovePauliMeasurements, remove_pauli_measurements + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + from collections.abc import Set as AbstractSet + + from numpy.random import PCG64 + + from graphix.command import Node + from graphix.remove_pauli_measurements import Graph + + +def opengraph_lemma_2_31(measurements: Mapping[Node, Measurement]) -> OpenGraph[Measurement]: + graph: Graph = nx.Graph( + [ + (0, 1), + (0, 2), + (0, 3), + (1, 2), + (1, 3), + ] + ) + output_nodes = tuple(node for node in range(4) if node not in measurements) + return OpenGraph(graph, input_nodes=(1, 2, 3), output_nodes=output_nodes, measurements=measurements) + + +def opengraph_lemma_2_32(measurements: Mapping[Node, Measurement]) -> OpenGraph[Measurement]: + graph: Graph = nx.Graph( + [ + (0, 1), + (0, 2), + (1, 2), + (0, 3), + (1, 3), + (0, 4), + (3, 4), + (1, 5), + (3, 5), + (0, 6), + (2, 6), + (5, 6), + (1, 7), + (2, 7), + (4, 7), + ] + ) + output_nodes = tuple(node for node in range(8) if node not in measurements) + return OpenGraph(graph, input_nodes=(4, 5, 6, 7), output_nodes=output_nodes, measurements=measurements) + + +@pytest.mark.parametrize("measured_set", [set(), {1}, {2}]) +def test_local_complement(fx_rng: Generator, measured_set: AbstractSet[int]) -> None: + og = opengraph_lemma_2_31({node: Measurement.XY(0.25) for node in measured_set}) + pattern = og.to_pattern() + standardized_pattern = StandardizedPattern.from_pattern(pattern) + cut = PauliPushingCut.from_standardized_pattern(standardized_pattern) + remove_pauli_measurements = _RemovePauliMeasurements(cut) + remove_pauli_measurements.local_complement(0) + standardized_pattern2 = remove_pauli_measurements.to_standardized_pattern() + og2 = standardized_pattern2.extract_opengraph() + expected_graph: Graph = nx.Graph( + [ + (0, 1), + (0, 2), + (0, 3), + (2, 3), + ] + ) + assert nx.utils.graphs_equal(og2.graph, expected_graph) + pattern2 = standardized_pattern2.to_pattern() + assert pattern2.extract_gflow() + check_pattern_equivalence(pattern, pattern2, rng=fx_rng) + + +@pytest.mark.parametrize("measured_set", [set(), {4}, {4, 5}, {4, 5, 6}, {4, 5, 7}, {0}, {1}, {2}, {3}, {0, 2}]) +def test_pivot_vertices(fx_rng: Generator, measured_set: AbstractSet[int]) -> None: + og = opengraph_lemma_2_32({node: Measurement.XY(0.25) for node in measured_set}) + pattern = og.to_pattern() + standardized_pattern = StandardizedPattern.from_pattern(pattern) + cut = PauliPushingCut.from_standardized_pattern(standardized_pattern) + remove_pauli_measurements = _RemovePauliMeasurements(cut) + remove_pauli_measurements.pivot_vertices(0, 1) + standardized_pattern2 = remove_pauli_measurements.to_standardized_pattern() + og2 = standardized_pattern2.extract_opengraph() + expected_graph: Graph = nx.Graph( + [ + (0, 1), + (0, 2), + (1, 2), + (0, 3), + (1, 3), + (0, 4), + (2, 4), + (1, 5), + (2, 5), + (4, 5), + (0, 6), + (3, 6), + (1, 7), + (3, 7), + (6, 7), + ] + ) + assert nx.utils.graphs_equal(og2.graph, expected_graph) + assert og2.output_nodes == tuple(0 if node == 1 else 1 if node == 0 else node for node in og.output_nodes) + pattern2 = standardized_pattern2.to_pattern() + assert pattern2.extract_gflow() + check_pattern_equivalence(pattern, pattern2, rng=fx_rng) + + +@pytest.mark.parametrize("node", [0, 1, 2, 3]) +@pytest.mark.parametrize("sign", Sign) +def test_remove_z(fx_rng: Generator, node: Node, sign: Sign) -> None: + og = opengraph_lemma_2_32({node: PauliMeasurement(Axis.Z, sign)}) + pattern = og.to_pattern() + standardized_pattern = StandardizedPattern.from_pattern(pattern) + cut = PauliPushingCut.from_standardized_pattern(standardized_pattern) + remove_pauli_measurements = _RemovePauliMeasurements(cut) + remove_pauli_measurements.remove_z(node, sign) + standardized_pattern2 = remove_pauli_measurements.to_standardized_pattern() + pattern2 = standardized_pattern2.to_pattern() + check_pattern_equivalence(pattern, pattern2, rng=fx_rng) + + +@pytest.mark.parametrize("node", [0, 1, 2, 3]) +@pytest.mark.parametrize("sign", Sign) +def test_remove_y(fx_rng: Generator, node: Node, sign: Sign) -> None: + og = opengraph_lemma_2_32({node: PauliMeasurement(Axis.Y, sign)}) + pattern = og.to_pattern() + standardized_pattern = StandardizedPattern.from_pattern(pattern) + cut = PauliPushingCut.from_standardized_pattern(standardized_pattern) + remove_pauli_measurements = _RemovePauliMeasurements(cut) + remove_pauli_measurements.remove_y(node, sign) + standardized_pattern2 = remove_pauli_measurements.to_standardized_pattern() + pattern2 = standardized_pattern2.to_pattern() + check_pattern_equivalence(pattern, pattern2, rng=fx_rng) + + +@pytest.mark.parametrize("sign", Sign) +def test_remove_x_with_internal_neighbor(fx_rng: Generator, sign: Sign) -> None: + og = opengraph_lemma_2_32({0: PauliMeasurement(Axis.X, sign)}) + pattern = og.to_pattern() + standardized_pattern = StandardizedPattern.from_pattern(pattern) + cut = PauliPushingCut.from_standardized_pattern(standardized_pattern) + remove_pauli_measurements = _RemovePauliMeasurements(cut) + remove_pauli_measurements.remove_x_with_internal_neighbor(0, 1, sign) + standardized_pattern2 = remove_pauli_measurements.to_standardized_pattern() + pattern2 = standardized_pattern2.to_pattern() + check_pattern_equivalence(pattern, pattern2, rng=fx_rng) + + +def all_bloch_measurement_or_input_node(input_nodes: Iterable[Node], measurement_commands: Iterable[Command.M]) -> bool: + input_node_set = set(input_nodes) + return all( + isinstance(cmd_m.measurement, BlochMeasurement) or cmd_m.node in input_node_set + for cmd_m in measurement_commands + ) + + +def check_pattern(pattern: Pattern, rng: Generator) -> None: + standardized_pattern = StandardizedPattern.from_pattern(pattern) + cut = PauliPushingCut.from_standardized_pattern(standardized_pattern) + standardized_pattern2 = remove_pauli_measurements(cut) + + assert all_bloch_measurement_or_input_node(standardized_pattern2.input_nodes, standardized_pattern2.m_list) + + # Check that the pattern has a gflow + standardized_pattern2.to_bloch().extract_gflow() + + pattern2 = standardized_pattern2.to_pattern() + check_pattern_equivalence(pattern, pattern2, rng=rng) + + +def check_pattern_equivalence(pattern: Pattern, pattern2: Pattern, rng: Generator) -> None: + pattern.minimize_space() + pattern2.minimize_space() + for _ in range(4): + input_state = rand_state_vector(len(pattern.input_nodes), rng=rng) + state = pattern.simulate_pattern(input_state=input_state, rng=rng) + state2 = pattern2.simulate_pattern(input_state=input_state, rng=rng) + assert state.isclose(state2) + + +def test_ccx(fx_rng: Generator) -> None: + circuit = Circuit(3) + circuit.ccx(0, 1, 2) + check_pattern(circuit.transpile().pattern, fx_rng) + + +@pytest.mark.parametrize("jumps", range(1, 11)) +def test_random_circuit(fx_bg: PCG64, jumps: int) -> None: + rng = Generator(fx_bg.jumped(jumps)) + nqubits = 4 + depth = 4 + circuit = rand_circuit(nqubits, depth, rng) + check_pattern(circuit.transpile().pattern, rng) + + +def test_step_4() -> None: + graph: Graph = nx.Graph([(0, 1), (1, 2)]) + measurements = {0: Measurement.XY(0.25), 1: Measurement.X} + og = OpenGraph(graph, input_nodes=(0,), output_nodes=(2,), measurements=measurements) + pattern = og.to_pattern() + standardized_pattern = StandardizedPattern.from_pattern(pattern) + cut = PauliPushingCut.from_standardized_pattern(standardized_pattern) + standardized_pattern2 = remove_pauli_measurements(cut) + assert len(standardized_pattern2.m_list) == 1 + + +def test_step_4_no_flow() -> None: + # This example tests the case of a pattern that contains a + # non-input X-measured node 1 which is connected to an output node + # 0, where the node 0 is also an input. In this situation Lemma + # 4.11 cannot be applied; this exercices the filtering implemented + # in the `try_pivot_x_with_output_node` method. + pattern = Pattern(input_nodes=(0,), output_nodes=(0,), cmds=[Command.N(1), Command.E((0, 1)), Command.M(1)]) + standardized_pattern = StandardizedPattern.from_pattern(pattern) + cut = PauliPushingCut.from_standardized_pattern(standardized_pattern) + standardized_pattern2 = remove_pauli_measurements(cut) + assert len(standardized_pattern2.m_list) == 1 + + +def test_cliffords_in_original_pattern(fx_rng: Generator) -> None: + circuit = Circuit(2) + circuit.cnot(0, 1) + pattern = circuit.transpile().pattern + u, v = pattern.output_nodes + pattern.add(Command.C(u, Clifford.S)) + pattern.add(Command.C(v, Clifford.SDG)) + check_pattern(pattern, fx_rng) + + +def test_pattern_remove_pauli_measurements() -> None: + circuit = Circuit(2) + circuit.cnot(0, 1) + pattern = circuit.transpile().pattern + pattern2 = pattern.remove_pauli_measurements(copy=True) + assert all_bloch_measurement_or_input_node( + pattern2.input_nodes, (cmd for cmd in pattern2 if isinstance(cmd, Command.M)) + ) + assert not pattern2.is_standard() + pattern3 = pattern.remove_pauli_measurements(copy=True, standardize=True) + assert all_bloch_measurement_or_input_node( + pattern3.input_nodes, (cmd for cmd in pattern3 if isinstance(cmd, Command.M)) + ) + assert pattern3.is_standard() + assert not all_bloch_measurement_or_input_node( + pattern.input_nodes, (cmd for cmd in pattern if isinstance(cmd, Command.M)) + ) + pattern.remove_pauli_measurements() + assert all_bloch_measurement_or_input_node( + pattern.input_nodes, (cmd for cmd in pattern if isinstance(cmd, Command.M)) + ) + + +def test_pattern_remove_pauli_measurements_output_nodes() -> None: + og = OpenGraph( + graph=nx.Graph([(1, 2)]), + input_nodes=[], + output_nodes=[2], + measurements={ + 1: Measurement.X, + }, + ) + pattern = og.to_pattern() + pattern.remove_pauli_measurements() + pattern.simulate_pattern() + + +def test_try_pivot_x_with_output_node_after_pivot() -> None: + # This test checks that `try_pivot_x_with_output_node` applies + # `pivot_vertices` using `new_node` rather than the original + # `node`. + # + # In practice this situation is unlikely to arise: for `node != new_node` + # to occur, a pivot must have already been applied to `node`. Yet, + # after such a pivot we would need `new_node` to be measured in X, which + # implies that `node` was originally measured in Z. The removal strategy + # would then delete `node` before the pivot could take place. + # + # Consequently, this test guarantees that `try_pivot_x_with_output_node` + # works correctly regardless of the removal strategy and maintains the + # intended invariant, even though the earlier bug (pivoting with the + # original node) was not observable through the public API. + pattern = Pattern( + cmds=[ + Command.N(0), + Command.N(1), + Command.N(2), + Command.E((0, 1)), + Command.E((0, 2)), + Command.M(0), + Command.M(1, Measurement.Z), + ] + ) + standardized_pattern = StandardizedPattern.from_pattern(pattern) + cut = PauliPushingCut.from_standardized_pattern(standardized_pattern) + process = _RemovePauliMeasurements(cut) + process.remove_x_with_internal_neighbor(0, 1, Sign.PLUS) + # Fail if pivot is applied to the original node + process.try_pivot_x_with_output_node() diff --git a/tests/test_tnsim.py b/tests/test_tnsim.py index 9827576af..c347e492b 100644 --- a/tests/test_tnsim.py +++ b/tests/test_tnsim.py @@ -1,7 +1,6 @@ from __future__ import annotations import itertools -from typing import TYPE_CHECKING import numpy as np import numpy.typing as npt @@ -10,8 +9,7 @@ from quimb.tensor import Tensor from graphix.branch_selector import RandomBranchSelector -from graphix.clifford import Clifford -from graphix.command import C, E, X, Z +from graphix.command import E from graphix.fundamentals import ANGLE_PI from graphix.ops import Ops from graphix.random_objects import rand_circuit @@ -20,9 +18,6 @@ from graphix.states import BasicStates from graphix.transpiler import Circuit -if TYPE_CHECKING: - from graphix.command import CommandType - def random_op(sites: int, rng: Generator) -> npt.NDArray[np.complex128]: size = 2**sites @@ -73,37 +68,37 @@ def test_entangle_nodes(self, fx_rng: Generator) -> None: contracted_ref = np.einsum("abcd, c, d, ab->", CZ.reshape(2, 2, 2, 2), plus, plus, random_vec) assert contracted == pytest.approx(contracted_ref) - def test_apply_one_site_operator(self, fx_rng: Generator) -> None: - clifford = Clifford(fx_rng.integers(len(Clifford))) - cmds: list[CommandType] = [ - X(node=0, domain={15}), - Z(node=0, domain={15}), - C(node=0, clifford=clifford), - ] - random_vec = fx_rng.normal(size=2) - - circuit = Circuit(1) - pattern = circuit.transpile().pattern - pattern.results[15] = 1 # X&Z operator will be applied. - for cmd in cmds: - pattern.add(cmd) - tn = pattern.simulate_pattern(backend="tensornetwork", rng=fx_rng) - dummy_index = gen_str() - ind = tn._dangling.pop("0") - tensor = tn.tensor_map[tn._get_tids_from_inds(ind).popleft()] - tensor.reindex({ind: dummy_index}, inplace=True) - random_vec_ts = Tensor(random_vec, [dummy_index], ["random_vector"]) - tn.add_tensor(random_vec_ts) - contracted = tn.contract() - - # reference - ops = [ - np.array([[0.0, 1.0], [1.0, 0.0]]), - np.array([[1.0, 0.0], [0.0, -1.0]]), - clifford.matrix, - ] - contracted_ref = np.einsum("i,ij,jk,kl,l", random_vec, ops[2], ops[1], ops[0], plus) - assert contracted == pytest.approx(contracted_ref) + # def test_apply_one_site_operator(self, fx_rng: Generator) -> None: + # clifford = Clifford(fx_rng.integers(len(Clifford))) + # cmds: list[CommandType] = [ + # X(node=0, domain={15}), + # Z(node=0, domain={15}), + # C(node=0, clifford=clifford), + # ] + # random_vec = fx_rng.normal(size=2) + # + # circuit = Circuit(1) + # pattern = circuit.transpile().pattern + # pattern.results[15] = 1 # X&Z operator will be applied. + # for cmd in cmds: + # pattern.add(cmd) + # tn = pattern.simulate_pattern(backend="tensornetwork", rng=fx_rng) + # dummy_index = gen_str() + # ind = tn._dangling.pop("0") + # tensor = tn.tensor_map[tn._get_tids_from_inds(ind).popleft()] + # tensor.reindex({ind: dummy_index}, inplace=True) + # random_vec_ts = Tensor(random_vec, [dummy_index], ["random_vector"]) + # tn.add_tensor(random_vec_ts) + # contracted = tn.contract() + # + # # reference + # ops = [ + # np.array([[0.0, 1.0], [1.0, 0.0]]), + # np.array([[1.0, 0.0], [0.0, -1.0]]), + # clifford.matrix, + # ] + # contracted_ref = np.einsum("i,ij,jk,kl,l", random_vec, ops[2], ops[1], ops[0], plus) + # assert contracted == pytest.approx(contracted_ref) def test_expectation_value1(self, fx_rng: Generator) -> None: circuit = Circuit(1) @@ -335,8 +330,7 @@ def test_with_graphtrans(self, fx_bg: PCG64, jumps: int, fx_rng: Generator) -> N pattern = circuit.transpile().pattern pattern.standardize() pattern.shift_signals() - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() + pattern.remove_pauli_measurements() state = circuit.simulate_statevector().statevec tn_mbqc = pattern.simulate_pattern(backend="tensornetwork", rng=fx_rng) random_op3 = random_op(3, rng) @@ -354,8 +348,7 @@ def test_with_graphtrans_sequential(self, fx_bg: PCG64, jumps: int, fx_rng: Gene pattern = circuit.transpile().pattern pattern.standardize() pattern.shift_signals() - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() + pattern.remove_pauli_measurements() state = circuit.simulate_statevector().statevec tn_mbqc = pattern.simulate_pattern(backend="tensornetwork", graph_prep="sequential", rng=fx_rng) random_op3 = random_op(3, rng) @@ -403,8 +396,7 @@ def test_evolve(self, fx_bg: PCG64, jumps: int, fx_rng: Generator) -> None: pattern = circuit.transpile().pattern pattern.standardize() pattern.shift_signals() - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() + pattern.remove_pauli_measurements() state = circuit.simulate_statevector().statevec tn_mbqc = pattern.simulate_pattern(backend="tensornetwork", rng=fx_rng) random_op3 = random_op(3, rng) diff --git a/tests/test_visualization.py b/tests/test_visualization.py index e1ee898b7..9c6dae1f6 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -158,8 +158,7 @@ def example_hadamard() -> Pattern: def example_local_clifford() -> Pattern: pattern = example_hadamard() - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() + pattern.remove_pauli_measurements() return pattern @@ -258,8 +257,7 @@ def test_draw_graph_reference(flow_and_not_pauli_presimulate: bool) -> Figure: # to have causal flow. pattern = pattern.to_bloch() else: - pattern.remove_input_nodes() - pattern.perform_pauli_measurements() + pattern.remove_pauli_measurements() pattern.standardize() pattern.draw( flow_from_pattern=flow_and_not_pauli_presimulate, node_distance=(1, 1), measurement_labels=True, legend=False diff --git a/uv.lock b/uv.lock index c80acabd5..46b5e2481 100644 --- a/uv.lock +++ b/uv.lock @@ -993,7 +993,7 @@ requires-dist = [ { name = "pytest-mock", marker = "extra == 'dev'" }, { name = "pytest-mpl", marker = "extra == 'dev'" }, { name = "pyzx", marker = "extra == 'extra'", specifier = ">=0.10.0" }, - { name = "qiskit", marker = "extra == 'dev'", specifier = "==2.3.1" }, + { name = "qiskit", marker = "extra == 'dev'", specifier = ">=1.0" }, { name = "qiskit-aer", marker = "extra == 'dev'" }, { name = "qiskit-qasm3-import", marker = "extra == 'dev'" }, { name = "quimb" },