From 607fb23308eb1ea594de84e3a04026d607516d24 Mon Sep 17 00:00:00 2001 From: donglrd <31209724+donglrd@users.noreply.github.com> Date: Tue, 12 May 2026 09:20:59 +0800 Subject: [PATCH] feat: add state pretty printers --- graphix/pretty_print.py | 155 ++++++++++++++++++++++++++++++++++ graphix/sim/density_matrix.py | 5 ++ graphix/sim/statevec.py | 5 ++ tests/test_pretty_print.py | 26 +++++- 4 files changed, 190 insertions(+), 1 deletion(-) diff --git a/graphix/pretty_print.py b/graphix/pretty_print.py index b2bd72e8f..84135c67e 100644 --- a/graphix/pretty_print.py +++ b/graphix/pretty_print.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, SupportsFloat # `assert_never` introduced in Python 3.11 +import numpy as np from typing_extensions import assert_never from graphix import command @@ -26,6 +27,8 @@ from graphix.flow.core import PauliFlow, XZCorrections from graphix.fundamentals import Angle from graphix.pattern import Pattern + from graphix.sim.density_matrix import DensityMatrix + from graphix.sim.statevec import Statevec class OutputFormat(Enum): @@ -36,6 +39,158 @@ class OutputFormat(Enum): Unicode = enum.auto() +def complex_to_str(value: complex, output: OutputFormat, *, atol: float = 1e-8) -> str: + """Return a compact exact-looking representation of a complex number.""" + val = complex(value) + if abs(val) <= atol: + return "0" + + radius = abs(val) + if math.isclose(radius, 1.0, abs_tol=atol) and abs(val.real) > atol and abs(val.imag) > atol: + phase = math.atan2(val.imag, val.real) / pi + phase_str = angle_to_str(phase, output) + if output == OutputFormat.LaTeX: + return rf"\mathrm{{e}}^{{\mathrm{{i}}{phase_str}}}" + if output == OutputFormat.Unicode: + return f"e^(i{phase_str})" + return f"exp(i{phase_str})" + + real = _real_to_str(val.real, output, atol=atol) if abs(val.real) > atol else "" + imag = _imaginary_to_str(val.imag, output, atol=atol) if abs(val.imag) > atol else "" + if real and imag: + sign = " + " if not imag.startswith("-") else " - " + return f"{real}{sign}{imag.removeprefix('-')}" + return real or imag + + +def statevector_to_str( + state: Statevec, + output: OutputFormat, + *, + encoding: str = "MSB", + atol: float = 1e-8, +) -> str: + """Return a pretty ket expansion for a statevector.""" + terms = [] + for ket, amp in state.to_dict(encoding=encoding, atol=atol).items(): + coeff = complex_to_str(amp, output, atol=atol) + terms.append(_ket_term(coeff, ket, output)) + body = _join_terms(terms) if terms else "0" + return rf"\({body}\)" if output == OutputFormat.LaTeX else body + + +def density_matrix_to_str( + state: DensityMatrix, + output: OutputFormat, + *, + encoding: str = "MSB", + atol: float = 1e-8, +) -> str: + """Return a pretty outer-product expansion for a density matrix.""" + terms = [] + for row in range(1 << state.nqubit): + bra_ket = _format_basis(state.nqubit, row, encoding) + for col in range(1 << state.nqubit): + coeff = state.rho[row, col] + if np.isclose(abs(coeff), 0, atol=atol, rtol=0): + continue + terms.append(_density_term(complex_to_str(coeff, output, atol=atol), bra_ket, _format_basis(state.nqubit, col, encoding), output)) + body = _join_terms(terms) if terms else "0" + return rf"\({body}\)" if output == OutputFormat.LaTeX else body + + +def _real_to_str(value: float, output: OutputFormat, *, atol: float) -> str: + sign = "-" if value < 0 else "" + val = abs(value) + if math.isclose(val, math.sqrt(2) / 2, abs_tol=atol): + return sign + _sqrt_fraction(2, 2, output) + if math.isclose(val, math.sqrt(3) / 2, abs_tol=atol): + return sign + _sqrt_fraction(3, 2, output) + + frac = Fraction(val).limit_denominator(16) + if math.isclose(val, float(frac), abs_tol=atol): + return sign + _fraction_to_str(frac.numerator, frac.denominator, output) + return f"{value:.8g}" + + +def _imaginary_to_str(value: float, output: OutputFormat, *, atol: float) -> str: + sign = "-" if value < 0 else "" + magnitude = _real_to_str(abs(value), output, atol=atol) + unit = r"\mathrm{i}" if output == OutputFormat.LaTeX else "i" + if magnitude == "1": + return f"{sign}{unit}" + if output == OutputFormat.LaTeX: + return f"{sign}{magnitude}{unit}" + return f"{sign}{magnitude}{unit}" + + +def _fraction_to_str(num: int, den: int, output: OutputFormat) -> str: + if den == 1: + return str(num) + if output == OutputFormat.LaTeX: + return rf"\frac{{{num}}}{{{den}}}" + return f"{num}/{den}" + + +def _sqrt_fraction(rad: int, den: int, output: OutputFormat) -> str: + if output == OutputFormat.LaTeX: + return rf"\frac{{\sqrt{{{rad}}}}}{{{den}}}" + root = f"√{rad}" if output == OutputFormat.Unicode else f"sqrt({rad})" + return f"{root}/{den}" + + +def _ket(label: str, output: OutputFormat) -> str: + if output == OutputFormat.LaTeX: + return rf"\lvert {label}\rangle" + if output == OutputFormat.Unicode: + return f"|{label}⟩" + return f"|{label}>" + + +def _bra(label: str, output: OutputFormat) -> str: + if output == OutputFormat.LaTeX: + return rf"\langle {label}\rvert" + if output == OutputFormat.Unicode: + return f"⟨{label}|" + return f"<{label}|" + + +def _ket_term(coeff: str, ket: str, output: OutputFormat) -> str: + if coeff == "1": + return _ket(ket, output) + if coeff == "-1": + return f"-{_ket(ket, output)}" + return f"{coeff}{_ket(ket, output)}" + + +def _density_term(coeff: str, ket_label: str, bra_label: str, output: OutputFormat) -> str: + outer = f"{_ket(ket_label, output)}{_bra(bra_label, output)}" + if coeff == "1": + return outer + if coeff == "-1": + return f"-{outer}" + return f"{coeff}{outer}" + + +def _join_terms(terms: list[str]) -> str: + result = terms[0] + for term in terms[1:]: + if term.startswith("-"): + result += f" - {term[1:]}" + else: + result += f" + {term}" + return result + + +def _format_basis(nqubit: int, index: int, encoding: str) -> str: + label = f"{index:0{nqubit}b}" + if encoding == "MSB": + return label + if encoding == "LSB": + return label[::-1] + raise ValueError("encoding must be either 'MSB' or 'LSB'.") + + def angle_to_str( angle: Angle, output: OutputFormat, max_denominator: int = 1000, multiplication_sign: bool = False ) -> str: diff --git a/graphix/sim/density_matrix.py b/graphix/sim/density_matrix.py index cb5570e2a..dddea33cb 100644 --- a/graphix/sim/density_matrix.py +++ b/graphix/sim/density_matrix.py @@ -19,6 +19,7 @@ from graphix import parameter from graphix.channels import KrausChannel from graphix.parameter import Expression, ExpressionOrFloat, ExpressionOrSupportsComplex +from graphix.pretty_print import OutputFormat, density_matrix_to_str from graphix.sim.base_backend import DenseState, DenseStateBackend, Matrix, kron, matmul, outer, tensordot, vdot from graphix.sim.statevec import CNOT_TENSOR, CZ_TENSOR, SWAP_TENSOR, Statevec from graphix.states import BasicStates, State @@ -117,6 +118,10 @@ def __str__(self) -> str: """Return a string description.""" return f"DensityMatrix object, with density matrix {self.rho} and shape {self.dims()}." + def draw(self, output: OutputFormat = OutputFormat.ASCII, encoding: str = "MSB", *, atol: float = 1e-8) -> str: + """Return a pretty outer-product expansion of the density matrix.""" + return density_matrix_to_str(self, output=output, encoding=encoding, atol=atol) + @override def add_nodes(self, nqubit: int, data: Data) -> None: r""" diff --git a/graphix/sim/statevec.py b/graphix/sim/statevec.py index 1f7dd23e6..e11575540 100644 --- a/graphix/sim/statevec.py +++ b/graphix/sim/statevec.py @@ -16,6 +16,7 @@ from graphix import parameter, states from graphix.parameter import Expression, ExpressionOrSupportsComplex, check_expression_or_float +from graphix.pretty_print import OutputFormat, statevector_to_str from graphix.sim.base_backend import DenseState, DenseStateBackend, Matrix, kron, tensordot from graphix.states import BasicStates @@ -520,6 +521,10 @@ def to_prob_dict( """ return self._to_dict_map(lambda x: np.abs(x) ** 2, encoding, rtol=rtol, atol=atol) + def draw(self, output: OutputFormat = OutputFormat.ASCII, encoding: _ENCODING = "MSB", *, atol: float = 1e-8) -> str: + """Return a pretty ket expansion of the statevector.""" + return statevector_to_str(self, output=output, encoding=encoding, atol=atol) + def _to_dict_map( self, f: Callable[[npt.NDArray[np.object_ | np.complex128]], npt.NDArray[_ScalarT]], diff --git a/tests/test_pretty_print.py b/tests/test_pretty_print.py index 28dedce55..51f74e208 100644 --- a/tests/test_pretty_print.py +++ b/tests/test_pretty_print.py @@ -13,8 +13,11 @@ from graphix.opengraph import OpenGraph from graphix.parameter import Placeholder from graphix.pattern import Pattern -from graphix.pretty_print import OutputFormat, pattern_to_str +from graphix.pretty_print import OutputFormat, complex_to_str, pattern_to_str from graphix.random_objects import rand_circuit +from graphix.sim.density_matrix import DensityMatrix +from graphix.sim.statevec import Statevec +from graphix.states import BasicStates from graphix.transpiler import Circuit if TYPE_CHECKING: @@ -202,3 +205,24 @@ def test_xzcorr_str() -> None: str(flow) == "x(3) = {5}, x(4) = {6}, x(1) = {3}, x(2) = {4}; z(1) = {4, 5}, z(2) = {3, 6}; {1, 2} < {3, 4} < {5, 6}" ) + + +def test_complex_number_pretty_print() -> None: + assert complex_to_str(0.25, OutputFormat.ASCII) == "1/4" + assert complex_to_str(2**-0.5, OutputFormat.Unicode) == "√2/2" + assert complex_to_str(0.5 + 0.8660254037844386j, OutputFormat.LaTeX) == r"\mathrm{e}^{\mathrm{i}\frac{\pi}{3}}" + + +def test_statevector_draw() -> None: + state = Statevec(data=BasicStates.PLUS) + + assert state.draw() == "sqrt(2)/2|0> + sqrt(2)/2|1>" + assert state.draw(OutputFormat.Unicode) == "√2/2|0⟩ + √2/2|1⟩" + assert state.draw(OutputFormat.LaTeX) == r"\(\frac{\sqrt{2}}{2}\lvert 0\rangle + \frac{\sqrt{2}}{2}\lvert 1\rangle\)" + + +def test_density_matrix_draw() -> None: + state = DensityMatrix(data=BasicStates.PLUS) + + assert state.draw() == "1/2|0><0| + 1/2|0><1| + 1/2|1><0| + 1/2|1><1|" + assert state.draw(OutputFormat.Unicode) == "1/2|0⟩⟨0| + 1/2|0⟩⟨1| + 1/2|1⟩⟨0| + 1/2|1⟩⟨1|"