Skip to content

Commit c86d01c

Browse files
ivanbasovclaude
andcommitted
chore: merge ising/main and resolve README conflict
Keep badges from branch; take updated description from main. Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
2 parents 336a9cc + 9668dd3 commit c86d01c

30 files changed

Lines changed: 934 additions & 121 deletions

.github/workflows/long-running-tests.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ jobs:
189189
PREDECODER_TRAIN_SAMPLES: "32768"
190190
PREDECODER_VAL_SAMPLES: "4096"
191191
PREDECODER_TEST_SAMPLES: "4096"
192-
PREDECODER_TRAIN_EPOCHS: "1"
192+
PREDECODER_TRAIN_EPOCHS: "30"
193+
PREDECODER_DISABLE_SDR: "1"
193194

194195
- name: Multi-orientation inference (O1–O4) with LER output check
195196
shell: bash

.github/workflows/pr_file_check.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,13 @@ jobs:
3737
run: |
3838
MAX_SIZE_BYTES=102400 # 100KB
3939
MAX_SIZE_HUMAN="100KB"
40+
# Exempt image formats and notebooks (documentation/tutorial assets)
41+
EXEMPT_EXTENSIONS="(png|jpg|jpeg|gif|svg|ico|webp|ipynb)"
4042
LARGE_FILES=""
4143
while IFS= read -r file; do
44+
if [[ "$file" =~ \.($EXEMPT_EXTENSIONS)$ ]]; then
45+
continue
46+
fi
4247
if [[ -f "$file" ]]; then
4348
size=$(stat --format='%s' "$file")
4449
if (( size > MAX_SIZE_BYTES )); then

README.md

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
# AI pre-decoder for surface-code memory circuits
1+
# Ising Decoding
22

33
[![License](https://img.shields.io/badge/License-Apache%202.0-blue)](./LICENSE)
44
[![Release](https://img.shields.io/badge/Release-v0.1.0-brightgreen)](https://github.com/NVIDIA/Ising-Decoding/tree/releases/v0.1.0)
55
[![Paper](https://img.shields.io/badge/Paper-NVIDIA%20Research-76b900)](https://research.nvidia.com/publication/2026-04_fast-ai-based-pre-decoders-surface-codes)
66
[![Model: Fast](https://img.shields.io/badge/🤗%20HuggingFace-Fast%20Model-ffd21e)](https://huggingface.co/nvidia/Ising-Decoder-SurfaceCode-1-Fast)
77
[![Model: Accurate](https://img.shields.io/badge/🤗%20HuggingFace-Accurate%20Model-ffd21e)](https://huggingface.co/nvidia/Ising-Decoder-SurfaceCode-1-Accurate)
88

9-
This repo implements a **pre-decoder** for surface-code memory experiments:
9+
This repo offers AI training frameworks and recipes to build, customize and deploy scalable quantum error correction **decoders**:
1010

1111
- A neural network consumes detector syndromes across space **and** time
1212
- It predicts corrections that reduce syndrome density / improve decoding
@@ -100,8 +100,8 @@ pip install -r code/requirements_public_inference.txt
100100

101101
2. **Get the pre-trained models**
102102
This repo ships two pre-trained model files (tracked with Git LFS):
103-
- `models/PreDecoderModelMemory_r9_v1.0.77.pt` (receptive field R=9, checkpoint 77)
104-
- `models/PreDecoderModelMemory_r13_v1.0.86.pt` (receptive field R=13, checkpoint 86)
103+
- `models/Ising-Decoder-SurfaceCode-1-Fast.pt` (receptive field R=9)
104+
- `models/Ising-Decoder-SurfaceCode-1-Accurate.pt` (receptive field R=13)
105105

106106
Clones get the files via `git lfs pull`. Optionally, set `PREDECODER_MODEL_URL` to the LFS/raw URL to fetch files when not in the working tree (e.g. in a minimal checkout or CI).
107107

@@ -146,8 +146,16 @@ The pre-trained public models use `--model-id 1` (R=9) and `--model-id 4` (R=13)
146146
After training (or starting from the shipped `.safetensors` files), you can export the model to
147147
ONNX and optionally apply INT8 or FP8 post-training quantization for deployment.
148148

149-
Set the `ONNX_WORKFLOW` and (optionally) `QUANT_FORMAT` environment variables before running
150-
inference with `local_run.sh`:
149+
You may also change the surface code distance and number of rounds at inference
150+
time. That is - you are not required retrain a new model when changing either
151+
one of these parameters; since the model is a 3D convolutional neural network,
152+
the model will simply be run over a new decoding volume.
153+
154+
- To run with a new distance, simply add `DISTANCE=<your distance>` to the commands below.
155+
- To run with a new number of rounds, simply add `N_ROUNDS=<your number of rounds>` to the commands below.
156+
157+
Set the `ONNX_WORKFLOW` and (optionally) (`QUANT_FORMAT`, `DISTANCE`,
158+
`N_ROUNDS`) environment variables before running inference with `local_run.sh`:
151159

152160
| `ONNX_WORKFLOW` | Behavior |
153161
|---|---|
@@ -177,7 +185,16 @@ ONNX_WORKFLOW=3 WORKFLOW=inference bash code/scripts/local_run.sh
177185
| `QUANT_FORMAT` | unset | `int8` or `fp8`. Unset means no quantization (FP32 ONNX). |
178186
| `QUANT_CALIB_SAMPLES` | `256` | Calibration samples for INT8/FP8 post-training quantization. |
179187

188+
**Circuit variables:**
189+
190+
| Variable | Default | Description |
191+
|---|---|---|
192+
| `CONFIG_NAME` | `config_public` | Use the defaults from the `conf/$CONFIG_NAME.yaml` file |
193+
| `DISTANCE` | Use the distance specified in the `conf/$CONFIG_NAME.yaml` file | surface code distance |
194+
| `N_ROUNDS` | Calibration samples for INT8/FP8 post-training quantization. | number of rounds in memory experiment |
195+
180196
Notes:
197+
181198
- TensorRT workflows (`ONNX_WORKFLOW=2` or `3`) require `tensorrt` and `modelopt`.
182199
- FP8 quantization failure is fatal. INT8 failure falls back to the FP32 ONNX model silently.
183200
- ONNX and engine files are written to the current working directory.
@@ -223,7 +240,7 @@ Results are written to `outputs/<EXPERIMENT_NAME>/plots/`.
223240
| Decoder | Source | Notes |
224241
|---|---|---|
225242
| No-op || Pre-decoder output only, no global correction |
226-
| Union-Find | `ldpc` | Fast, sub-optimal |
243+
| Union-Find | `ldpc` | Fast, sub-optimal LER (Logical Error Rate) |
227244
| BP-only | `ldpc` | Belief propagation, no OSD |
228245
| BP+LSD-0 | `ldpc` | BP with localized statistics decoding |
229246
| Uncorr-PM | PyMatching | Uncorrelated minimum-weight perfect matching |
@@ -573,4 +590,4 @@ Presence of these headers is enforced automatically by the `spdx-header-check` C
573590
`.github/workflows/ci.yml`).
574591

575592
Third-party open source components bundled with or required by this project are listed with their
576-
respective copyright notices and license texts in [NOTICE](NOTICE).
593+
respective copyright notices and license texts in [NOTICE](NOTICE).

TRAINING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ For local single-machine usage, see `README.md`.
77
## Prerequisites
88

99
- Docker with NVIDIA GPU support (`nvidia-docker` / `--gpus`)
10-
- One or more NVIDIA GPUs (H100, A100, or similar)
10+
- One or more NVIDIA GPUs (B200, H200 or similar)
1111
- A persistent directory for checkpoints and logs
1212

1313
## Quick start (Docker — recommended)

code/evaluation/failure_analysis.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"""
1919
import os
2020
import random
21+
import warnings
2122

2223
import numpy as np
2324
import torch
@@ -179,27 +180,39 @@ def _build_cudaq_decoders(det_model):
179180

180181
def _decode_cudaq_batch(decoder, L_dense, syndromes_np):
181182
"""
182-
Decode a batch of syndromes with a cudaq-qec nv-qldpc-decoder (single-shot loop).
183+
Decode a batch of syndromes with a cudaq-qec nv-qldpc-decoder.
183184
Returns (obs, stats) where:
184185
- obs: observable predictions as np.ndarray of shape (B,)
185186
- stats: dict with per-sample convergence flags, iteration counts
186-
The decoder.decode() takes list[float] and returns DecoderResult with .result (list[float]).
187187
"""
188188
B = syndromes_np.shape[0]
189-
obs = np.zeros(B, dtype=np.uint8)
189+
n_bits = L_dense.shape[1]
190190
converged_flags = np.zeros(B, dtype=bool)
191191
iter_counts = np.zeros(B, dtype=np.int32)
192-
for i in range(B):
193-
syndrome_list = syndromes_np[i].astype(np.float64).tolist()
194-
result = decoder.decode(syndrome_list)
195-
correction = np.array(result.result, dtype=np.uint8)
196-
obs[i] = int((L_dense @ correction).item() %
197-
2) if L_dense.shape[0] == 1 else int((L_dense @ correction)[0] % 2)
192+
corrections = np.empty((B, n_bits), dtype=np.uint8)
193+
syndromes_f64 = np.ascontiguousarray(syndromes_np, dtype=np.float64)
194+
195+
def _unpack(i, result):
196+
corrections[i] = np.array(result.result, dtype=np.uint8)
198197
converged_flags[i] = result.converged
199-
# Collect iteration count if available via opt_results
200198
opt = getattr(result, 'opt_results', None)
201199
if opt and isinstance(opt, dict) and 'num_iter' in opt:
202200
iter_counts[i] = opt['num_iter']
201+
202+
def _loop_decode():
203+
for i in range(B):
204+
_unpack(i, decoder.decode(syndromes_f64[i].tolist()))
205+
206+
try:
207+
results = decoder.decode_batch(syndromes_f64.tolist())
208+
except Exception as exc:
209+
warnings.warn(f"decode_batch failed ({exc}); falling back to per-sample loop")
210+
_loop_decode()
211+
else:
212+
for i, result in enumerate(results):
213+
_unpack(i, result)
214+
215+
obs = ((corrections.astype(np.int32) @ L_dense.T.astype(np.int32))[:, 0] % 2).astype(np.uint8)
203216
return obs, {"converged_flags": converged_flags, "iter_counts": iter_counts}
204217

205218

@@ -249,20 +262,17 @@ def _build_ldpc_decoders(det_model):
249262

250263
def _decode_ldpc_batch(decoder, L_dense, syndromes_np):
251264
"""
252-
Decode a batch of syndromes with an ldpc decoder (single-shot loop).
265+
Decode a batch of syndromes with an ldpc decoder.
253266
Returns observable predictions as np.ndarray of shape (B,).
254267
"""
255268
B = syndromes_np.shape[0]
256-
obs = np.zeros(B, dtype=np.uint8)
269+
n_bits = L_dense.shape[1]
270+
syndromes_c = np.ascontiguousarray(syndromes_np, dtype=np.uint8)
271+
corrections = np.empty((B, n_bits), dtype=np.uint8)
257272
for i in range(B):
258-
# Get the most-likely error configuration from the decoder for this syndrome.
259-
correction = decoder.decode(syndromes_np[i])
260-
# Project the correction onto the logical observable via L_dense (mod 2).
261-
# L_dense has shape (num_obs, num_errors); the first observable row is used.
262-
obs[i] = (
263-
int((L_dense @ correction).item() %
264-
2) if L_dense.shape[0] == 1 else int((L_dense @ correction)[0] % 2)
265-
)
273+
corrections[i] = decoder.decode(syndromes_c[i])
274+
275+
obs = ((corrections.astype(np.int32) @ L_dense.T.astype(np.int32))[:, 0] % 2).astype(np.uint8)
266276
return obs
267277

268278

code/evaluation/logical_error_rate.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def _ort_quantize_int8(fp32_onnx_path: str, output_path: str, calib_dets: "np.nd
200200
class _DetCalibReader(CalibrationDataReader):
201201

202202
def __init__(self, data):
203-
self._rows = [{"dets": data[i:i + 1].astype("float32")} for i in range(len(data))]
203+
self._rows = [{"dets": data[i:i + 1]} for i in range(len(data))]
204204
self._iter = iter(self._rows)
205205

206206
def get_next(self):
@@ -1055,6 +1055,11 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
10551055
pass
10561056
except Exception:
10571057
pass
1058+
# torch.compile + spawn workers causes a segfault (CUDA context conflict in
1059+
# spawned subprocesses after the model is compiled). Fall back to in-process
1060+
# loading when torch.compile has been applied.
1061+
if _applied_compile and int(test_loader_kwargs.get("num_workers", 0)) > 0:
1062+
test_loader_kwargs["num_workers"] = 0
10581063
# Handle prefetch_factor when num_workers=0
10591064
if test_loader_kwargs.get('num_workers', 0) == 0:
10601065
test_loader_kwargs.pop('prefetch_factor', None)
@@ -1197,7 +1202,7 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
11971202
mq.quantize(
11981203
onnx_path=fp32_onnx_path,
11991204
quantize_mode=quant_format,
1200-
calibration_data={"dets": calib_dets.astype("float32")},
1205+
calibration_data={"dets": calib_dets},
12011206
output_path=onnx_path,
12021207
**quant_kwargs,
12031208
)

code/export/checkpoint_to_safetensors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
2121
Usage:
2222
PYTHONPATH=code python code/export/checkpoint_to_safetensors.py \\
23-
--checkpoint models/PreDecoderModelMemory_r9_v1.0.77.pt \\
23+
--checkpoint models/Ising-Decoder-SurfaceCode-1-Fast.pt \\
2424
--model-id 1 [--fp16]
2525
2626
Then run inference with:
27-
PREDECODER_SAFETENSORS_CHECKPOINT=models/PreDecoderModelMemory_r9_v1.0.77_fp16.safetensors \\
27+
PREDECODER_SAFETENSORS_CHECKPOINT=models/Ising-Decoder-SurfaceCode-1-Fast_fp16.safetensors \\
2828
WORKFLOW=inference DISTANCE=9 N_ROUNDS=9 EXPERIMENT_NAME=predecoder_model_1 \\
2929
bash code/scripts/local_run.sh
3030
"""

0 commit comments

Comments
 (0)