Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions graphix/pretty_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import TYPE_CHECKING, SupportsFloat

# `assert_never` introduced in Python 3.11
import numpy as np
from typing_extensions import assert_never

from graphix import command
Expand All @@ -26,6 +27,8 @@
from graphix.flow.core import PauliFlow, XZCorrections
from graphix.fundamentals import Angle
from graphix.pattern import Pattern
from graphix.sim.density_matrix import DensityMatrix
from graphix.sim.statevec import Statevec


class OutputFormat(Enum):
Expand All @@ -36,6 +39,158 @@ class OutputFormat(Enum):
Unicode = enum.auto()


def complex_to_str(value: complex, output: OutputFormat, *, atol: float = 1e-8) -> str:
"""Return a compact exact-looking representation of a complex number."""
val = complex(value)
if abs(val) <= atol:
return "0"

radius = abs(val)
if math.isclose(radius, 1.0, abs_tol=atol) and abs(val.real) > atol and abs(val.imag) > atol:
phase = math.atan2(val.imag, val.real) / pi
phase_str = angle_to_str(phase, output)
if output == OutputFormat.LaTeX:
return rf"\mathrm{{e}}^{{\mathrm{{i}}{phase_str}}}"
if output == OutputFormat.Unicode:
return f"e^(i{phase_str})"
return f"exp(i{phase_str})"

real = _real_to_str(val.real, output, atol=atol) if abs(val.real) > atol else ""
imag = _imaginary_to_str(val.imag, output, atol=atol) if abs(val.imag) > atol else ""
if real and imag:
sign = " + " if not imag.startswith("-") else " - "
return f"{real}{sign}{imag.removeprefix('-')}"
return real or imag


def statevector_to_str(
state: Statevec,
output: OutputFormat,
*,
encoding: str = "MSB",
atol: float = 1e-8,
) -> str:
"""Return a pretty ket expansion for a statevector."""
terms = []
for ket, amp in state.to_dict(encoding=encoding, atol=atol).items():
coeff = complex_to_str(amp, output, atol=atol)
terms.append(_ket_term(coeff, ket, output))
body = _join_terms(terms) if terms else "0"
return rf"\({body}\)" if output == OutputFormat.LaTeX else body


def density_matrix_to_str(
state: DensityMatrix,
output: OutputFormat,
*,
encoding: str = "MSB",
atol: float = 1e-8,
) -> str:
"""Return a pretty outer-product expansion for a density matrix."""
terms = []
for row in range(1 << state.nqubit):
bra_ket = _format_basis(state.nqubit, row, encoding)
for col in range(1 << state.nqubit):
coeff = state.rho[row, col]
if np.isclose(abs(coeff), 0, atol=atol, rtol=0):
continue
terms.append(_density_term(complex_to_str(coeff, output, atol=atol), bra_ket, _format_basis(state.nqubit, col, encoding), output))
body = _join_terms(terms) if terms else "0"
return rf"\({body}\)" if output == OutputFormat.LaTeX else body


def _real_to_str(value: float, output: OutputFormat, *, atol: float) -> str:
sign = "-" if value < 0 else ""
val = abs(value)
if math.isclose(val, math.sqrt(2) / 2, abs_tol=atol):
return sign + _sqrt_fraction(2, 2, output)
if math.isclose(val, math.sqrt(3) / 2, abs_tol=atol):
return sign + _sqrt_fraction(3, 2, output)

frac = Fraction(val).limit_denominator(16)
if math.isclose(val, float(frac), abs_tol=atol):
return sign + _fraction_to_str(frac.numerator, frac.denominator, output)
return f"{value:.8g}"


def _imaginary_to_str(value: float, output: OutputFormat, *, atol: float) -> str:
sign = "-" if value < 0 else ""
magnitude = _real_to_str(abs(value), output, atol=atol)
unit = r"\mathrm{i}" if output == OutputFormat.LaTeX else "i"
if magnitude == "1":
return f"{sign}{unit}"
if output == OutputFormat.LaTeX:
return f"{sign}{magnitude}{unit}"
return f"{sign}{magnitude}{unit}"


def _fraction_to_str(num: int, den: int, output: OutputFormat) -> str:
if den == 1:
return str(num)
if output == OutputFormat.LaTeX:
return rf"\frac{{{num}}}{{{den}}}"
return f"{num}/{den}"


def _sqrt_fraction(rad: int, den: int, output: OutputFormat) -> str:
if output == OutputFormat.LaTeX:
return rf"\frac{{\sqrt{{{rad}}}}}{{{den}}}"
root = f"√{rad}" if output == OutputFormat.Unicode else f"sqrt({rad})"
return f"{root}/{den}"


def _ket(label: str, output: OutputFormat) -> str:
if output == OutputFormat.LaTeX:
return rf"\lvert {label}\rangle"
if output == OutputFormat.Unicode:
return f"|{label}⟩"
return f"|{label}>"


def _bra(label: str, output: OutputFormat) -> str:
if output == OutputFormat.LaTeX:
return rf"\langle {label}\rvert"
if output == OutputFormat.Unicode:
return f"⟨{label}|"
return f"<{label}|"


def _ket_term(coeff: str, ket: str, output: OutputFormat) -> str:
if coeff == "1":
return _ket(ket, output)
if coeff == "-1":
return f"-{_ket(ket, output)}"
return f"{coeff}{_ket(ket, output)}"


def _density_term(coeff: str, ket_label: str, bra_label: str, output: OutputFormat) -> str:
outer = f"{_ket(ket_label, output)}{_bra(bra_label, output)}"
if coeff == "1":
return outer
if coeff == "-1":
return f"-{outer}"
return f"{coeff}{outer}"


def _join_terms(terms: list[str]) -> str:
result = terms[0]
for term in terms[1:]:
if term.startswith("-"):
result += f" - {term[1:]}"
else:
result += f" + {term}"
return result


def _format_basis(nqubit: int, index: int, encoding: str) -> str:
label = f"{index:0{nqubit}b}"
if encoding == "MSB":
return label
if encoding == "LSB":
return label[::-1]
raise ValueError("encoding must be either 'MSB' or 'LSB'.")


def angle_to_str(
angle: Angle, output: OutputFormat, max_denominator: int = 1000, multiplication_sign: bool = False
) -> str:
Expand Down
5 changes: 5 additions & 0 deletions graphix/sim/density_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from graphix import parameter
from graphix.channels import KrausChannel
from graphix.parameter import Expression, ExpressionOrFloat, ExpressionOrSupportsComplex
from graphix.pretty_print import OutputFormat, density_matrix_to_str
from graphix.sim.base_backend import DenseState, DenseStateBackend, Matrix, kron, matmul, outer, tensordot, vdot
from graphix.sim.statevec import CNOT_TENSOR, CZ_TENSOR, SWAP_TENSOR, Statevec
from graphix.states import BasicStates, State
Expand Down Expand Up @@ -117,6 +118,10 @@ def __str__(self) -> str:
"""Return a string description."""
return f"DensityMatrix object, with density matrix {self.rho} and shape {self.dims()}."

def draw(self, output: OutputFormat = OutputFormat.ASCII, encoding: str = "MSB", *, atol: float = 1e-8) -> str:
"""Return a pretty outer-product expansion of the density matrix."""
return density_matrix_to_str(self, output=output, encoding=encoding, atol=atol)

@override
def add_nodes(self, nqubit: int, data: Data) -> None:
r"""
Expand Down
5 changes: 5 additions & 0 deletions graphix/sim/statevec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from graphix import parameter, states
from graphix.parameter import Expression, ExpressionOrSupportsComplex, check_expression_or_float
from graphix.pretty_print import OutputFormat, statevector_to_str
from graphix.sim.base_backend import DenseState, DenseStateBackend, Matrix, kron, tensordot
from graphix.states import BasicStates

Expand Down Expand Up @@ -520,6 +521,10 @@ def to_prob_dict(
"""
return self._to_dict_map(lambda x: np.abs(x) ** 2, encoding, rtol=rtol, atol=atol)

def draw(self, output: OutputFormat = OutputFormat.ASCII, encoding: _ENCODING = "MSB", *, atol: float = 1e-8) -> str:
"""Return a pretty ket expansion of the statevector."""
return statevector_to_str(self, output=output, encoding=encoding, atol=atol)

def _to_dict_map(
self,
f: Callable[[npt.NDArray[np.object_ | np.complex128]], npt.NDArray[_ScalarT]],
Expand Down
26 changes: 25 additions & 1 deletion tests/test_pretty_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
from graphix.opengraph import OpenGraph
from graphix.parameter import Placeholder
from graphix.pattern import Pattern
from graphix.pretty_print import OutputFormat, pattern_to_str
from graphix.pretty_print import OutputFormat, complex_to_str, pattern_to_str
from graphix.random_objects import rand_circuit
from graphix.sim.density_matrix import DensityMatrix
from graphix.sim.statevec import Statevec
from graphix.states import BasicStates
from graphix.transpiler import Circuit

if TYPE_CHECKING:
Expand Down Expand Up @@ -202,3 +205,24 @@ def test_xzcorr_str() -> None:
str(flow)
== "x(3) = {5}, x(4) = {6}, x(1) = {3}, x(2) = {4}; z(1) = {4, 5}, z(2) = {3, 6}; {1, 2} < {3, 4} < {5, 6}"
)


def test_complex_number_pretty_print() -> None:
assert complex_to_str(0.25, OutputFormat.ASCII) == "1/4"
assert complex_to_str(2**-0.5, OutputFormat.Unicode) == "√2/2"
assert complex_to_str(0.5 + 0.8660254037844386j, OutputFormat.LaTeX) == r"\mathrm{e}^{\mathrm{i}\frac{\pi}{3}}"


def test_statevector_draw() -> None:
state = Statevec(data=BasicStates.PLUS)

assert state.draw() == "sqrt(2)/2|0> + sqrt(2)/2|1>"
assert state.draw(OutputFormat.Unicode) == "√2/2|0⟩ + √2/2|1⟩"
assert state.draw(OutputFormat.LaTeX) == r"\(\frac{\sqrt{2}}{2}\lvert 0\rangle + \frac{\sqrt{2}}{2}\lvert 1\rangle\)"


def test_density_matrix_draw() -> None:
state = DensityMatrix(data=BasicStates.PLUS)

assert state.draw() == "1/2|0><0| + 1/2|0><1| + 1/2|1><0| + 1/2|1><1|"
assert state.draw(OutputFormat.Unicode) == "1/2|0⟩⟨0| + 1/2|0⟩⟨1| + 1/2|1⟩⟨0| + 1/2|1⟩⟨1|"