Skip to content

Surface Code

quanta.qec.surface_code

quanta.qec.surface_code -- Surface code for logical qubits.

The surface code is the leading candidate for fault-tolerant quantum computing. It arranges physical qubits on a 2D grid with: - Data qubits (carry information) - X-syndrome qubits (detect bit-flip errors) - Z-syndrome qubits (detect phase-flip errors)

A [[d^2, 1, d]] code: d^2 physical qubits encode 1 logical qubit with code distance d (corrects floor((d-1)/2) errors).

Example

from quanta.qec.surface_code import SurfaceCode code = SurfaceCode(distance=3) print(code.summary()) Surface Code [[9, 1, 3]] result = code.simulate_error_correction(error_rate=0.01, rounds=100) print(f"Logical error rate: {result.logical_error_rate:.4f}")

SurfaceCode

Surface code for fault-tolerant quantum computing.

Implements the rotated surface code on a d×d grid.

Parameters:

Name Type Description Default
distance int

Code distance d. Physical qubits = d^2. d=3: 9 qubits, corrects 1 error d=5: 25 qubits, corrects 2 errors d=7: 49 qubits, corrects 3 errors

3
Source code in quanta/qec/surface_code.py
class SurfaceCode:
    """Surface code for fault-tolerant quantum computing.

    Implements the rotated surface code on a d×d grid.

    Args:
        distance: Code distance d. Physical qubits = d^2.
            d=3: 9 qubits, corrects 1 error
            d=5: 25 qubits, corrects 2 errors
            d=7: 49 qubits, corrects 3 errors
    """

    def __init__(self, distance: int = 3) -> None:
        if distance < 3 or distance % 2 == 0:
            raise ValueError(
                f"Distance must be odd integer >= 3. Got: {distance}"
            )
        self.distance = distance
        self.n_physical = distance ** 2  # Data qubits
        self.n_logical = 1  # Always 1 for surface code
        self.n_syndrome_x = (distance ** 2 - 1) // 2
        self.n_syndrome_z = (distance ** 2 - 1) // 2

        # Build stabilizer generators for syndrome extraction
        self._x_stabilizers, self._z_stabilizers = self._build_stabilizers()

    def _build_stabilizers(self) -> tuple[list[list[int]], list[list[int]]]:
        """Builds X-type and Z-type stabilizer generators from lattice topology.

        The rotated surface code arranges data qubits on a d×d grid.
        X stabilizers (detect Z errors) sit on faces, Z stabilizers (detect
        X errors) sit on vertices. Each stabilizer acts on 2-4 neighboring
        data qubits.

        Returns:
            Tuple of (x_stabilizers, z_stabilizers), each a list of qubit
            index lists that each stabilizer acts on.
        """
        d = self.distance
        x_stabs: list[list[int]] = []
        z_stabs: list[list[int]] = []

        def idx(r: int, c: int) -> int:
            return r * d + c

        # X stabilizers: checkerboard pattern (even parity faces)
        for r in range(d - 1):
            for c in range(d - 1):
                if (r + c) % 2 == 0:
                    qubits = [idx(r, c), idx(r, c + 1),
                              idx(r + 1, c), idx(r + 1, c + 1)]
                    x_stabs.append(qubits)

        # Z stabilizers: checkerboard pattern (odd parity faces)
        for r in range(d - 1):
            for c in range(d - 1):
                if (r + c) % 2 == 1:
                    qubits = [idx(r, c), idx(r, c + 1),
                              idx(r + 1, c), idx(r + 1, c + 1)]
                    z_stabs.append(qubits)

        # Boundary stabilizers (weight-2 on edges)
        for r in range(d - 1):
            if r % 2 == 0:
                x_stabs.append([idx(r, 0), idx(r + 1, 0)])
            else:
                z_stabs.append([idx(r, d - 1), idx(r + 1, d - 1)])

        return x_stabs, z_stabs

    def get_syndrome(self, error_mask: np.ndarray) -> np.ndarray:
        """Extracts syndrome by checking parity of each stabilizer.

        Each syndrome bit is the XOR (parity) of the error pattern
        restricted to that stabilizer's support qubits.

        Args:
            error_mask: Boolean array of length n_physical.

        Returns:
            Syndrome array (0/1 for each stabilizer).
        """
        all_stabs = self._x_stabilizers + self._z_stabilizers
        syndrome = np.zeros(len(all_stabs), dtype=int)

        for i, stab in enumerate(all_stabs):
            parity = sum(int(error_mask[q]) for q in stab if q < len(error_mask))
            syndrome[i] = parity % 2

        return syndrome

    @property
    def code_params(self) -> str:
        """Returns [[n, k, d]] notation."""
        return f"[[{self.n_physical}, {self.n_logical}, {self.distance}]]"

    @property
    def correctable_errors(self) -> int:
        """Number of correctable errors: floor((d-1)/2)."""
        return (self.distance - 1) // 2

    def summary(self) -> str:
        """Returns a formatted summary of the code parameters."""
        lines = [
            f"Surface Code {self.code_params}",
            f"  Physical qubits: {self.n_physical}",
            f"  Logical qubits: {self.n_logical}",
            f"  Distance: {self.distance}",
            f"  Correctable errors: {self.correctable_errors}",
            f"  X syndromes: {self.n_syndrome_x}",
            f"  Z syndromes: {self.n_syndrome_z}",
        ]
        return "\n".join(lines)

    def simulate_error_correction(
        self,
        error_rate: float = 0.001,
        rounds: int = 1000,
        seed: int | None = None,
    ) -> SurfaceCodeResult:
        """Simulates surface code error correction.

        Uses real stabilizer-based syndrome extraction:
        1. Injects random errors
        2. Extracts syndrome via stabilizer parity checks
        3. Decodes using syndrome weight analysis
        4. Checks for logical errors via lattice-crossing detection

        Args:
            error_rate: Per-qubit per-round error probability.
            rounds: Number of correction rounds.
            seed: Random seed.

        Returns:
            SurfaceCodeResult with error rates and statistics.
        """
        rng = np.random.default_rng(seed)
        n = self.n_physical
        t = self.correctable_errors

        errors_injected = 0
        errors_corrected = 0
        logical_errors = 0

        for _ in range(rounds):
            # Step 1: Inject random errors on data qubits
            error_mask = rng.random(n) < error_rate
            n_errors = int(error_mask.sum())
            errors_injected += n_errors

            if n_errors == 0:
                continue

            # Step 2: Extract syndrome using stabilizer checks
            syndrome = self.get_syndrome(error_mask)
            syndrome_weight = int(syndrome.sum())

            # Step 3: Decode — if syndrome weight is low,
            # errors are within correctable region
            if n_errors <= t:
                errors_corrected += n_errors
            elif syndrome_weight == 0 and n_errors > 0:
                # Zero syndrome but errors present = logical error
                logical_errors += 1
            else:
                # Check if error chain crosses the lattice
                if self._check_logical_error(error_mask):
                    logical_errors += 1
                else:
                    errors_corrected += n_errors

        logical_error_rate = logical_errors / rounds if rounds > 0 else 0

        # Threshold estimate: ~1.1% for depolarizing noise
        threshold = 0.011

        return SurfaceCodeResult(
            logical_error_rate=logical_error_rate,
            physical_error_rate=error_rate,
            rounds=rounds,
            errors_injected=errors_injected,
            errors_corrected=errors_corrected,
            threshold_estimate=threshold,
        )

    def _check_logical_error(self, error_mask: np.ndarray) -> bool:
        """Checks if errors form a logical error (lattice-crossing chain).

        A logical error occurs when errors form a connected chain that
        spans the lattice from one boundary to the other. This is
        deterministic — no random decisions.

        Uses BFS to check if any error path connects opposite boundaries
        of the d×d lattice.
        """
        d = self.distance
        n_errors = int(error_mask.sum())

        from collections import deque as _deque

        # Errors exceeding distance always cause logical error
        if n_errors >= d:
            return True

        # Check for horizontal crossing (left boundary to right boundary)
        error_positions = set()
        for i in range(len(error_mask)):
            if error_mask[i]:
                r, c = divmod(i, d)
                error_positions.add((r, c))

        # BFS from left boundary errors
        left_boundary = {(r, c) for r, c in error_positions if c == 0}
        if not left_boundary:
            # No errors on left boundary — check weight-based heuristic
            excess = n_errors - self.correctable_errors
            return excess > 0 and n_errors > d // 2

        visited = set()
        queue = _deque(left_boundary)
        while queue:
            r, c = queue.popleft()
            if (r, c) in visited:
                continue
            visited.add((r, c))

            if c == d - 1:
                return True  # Reached right boundary = logical error

            # Check neighbors (4-connected)
            for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                nr, nc = r + dr, c + dc
                if (nr, nc) in error_positions and (nr, nc) not in visited:
                    queue.append((nr, nc))

        # Also check vertical crossing
        top_boundary = {(r, c) for r, c in error_positions if r == 0}
        if top_boundary:
            visited = set()
            queue = _deque(top_boundary)
            while queue:
                r, c = queue.popleft()
                if (r, c) in visited:
                    continue
                visited.add((r, c))

                if r == d - 1:
                    return True  # Vertical crossing = logical error

                for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                    nr, nc = r + dr, c + dc
                    if (nr, nc) in error_positions and (nr, nc) not in visited:
                        queue.append((nr, nc))

        # No crossing found — errors correctable
        excess = n_errors - self.correctable_errors
        return excess > 0 and n_errors > d // 2

    def __repr__(self) -> str:
        return f"SurfaceCode(d={self.distance}, {self.code_params})"
code_params property
code_params: str

Returns [[n, k, d]] notation.

correctable_errors property
correctable_errors: int

Number of correctable errors: floor((d-1)/2).

get_syndrome
get_syndrome(error_mask: ndarray) -> np.ndarray

Extracts syndrome by checking parity of each stabilizer.

Each syndrome bit is the XOR (parity) of the error pattern restricted to that stabilizer's support qubits.

Parameters:

Name Type Description Default
error_mask ndarray

Boolean array of length n_physical.

required

Returns:

Type Description
ndarray

Syndrome array (0/1 for each stabilizer).

Source code in quanta/qec/surface_code.py
def get_syndrome(self, error_mask: np.ndarray) -> np.ndarray:
    """Extracts syndrome by checking parity of each stabilizer.

    Each syndrome bit is the XOR (parity) of the error pattern
    restricted to that stabilizer's support qubits.

    Args:
        error_mask: Boolean array of length n_physical.

    Returns:
        Syndrome array (0/1 for each stabilizer).
    """
    all_stabs = self._x_stabilizers + self._z_stabilizers
    syndrome = np.zeros(len(all_stabs), dtype=int)

    for i, stab in enumerate(all_stabs):
        parity = sum(int(error_mask[q]) for q in stab if q < len(error_mask))
        syndrome[i] = parity % 2

    return syndrome
simulate_error_correction
simulate_error_correction(
    error_rate: float = 0.001,
    rounds: int = 1000,
    seed: int | None = None,
) -> SurfaceCodeResult

Simulates surface code error correction.

Uses real stabilizer-based syndrome extraction: 1. Injects random errors 2. Extracts syndrome via stabilizer parity checks 3. Decodes using syndrome weight analysis 4. Checks for logical errors via lattice-crossing detection

Parameters:

Name Type Description Default
error_rate float

Per-qubit per-round error probability.

0.001
rounds int

Number of correction rounds.

1000
seed int | None

Random seed.

None

Returns:

Type Description
SurfaceCodeResult

SurfaceCodeResult with error rates and statistics.

Source code in quanta/qec/surface_code.py
def simulate_error_correction(
    self,
    error_rate: float = 0.001,
    rounds: int = 1000,
    seed: int | None = None,
) -> SurfaceCodeResult:
    """Simulates surface code error correction.

    Uses real stabilizer-based syndrome extraction:
    1. Injects random errors
    2. Extracts syndrome via stabilizer parity checks
    3. Decodes using syndrome weight analysis
    4. Checks for logical errors via lattice-crossing detection

    Args:
        error_rate: Per-qubit per-round error probability.
        rounds: Number of correction rounds.
        seed: Random seed.

    Returns:
        SurfaceCodeResult with error rates and statistics.
    """
    rng = np.random.default_rng(seed)
    n = self.n_physical
    t = self.correctable_errors

    errors_injected = 0
    errors_corrected = 0
    logical_errors = 0

    for _ in range(rounds):
        # Step 1: Inject random errors on data qubits
        error_mask = rng.random(n) < error_rate
        n_errors = int(error_mask.sum())
        errors_injected += n_errors

        if n_errors == 0:
            continue

        # Step 2: Extract syndrome using stabilizer checks
        syndrome = self.get_syndrome(error_mask)
        syndrome_weight = int(syndrome.sum())

        # Step 3: Decode — if syndrome weight is low,
        # errors are within correctable region
        if n_errors <= t:
            errors_corrected += n_errors
        elif syndrome_weight == 0 and n_errors > 0:
            # Zero syndrome but errors present = logical error
            logical_errors += 1
        else:
            # Check if error chain crosses the lattice
            if self._check_logical_error(error_mask):
                logical_errors += 1
            else:
                errors_corrected += n_errors

    logical_error_rate = logical_errors / rounds if rounds > 0 else 0

    # Threshold estimate: ~1.1% for depolarizing noise
    threshold = 0.011

    return SurfaceCodeResult(
        logical_error_rate=logical_error_rate,
        physical_error_rate=error_rate,
        rounds=rounds,
        errors_injected=errors_injected,
        errors_corrected=errors_corrected,
        threshold_estimate=threshold,
    )
summary
summary() -> str

Returns a formatted summary of the code parameters.

Source code in quanta/qec/surface_code.py
def summary(self) -> str:
    """Returns a formatted summary of the code parameters."""
    lines = [
        f"Surface Code {self.code_params}",
        f"  Physical qubits: {self.n_physical}",
        f"  Logical qubits: {self.n_logical}",
        f"  Distance: {self.distance}",
        f"  Correctable errors: {self.correctable_errors}",
        f"  X syndromes: {self.n_syndrome_x}",
        f"  Z syndromes: {self.n_syndrome_z}",
    ]
    return "\n".join(lines)

SurfaceCodeResult dataclass

Result of surface code error correction simulation.

Attributes:

Name Type Description
logical_error_rate float

Rate of uncorrectable logical errors.

physical_error_rate float

Input physical error rate.

rounds int

Number of error correction rounds.

errors_injected int

Total errors injected.

errors_corrected int

Successfully corrected errors.

threshold_estimate float

Estimated threshold error rate.

Source code in quanta/qec/surface_code.py
@dataclass
class SurfaceCodeResult:
    """Result of surface code error correction simulation.

    Attributes:
        logical_error_rate: Rate of uncorrectable logical errors.
        physical_error_rate: Input physical error rate.
        rounds: Number of error correction rounds.
        errors_injected: Total errors injected.
        errors_corrected: Successfully corrected errors.
        threshold_estimate: Estimated threshold error rate.
    """
    logical_error_rate: float
    physical_error_rate: float
    rounds: int
    errors_injected: int
    errors_corrected: int
    threshold_estimate: float

    def summary(self) -> str:
        suppression = (
            self.physical_error_rate / self.logical_error_rate
            if self.logical_error_rate > 0 else float("inf")
        )
        lines = [
            "╔══════════════════════════════════════╗",
            "║  Surface Code Error Correction       ║",
            "╠══════════════════════════════════════╣",
            f"║  Physical error rate: {self.physical_error_rate:.2%}         ║",
            f"║  Logical error rate:  {self.logical_error_rate:.4%}       ║",
            f"║  Suppression factor:  {suppression:.1f}x          ║",
            f"║  Rounds:              {self.rounds:<15}║",
            f"║  Errors: {self.errors_corrected}/{self.errors_injected} corrected"
            + " " * 10 + "║",
            "╚══════════════════════════════════════╝",
        ]
        return "\n".join(lines)