Skip to content

SAY-5/export-validator

Repository files navigation

export-validator

Per-layer ONNX export parity validator. Walks a PyTorch model leaf-by-leaf, exports each leaf as a named ONNX graph output, runs PyTorch and ONNX Runtime on the same input bytes, and reports divergence layer-by-layer. The comparator runs in a small C++20 binary; a pure-Python fallback emits byte-identical JSON when the binary is not built.

What this studies

Whole-output parity (e.g. np.allclose(pt_out, ort_out, atol=1e-4)) hides where a divergence is born. By the time the classifier output drifts, the root cause may be ten layers upstream and folded into other operations. This repo exports every leaf module as a named ONNX output, captures both runtimes, and reports the first layer whose max-abs diff exceeds the tolerance — the drift_origin.

The C++ comparator (cpp/) exists for two reasons: (1) it demonstrates the same numerical contract holds under a non-Python implementation, and (2) tests/integration/test_python_cpp_parity.py enforces that both backends produce byte-identical JSON, which doubles as a regression test on the report format.

Real run on torchvision.models.resnet18 (FP32)

From examples/reports/resnet18_fp32.md, generated by make pipeline:

Tolerance: 0.0001 · layers checked: 60 · layers exceeding: 0 Drift origin: none

Worst per-layer divergence across the whole model:

metric layer value
max abs diff layer4.1.relu 9.537e-06
max mean abs diff layer4.1.bn2 1.142e-06
min max abs diff layer4.0.downsample.0 5.811e-07

No drift detected at tolerance 1e-4 across 60 layers.

That is the actual finding — every layer of FP32 ResNet-18 lands within ~1e-5 max abs diff between PyTorch 2.5.1 (CPU) and ONNX Runtime 1.20.1 (CPU EP). Useful baseline for the next experiment (FP16, INT8, fused-Conv-BN graphs, etc.).

Multi-architecture parity sweep (FP32)

Reports are committed under examples/reports/ for four torchvision architectures. The same exporter, capture, and compare path is reused for every model; the only per-model code is a 10-line builder under src/export_validator/models/.

model leaves layers exceeding 1e-4 worst max_abs_diff location
resnet18 60 0 9.537e-06 layer4.1.relu
resnet50 158 0 4.578e-05 layer4.2.conv1
mobilenet_v3_small 141 0 7.534e-05 features.4.block.0.0
vit_b_16 100 12 1.812e-04 encoder.layers.encoder_layer_5.mlp.3

The honest finding: ViT-B/16 is the only one where some layers exceed the 1e-4 envelope. The drift originates at the MLP block of encoder layer 5 (a Linear(768, 3072)) and propagates through the feed-forward path of the next several encoder blocks. CNNs of comparable depth (ResNet-50 at 2× the parameter count) stay clean — accumulated fp32 drift in GEMM-heavy transformer MLPs is a real per-layer phenomenon that whole-output parity hides.

ViT export also requires disabling PyTorch's MultiheadAttention fast-path (torch.backends.mha.set_fastpath_enabled(False)); the fused aten::_native_multi_head_attention op is not lowerable to ONNX opset 17 via the legacy exporter. The vit_b_16 builder applies that workaround internally. CNN builders are unaffected.

A real subtlety surfaced during development: ResNet uses nn.ReLU(inplace=True) after every BatchNorm. The ONNX exporter records the BN node's output as a tensor that the in-place ReLU will mutate; without an explicit clone() in the named-output wrapper, the per-layer report shows 1-5 unit divergence on every BN layer (post-ReLU vs pre-ReLU) while every other layer matches at 1e-6. See docs/drift-debugging.md.

Quickstart

make install
make build-cpp
make pipeline       # exports resnet18, runs both runtimes, writes reports
make test           # python + C++ tests

make pipeline writes examples/reports/resnet18_fp32.json and examples/reports/resnet18_fp32.md.

Architecture

                       PyTorch model (eval)
                              │
              ┌───────────────┴───────────────┐
              │                               │
       NamingHooks                    NamedOutputWrapper
       (per-leaf clone)              (returns final +
              │                       intermediates as
              │                       graph outputs)
              ▼                               ▼
     {layer → ndarray}                 torch.onnx.export
              │                               │
              │                               ▼
              │                          .onnx file with
              │                          one named output
              │                          per leaf module
              │                               │
              │                               ▼
              │                    onnxruntime.InferenceSession
              │                               │
              │                               ▼
              │                       {layer → ndarray}
              │                               │
              └────────────┬──────────────────┘
                           ▼
            ┌────────────────────────────────┐
            │   compare (C++ or Python)      │
            │   - per-layer max/mean abs     │
            │   - first violator = origin    │
            │   - byte-identical JSON        │
            └────────────────────────────────┘
                           │
                           ▼
                examples/reports/<model>.{json,md}

NCHW/NHWC layout-mismatch detection

A separate module format_mismatch.py walks per-layer activation pairs and asks: for each layer that exceeds tolerance, does permuting one tensor onto the other restore agreement? If yes, that layer is flagged with the inferred permutation (e.g. (0, 2, 3, 1) for NCHW → NHWC). The module covers four-dimensional CNN tensors and three-dimensional transformer tensors ((B, C, T)(B, T, C)). It is a detector, not a fixer — it tells you where the layout flip happens, not how to patch it. See tests/unit/test_format_mismatch.py for the synthetic permute-injection contract.

Adjacent ML-systems-debug experiments (SAY-5)

This repo answers per-layer drift, in isolation. Two adjacent experiments cover the orthogonal questions:

  • SAY-5/onnx-deploywhole-output parity, batched benchmarking, container packaging. The "is the whole thing within tolerance" view that the layer-by-layer view here decomposes.
  • SAY-5/quant-explorer — INT8/FP16 quantisation drift. Different ML-systems-debug angle: this repo's attribution layer recognises precision_loss as a cause but does not do quantisation calibration; that lives next door.

What this is not

  • Not a whole-output parity tool. See SAY-5/onnx-deploy.
  • Not a quantization explorer. See SAY-5/quant-explorer. The attribution layer here labels fp16 drift as precision_loss but does not produce calibrated INT8 weights.
  • Not a multi-EP comparator. Only the CPU execution provider is wired up; CUDA/CoreML/TensorRT are out of scope.

Layout

src/export_validator/   Python package: hooks, export, capture, compare, report
cpp/                    C++20 comparator (CMake + GoogleTest, optional libtorch
                        + onnxruntime smoke binary when both deps resolve)
examples/               Committed layer map + real-run reports + deterministic inputs
tests/                  Python unit + integration (RUN_INTEGRATION=1)
docs/                   per-layer.md, tolerance.md, drift-debugging.md

License

MIT. Copyright 2026 Sai Asish Y.

About

Per-layer ONNX export parity validator: PyTorch + ONNX Runtime cross-checked layer by layer in a C++ comparator with Python fallback

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors