speed up perform_gsn: device-aware fast backend + numerical fixes#27
Open
jacob-prince wants to merge 24 commits into
Open
speed up perform_gsn: device-aware fast backend + numerical fixes#27jacob-prince wants to merge 24 commits into
jacob-prince wants to merge 24 commits into
Conversation
Adds test10_ridge_2d_data_cov as a red test for a longstanding bug in calc_shrunken_covariance.py: lines 179 and 182 add c + 1e-6*I to a rank-deficient training covariance, masking the genuine singularity at alpha = 1. MATLAB's cholcov fails there -> nll = NaN -> min() skips and picks the largest grid point below 1 (0.98 on linspace(0, 1, 51)). Python's ridge lets alpha = 1 pass; when validation data lives in the same low-rank subspace as training, a spurious very-negative log-det dominates and Python picks shrinklevelD = 1.0. The default gsn.simulate_data generator produces full-rank populations that mask the bug, so an optional low_rank_spec field is added to TEST_DEFS to switch in a custom rank-deficient signal+noise generator. Test 10 (nvox=50 ncond=40 ntrial=3 rank_signal=5 rank_noise=10) fails on the unfixed code with shrinklevelD Python=1.0 vs MATLAB=0.98 and cSb/cS/cNb diverging at max abs err ~2e-3. The next commit removes the two c + np.eye(...) * 1e-6 lines.
Deletes the rank check and c + np.eye(c.shape[0]) * 1e-6 block in the 2D path. There is no counterpart in the MATLAB original calcshrunkencovariance.m. The shrinkage formula c2 = alpha*c + (1-alpha)*diag(c) already keeps c2 non-singular for any alpha < 1 via the diagonal injection; at alpha = 1, c2 equals the raw sample c and if that is singular, cholcov fails naturally, nll(p) is set to NaN, and nanargmin skips it. The Python ridge silently regularized the input and let alpha = 1 falsely pass the singularity check, which made Python pick shrinklevelD = 1.0 where MATLAB picks the largest grid point below 1 (0.98 on linspace(0, 1, 51)) for data whose population covariance is genuinely low-rank. Also drops two unused imports (math, scipy.stats as stats) at the top of the file. Test 10 (test10_ridge_2d_data_cov) added in the previous commit now passes; full equivalence suite is 10/10 and Python pytests are 61/61.
The Cholesky factor T is upper-triangular by construction, so np.linalg.pinv(T) ran a full SVD on a triangular matrix — same answer at O(N^3) where a triangular solve gives the same result at O(N^2) per RHS. MATLAB's calcmvgaussianpdf.m already uses `pts / T` (mrdivide), which dispatches to a triangular solve. Switches to scipy.linalg.solve_triangular(T, pts.T, lower=False, trans='T').T, which solves T.T @ X = pts.T for X = pts @ inv(T).
The input is symmetric (we symmetrize it on line 47), so eigh is the right tool. Mathematically equivalent to the SVD-based path used in the MATLAB original: for symmetric M with eigendecomposition V*D*V', SVD gives U=V*sign(D), S=|D|, V_svd=V, so (M + V*S*V')/2 reduces to V*max(D,0)*V' — exactly what eigh + clamp does directly. SVD does ~2x the work of eigh on symmetric input. The eig fallback that the reference used when SVD failed to converge is also dropped — eigh on a symmetric matrix is more numerically stable than eig, so the SVD-failure escape path is no longer needed.
Used only inside the figure block. Moving the three matplotlib imports inside that block saves ~400ms on every `import gsn` for callers who never draw (e.g. perform_gsn / mode=1).
calc_shrunken_covariance picked the optimal shrinkage level by running the held-out Gaussian NLL through a Python for-loop over all 51 shrinkage levels — each iteration its own O(N^3) Cholesky and O(M*N^2) triangular solve. For N in the hundreds-to-thousands range that loop is the dominant cost in perform_gsn. The new gsn.batched_nll.batched_shrunken_nll collapses those 51 sequential factorizations into a single batched torch.linalg.cholesky_ex plus a single batched solve_triangular over the (S, N, N) stack of shrunken covariances. cholesky_ex returns per-slot status without raising, so singular slots cleanly map to nll = NaN (matching MATLAB's min(nll) skip-NaN behavior). When torch is absent we fall back to a numpy + scipy loop that is bit-equivalent to the reference, just with the mean-subtraction lifted out of the loop (it was invariant across levels anyway). calc_shrunken_covariance is refactored to build pts_zm once before the loop and call the batched helper. The (N, N, S) covs array that the reference materialized just to index `covs[:,:,min0ix]` at the end is also dropped — for the wantfull=0 path we recompute the chosen shrunken cov on demand, avoiding the 51x memory blowup at large N. Adds matplotlib to requirements.txt (already used in the figure path, just wasn't declared) and registers torch>=2.0 as an optional 'fast' extra in setup.py — `pip install gsn[fast]` lights up the batched path with no code changes at call sites. Measured speedup of batched_shrunken_nll alone (numpy vs torch CPU): N=200 2.9x, N=500 2.1x, N=1000 11.0x. Larger N widens the gap further. Equivalence between paths is at floating-point noise (max|Δnll| ~5e-13).
batched_shrunken_nll now accepts device in {'cpu', 'cuda', 'mps', 'auto'}.
'cpu' is the default (unchanged behavior) and the right choice for
N up to ~1000 because GPU host<->device transfer costs more than the
batched cholesky_ex saves at that size. 'cuda' / 'mps' open up the GPU
path for large N; 'auto' picks cuda > mps > cpu based on availability.
_resolve_device errors clearly if the caller asks for a device this
torch install can't reach (better than letting it fail deep in a kernel
call). On mps we force float32 since Apple Metal has no float64.
calc_shrunken_covariance gains a device kwarg and threads it through.
rsa_noise_ceiling reads opt['device'] (defaults to 'cpu') and passes
it to both calc_shrunken_covariance calls. From a user perspective:
perform_gsn(data, {'device': 'cuda'}) # explicit
perform_gsn(data, {'device': 'auto'}) # cuda > mps > cpu
perform_gsn(data) # default 'cpu'
Docstrings updated in perform_gsn, rsa_noise_ceiling, and
calc_shrunken_covariance.
tests/test_gsn_python_speedups.py is a pure-Python pytest suite (no MATLAB required) covering every change made on this branch: - calc_mv_gaussian_pdf (pinv -> solve_triangular): parametric N/M matches a direct log-density formula, wantomitexp flag, singular cov returns err=1, single-variable case, no-input-mutation. - construct_nearest_psd_covariance (svd -> eigh): already-PSD passthrough, indefinite -> PSD projection, eigh vs SVD parity on symmetric input, asymmetric symmetrization, scalar/1x1/all-negative edge cases. - calc_shrunken_covariance (ridge removal): regression test that full-rank data is unaffected, rank-deficient 2D data picks shrinklevelD < 1.0 (the bug fix), 3D path still works. - rsa_noise_ceiling (lazy matplotlib): subprocess test that `import gsn` does not load matplotlib.pyplot. - batched_nll (new module): torch vs numpy parity at three shapes, singular slots -> NaN, all-singular -> all-NaN, N=1, S=1, float32 input, nanargmin picks the same level across paths. - device dispatch: 'cpu' / 'auto' resolution, unavailable cuda / mps raise clean RuntimeError, cpu and auto produce identical results. - perform_gsn integration: basic call, rank-deficient regression, torch vs numpy end-to-end equivalence, opt['device'] threading, determinism across repeated calls, uneven-trials path intact. 42 tests, all passing.
Mirror of the Python-side svd -> eigh change. The input is symmetric (we symmetrize on line 26), so eig is the right tool: for symmetric M with eigendecomposition V*D*V', the SVD-based form (M + V*|D|*V')/2 simplifies to V*max(D,0)*V' — exactly what we now do here directly. ~1.4-1.5x cheaper per call than svd on symmetric input. The eig fallback that the old code used when svd failed to converge is also dropped — eig on a (symmetrized) matrix is numerically robust enough that the svd-failure escape path is no longer needed. Equivalence tests Python<->MATLAB remain 10/10 with both languages on the eig path; floating-point reordering between LAPACK paths kept the diffs at machine precision.
Sweeps nunits and times perform_gsn / performgsn with K repeats per cell. Auto-detects available backends (python-numpy, python-torch-cpu, python-torch-cuda, python-torch-mps, matlab) and renders a 3-panel figure: absolute runtime, power-law extrapolation to N=1e6 (fit on N > 1000 only — small N is overhead-dominated), and relative speedup vs python-main-reference. Outputs gitignored.
cluster/: rsync code, SLURM array job (one H100 per nunits), Python driver times perform_gsn across python-numpy / python-torch-cpu / python-torch-cuda. Shards merge to one JSON consumable by tests/test_speedup_magnitude.py. tests/test_gsn_gpu_edge_cases.py: 19 skip-able CUDA/MPS tests for the gsn.batched_nll torch path — gpu↔cpu NLL parity, per-slot NaN under cholesky_ex, dtype handling, consecutive-call independence, end-to-end opt['device'] threading.
cluster/ is user/machine-specific (SLURM, conda paths, lab storage locations); gitignored and untracked here. tests/test_gsn_gpu_edge_cases.py: replace absolute-only tolerances with atol + rtol*|ref| so f32 NLLs at large N stay within precision; add per-test argmin-agreement check — the invariant downstream code actually depends on.
tests/test_gsn_gpu_edge_cases.py: replace absolute-only tolerances with atol + rtol*|ref| so f32 NLLs at large N stay within precision; add per-test argmin-agreement check — the invariant downstream code actually depends on.
New gsn/fast_perform_gsn.py runs the full GSN pipeline — noise+data
covariance, held-out shrinkage selection, biconvex loop, ncsnr — on a
single backend (numpy or torch on CPU/CUDA/MPS) without round-tripping
through host memory between stages.
What changed vs. the previous calc_shrunken_covariance + rsa_noise_ceiling
flow:
* Einsum for the 3D pooled noise covariance, in place of the per-condition
np.cov loop. On torch this collapses ncond_train kernel launches into one.
* Biconvex iteration stays on device. construct_nearest_psd_covariance
previously used numpy.linalg even when torch was available, so cSb/cNb
round-tripped host<->device every iteration; the new _nearest_psd is
written against the active backend.
* End-to-end on one backend. Data moves to device once at entry; we only
materialize numpy at the very end when building the results dict.
perform_gsn becomes a thin defaults+dispatcher (no more mode/ncsims/wantfig
setup for the rsa_noise_ceiling indirection). Uneven trials still fall
through to rsa_noise_ceiling.
batched_nll._torch_dtype_for now accepts both numpy and torch dtypes so
fast_perform_gsn can hand a device tensor straight to _torch_batched
without a numpy round trip.
Local cpu wall-clock (mac): 2-3x faster than the previous fast path
across N=200/500/1000. MATLAB equivalence still 10/10; all unit tests
(103) pass.
Log-log axes, dashed extrapolation, time-reference lines, per-backend power-law fits in the legend. Display names: numpy + scipy.linalg loop, torch CPU (batched), torch CUDA (batched), gsn.perform_gsn (reference).
The (S, N, N) shrunken-cov tensor scales as S * N^2 * dtype-bytes. With S=51, float64, N=20000 that's 160 GB — well past H100's 80 GB. _pick_chunk_size picks the largest shrinkage-level chunk that fits in ~70% of free device memory (queried via torch.cuda.mem_get_info when available, with safe fallbacks otherwise). The inner loop processes chunks sequentially, deleting intermediate tensors before the next chunk's allocations. Verified bit-identical to single-pass when chunk_size=S (the common case for N <= 3000); 103/103 unit tests still pass.
Previous 70% headroom triggered chunking at N=10000 f64 even though single-pass would have fit. 95% leaves single-pass behavior intact through ~N=8000 (any dtype) and ~N=10000 f64; chunking activates only where it's truly required (N >= ~12000 f64 or ~N=15000 f32 on H100).
Motivated by running GSN at large nunits where the torch path OOMed because the result dict carried all four (N, N) cov matrices through host memory while biconvex was still holding device tensors, and biconvex itself held more intermediates than necessary. opt['returns'] selector: callers pick which of cN / cS / cNb / cSb they actually need. Default still emits all four for backwards compat. PSN-only workloads (cSb + cNb consumers) can drop cN / cS and save 2 * N^2 * dtype-bytes of host memory. Memory cleanup in _run_torch: - cS is no longer materialized unless 'cS' in returns. ncsnr needs only its diagonal, computed directly from diag(cD) - diag(cN)/ntrial. - _biconvex_torch no longer takes cS; derives the iter-0 anchor inline as cD - cN/ntrial. - Intermediate data tensors freed as soon as their cov is built; torch.cuda.empty_cache() between the big allocations. _flat_pearson replaces torch.corrcoef for the biconvex convergence check. cuBLAS dot caps its input length at int32 (~2.1e9 elements); at large N the corrcoef path internally builds an intermediate (2, N^2) stack that trips the cap. _flat_pearson uses element-wise mul + reduce-sum with f64 accumulators (int64 strides, no cap). _nearest_psd_torch: - f32 -> f64 upcast for eigh. cuSOLVER syevd is unreliable on near-singular f32; upcasting fixes spurious negative eigenvalues at large N. - scipy.linalg.eigh CPU fallback if the device eigh raises. cuSOLVER syevd hits its workspace limit at very large N — the fallback keeps us going at the cost of one host round-trip. batched_nll: build the (S, N, N) shrunk-cov stack in place by scaling by alpha and then restoring the diagonal. Equivalent to the previous alpha*c + (1-alpha)*diag(c) but peak transient drops from ~3*chunk*N^2 to ~1*chunk*N^2. Test benchmark pins opt['returns'] to ['cSb', 'cNb'] so the cross- backend wall-clock matches the legacy main reference (which doesn't do the three extra eighs that the default 'returns' triggers).
replace the cpu-only _delegate_uneven fallback so uneven (nan-padded) trial counts get the same torch/gpu speedups as even data. - noise cov: per-condition pooled covariance over valid trials as one masked weighted gemm, with cv shrinkage selection. - data cov: deterministic min-trial truncation, then the existing 2d cov. - biconvex: add ntrialbc param (division uses ntrial=min(validcnt), coefficients use the average); default leaves the even path unchanged. - routing honors opt['device']; opt['uneven']='reference' keeps the old rsa delegation as a parity oracle. matches the reference rsa path to ~1e-14 and matches matlab on the uneven equivalence fixtures; cuda-validated; even-path tests unchanged. adds test_fast_uneven_matches_reference.
new opt['uneven']='missing' path for per-electrode/per-unit artifact rejection, where a trial may have some units present and others missing (standard gsn requires whole-trial validity and would discard good data). - cN: average over conditions (>=2 shared-clean trials) of the UNBIASED pairwise covariance of each unit pair over their shared-clean trials (each pair centered on its shared trials; closed form via 3 masked gemms S2=Xm Xm^T, Si=Xm Mb^T, K=Mb Mb^T, cov=(S2-Si.Si^T/K)/(K-1)). - cD: unbiased pairwise covariance across conditions of per-unit condition-means, over conditions where both units are defined. - bias: cS = cD - cN (.) alpha, alpha[i,j] = avg_c n_ij,c/(n_i,c n_j,c); generalizes the scalar 1/ntrial and reduces to it when complete. - biconvex: exact per-entry alpha in the cSb step, effective scalar ntrial in the regularizer coefficients. - shrinkage levels selected on the complete-data subset, applied to the per-entry covs. pairwise (not available-means) centering avoids the partial-overlap bias k/(k-1)*(1-1/n_i-1/n_j+k/(n_i n_j)). pinned by: exact reduction to the even path on complete data, brute-force pairwise references for cN/cD/alpha, and a monte-carlo unbiasedness test (catches the available-means bias). tests/test_missing_units.py (16 tests). numpy only; torch path to follow.
device-native version of the missing-units estimator. the per-condition pairwise-covariance loop and biconvex run on the active device; shrinkage levels are picked on the host (cheap scalars, reusing the numpy selectors) and applied on device. routing: opt['uneven']='missing' uses torch when available (honoring opt['device']), else numpy. pairwise cov per block via the same closed form (S2, Si=Xm Mb^T, K=Mb Mb^T). new test_torch_matches_numpy asserts torch(cpu) == numpy on cN/cS/cNb/cSb, ncsnr, means, eigenvalues, shrink levels, and the signal subspace.
- batched_nll: clone the expanded diagonal RHS so the N==1 shrinkage NLL no longer hits a torch memory-alias error (also unbroke calc_shrunken_covariance and fast_perform_gsn on single-variable data). - honor opt['shrinklevels'] everywhere via a centralized _get_shrinklevels (the fast/uneven/missing paths previously always used the default grid). - uneven noise-cov: assert when the held-out validation split has no condition with >=2 valid trials, instead of silently returning shrink level 0 on an all-NaN nll (matches rsa_noise_ceiling). - fast_perform_gsn / run_missing_units_numpy: guard data is 3D with >=2 trials (matches the reference 'Number of trials must be at least 2'). tests/test_bugfixes.py covers all four.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
summary
A faster, optionally GPU-accelerated GSN with NaN / per-unit-missing data
support, plus several numerical-correctness fixes to the shrinkage path.
The numpy CPU path stays the default; torch (cuda/mps) is optional + opt-in.
Paired with the PSN
refactor-dec25PR, which now routesperform_gsnthrough this package and so depends on this branch.
what's new
fast backend — why: GSN's covariance + cross-validated shrinkage estimation
is the runtime bottleneck at large nunits (thousands–tens of thousands of units),
and the original path couldn't handle missing data efficiently.
fast_perform_gsn.py— device-native (numpy + optional torch) GSN. Pushes theGEMM/solve-heavy work to torch (CPU or GPU) for large speedups; nan-aware
uneven-trials path;
opt['returns']selector + opt-in eigvecs/eigvals returns(
opt['eigh_device']) so callers (PSN) can skip re-doing the eigendecomposition;large-N memory cleanup.
batched_nll.py— batched-Cholesky shrinkage-NLL evaluation (optional torch).why: scoring all candidate shrinkage levels at once is far faster, but the full
stack blows device memory at large N — so it's chunked.
missing_units.py— per-unit missing-data GSN (numpy + torch/gpu). why: realdatasets have units missing on some conditions, not just whole-trial NaNs.
perform_gsn.py— wire in the fast path / options.numerical fixes (shrinkage / covariance) — why: use decompositions matched to
the matrices (symmetric / PSD) — faster and more numerically stable — and remove a
regularizer that was perturbing results.
calc_shrunken_covariance: remove the1e-6*Iridge (+ an equivalence test thatexercises the prior bug).
construct_nearest_psd_covariance:eighinstead ofsvd; matlabconstructnearestpsdcovariance.m:eiginstead ofsvd.calc_mv_gaussian_pdf:solve_triangularinstead ofpinv.tests — why: prove the fast/GPU paths stay equivalent to the reference and guard
the new edge cases.
benchmark, uneven-trials, bugfix tests; extended matlab↔python equivalence harness.
packaging — why: torch is heavy, so keep it optional; don't drag matplotlib into
every import.
in
rsa_noise_ceiling.review notes
device.a few downstream values differ slightly from a CPU run — inherent backend behavior,
not a bug (GPU test tolerances are relative, with an argmin-agreement check).
testing
pytest tests/tests/test_gsn_matlab_python_equivalence.sh