FFT and complex-valued tensor operations for AWS Trainium via NKI.
Trainium has no native complex number support and ships no FFT library. trnfft fills that gap with split real/imaginary representation, complex neural network layers, and NKI kernels optimized for the NeuronCore architecture.
Incorporates neuron-complex-ops. Part of the trnsci scientific computing suite (github.com/trnsci).
trnfft follows the trnsci 5-phase roadmap. Active work is tracked in phase-labeled GitHub issues:
- Phase 1 — correctness: complete as of v0.8.0 (butterfly + complex GEMM kernels hardware-validated on trn1.2xlarge, 70/70 benchmark cases passing).
- Phase 2 — precision (active): Kahan / Neumaier compensated summation. Partial delivery in v0.11.0 (precision-modes API + compensated butterfly). On-silicon kahan characterization (#58) and Thread C (Tensor-engine Stockham) are the remaining items.
- Phase 3 — perf: plan reuse, streaming large FFTs, NEFF cache reuse.
- Phase 4 — multi-chip: multi-chip large-N FFT (N > 2²⁰).
- Phase 5 — generation: trn2 larger-SBUF butterfly path.
Suite-wide tracker: trnsci/trnsci#1.
NVIDIA has cuFFT, cuBLAS, and native complex64. Trainium has none of these. Every signal processing, speech enhancement, physics simulation, and spectral method workload on Trainium currently falls back to CPU or requires hand-rolling complex arithmetic. trnfft fixes this.
pip install trnfft
# With Neuron hardware support
pip install trnfft[neuron]import torch
import trnfft
# Drop-in replacement for torch.fft
signal = torch.randn(1024)
X = trnfft.fft(signal)
recovered = trnfft.ifft(X)
# Real-valued FFT
X = trnfft.rfft(signal)
# 2D FFT
image = torch.randn(256, 256)
F = trnfft.fft2(image)
# STFT (matches torch.stft signature)
waveform = torch.randn(16000)
S = trnfft.stft(waveform, n_fft=512, hop_length=256)from trnfft import ComplexTensor
from trnfft.nn import ComplexLinear, ComplexConv1d, ComplexModReLU
# Build complex-valued models for speech/audio/physics
x = ComplexTensor(real_part, imag_part)
layer = ComplexLinear(256, 128)
y = layer(x)+--------------------------------------------+
| User Code / Model |
+--------------------------------------------+
| trnfft.api (torch.fft API) |
| fft() ifft() rfft() stft() fft2() |
+--------------------------------------------+
| trnfft.fft_core | trnfft.nn |
| Cooley-Tukey | ComplexLinear |
| Bluestein | ComplexConv1d |
| Plan caching | ComplexModReLU |
+------------------------+-------------------+
| trnfft.nki.dispatch |
| "auto" | "pytorch" | "nki" |
+--------------------------------------------+
| PyTorch ops | NKI kernels |
| (any device) | (Trainium only) |
| torch.matmul | nisa.nc_matmul |
| element-wise | Tensor Engine |
| | Vector Engine |
| | SBUF ↔ PSUM pipeline |
+------------------+------------------------+
No complex dtype? Trainium's NKI doesn't support complex64/complex128. ComplexTensor stores complex values as paired real tensors and decomposes complex arithmetic into real-valued operations.
FFT → butterflies → matmul. Each Cooley-Tukey butterfly stage performs complex-multiply-and-add across all groups simultaneously. On NKI, the complex multiply maps to the Tensor Engine (systolic array).
Algorithms:
- Power-of-2: Cooley-Tukey radix-2 (iterative, decimation-in-time)
- Arbitrary sizes: Bluestein's chirp-z transform (pads to power-of-2)
NKI complex GEMM uses stationary tile reuse (2 SBUF loads instead of 8) and PSUM accumulation, overlapping Vector Engine negation with Tensor Engine matmul.
NKI kernels are validated against Neuron SDK 2.24+ on the Deep Learning AMI Neuron PyTorch 2.9 (Ubuntu 24.04) AMI (20260410 or later). See docs/installation.md for the full compatibility matrix.
NKI vs PyTorch on the same Trainium instance — see the benchmarks page for the latest numbers.
v0.13.0 — NKI 0.3.0 (Neuron SDK 2.29) migration validated on trn1.2xlarge. DFT-as-GEMM fast path (N ≤ 256) up to 14× on batched FFT/STFT. Stockham radix-4 POC: precision-safe to N=4096, available via bench toggle. See benchmarks for the full picture.
API coverage (13 common torch.fft functions):
fft, ifft, rfft, irfft, fft2, rfft2, irfft2, fftn, ifftn, rfftn, irfftn, stft, istft.
Not implemented: hfft, ihfft — Hermitian-symmetric input variants.
These assume the input tensor is already conjugate-symmetric (X[k] = conj(X[N-k])),
which in practice means you've post-processed an rfft output or are
reconstructing from a known real signal's spectrum. Both workflows are
easier expressed with rfft / irfft plus a manual unpack/pack step.
If you have a use-case producing Hermitian-symmetric tensors directly,
open an issue with the
concrete workload and we'll add them.
Roadmap
- NKI
ComplexConv1d/ComplexModReLUkernels (today both fall back to PyTorch on NKI) - BF16 / FP16 support across NKI kernels
- Multi-NeuronCore parallelism (scaffold in
trnfft/nki/multicore.py) - SBUF-resident dispatch to reduce small-op overhead
All six siblings are on PyPI, along with the umbrella meta-package:
| Project | What | Latest |
|---|---|---|
| trnsci | Umbrella meta-package pulling the whole suite | v0.1.0 |
| trnblas | BLAS Level 1–3 for Trainium | v0.4.0 |
| trnrand | Philox / Sobol / Halton random number generation | v0.1.0 |
| trnsolver | Linear solvers (CG, GMRES) and eigendecomposition | v0.3.0 |
| trnsparse | Sparse matrix operations | v0.1.1 |
| trntensor | Tensor contractions (einsum, TT/Tucker decompositions) | v0.1.1 |
| neuron-complex-ops | Original proof-of-concept, folded into trnfft | archived |
Apache 2.0 — Copyright 2026 Scott Friedman
trnsci is an independent open-source project. It is not sponsored by, endorsed by, or affiliated with Amazon.com, Inc., Amazon Web Services, Inc., or Annapurna Labs Ltd.
"AWS", "Amazon", "Trainium", "Inferentia", "NeuronCore", "Neuron SDK", and related identifiers are trademarks of their respective owners and are used here solely for descriptive and interoperability purposes. Use does not imply endorsement, partnership, or any other relationship.
All work, opinions, analyses, benchmark results, architectural commentary, and editorial judgments in this repository and on trnsci.dev are those of the project's contributors. They do not represent the views, positions, or commitments of Amazon, AWS, or Annapurna Labs.
Feedback directed at the Neuron SDK or Trainium hardware is good-faith ecosystem commentary from independent users. It is not privileged information, is not pre-reviewed by AWS, and should not be read as authoritative about product roadmap, behavior, or quality.
For official AWS guidance, see aws-neuron documentation and the AWS Trainium product page.
Built on insights from:
- tcFFT — Tensor Core FFT research
- FFTW — Plan-based FFT architecture
- AWS NKI documentation