From d3dfe39f056f90c014c8f28fb49eb0874a8560cd Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 16 Apr 2026 09:24:07 -0700 Subject: [PATCH 1/3] feat(compute): T2.2 capture-aware allocWeight routing via cudaMallocAsync When CaptureAwareAllocator is active (set by BeginCapture/WithCapture), allocWeight routes through cudaMallocAsync on the capture stream so allocations are recorded as graph nodes. This avoids the silent hang caused by cudaMallocManaged during CUDA graph capture on GB10. Similarly, uploadBytes routes through cudaMemcpyAsync on the capture stream instead of the synchronous CPU copy used by the managed-memory path, which is illegal during capture. The ensureNotCapturing guard now only fires when capture is active but the allocator was NOT properly switched via BeginCapture/WithCapture. Changes: - Add IsCapturing() to CaptureAwareAllocator interface - Implement IsCapturing() on cuda.MemPool and gpuapi.CUDAMemPool - Add async allocation/copy routing in allocWeight and uploadBytes - Add function variable indirections for MallocManaged, MallocAsync, and MemcpyAsync to enable CPU-mock testing - Add 7 unit tests covering all routing paths --- compute/capture_alloc_test.go | 329 ++++++++++++++++++++++++++++++++ compute/gpu_engine.go | 41 +++- internal/cuda/mempool.go | 7 + internal/gpuapi/cuda_mempool.go | 5 + internal/gpuapi/mempool.go | 3 + 5 files changed, 379 insertions(+), 6 deletions(-) create mode 100644 compute/capture_alloc_test.go diff --git a/compute/capture_alloc_test.go b/compute/capture_alloc_test.go new file mode 100644 index 0000000..b70223c --- /dev/null +++ b/compute/capture_alloc_test.go @@ -0,0 +1,329 @@ +package compute + +import ( + "errors" + "sync/atomic" + "testing" + "unsafe" + + "github.com/zerfoo/ztensor/internal/cuda" + "github.com/zerfoo/ztensor/internal/gpuapi" +) + +// --- fake CaptureAwareAllocator pool for tests --- + +type fakeCapturePool struct { + capturing bool +} + +func (p *fakeCapturePool) Alloc(int, int) (unsafe.Pointer, error) { return nil, nil } +func (p *fakeCapturePool) Free(int, unsafe.Pointer, int) {} +func (p *fakeCapturePool) AllocManaged(int, int) (unsafe.Pointer, error) { return nil, nil } +func (p *fakeCapturePool) FreeManaged(int, unsafe.Pointer, int) {} +func (p *fakeCapturePool) Drain() error { return nil } +func (p *fakeCapturePool) Stats() (int, int) { return 0, 0 } +func (p *fakeCapturePool) SetCaptureStream(_ unsafe.Pointer) { p.capturing = true } +func (p *fakeCapturePool) ClearCaptureStream() { p.capturing = false } +func (p *fakeCapturePool) IsCapturing() bool { return p.capturing } + +var ( + _ gpuapi.MemPool = (*fakeCapturePool)(nil) + _ gpuapi.CaptureAwareAllocator = (*fakeCapturePool)(nil) +) + +// --- fake non-capture-aware pool (like CUDAArenaPool) --- + +type fakeBasicPool struct{} + +func (p *fakeBasicPool) Alloc(int, int) (unsafe.Pointer, error) { return nil, nil } +func (p *fakeBasicPool) Free(int, unsafe.Pointer, int) {} +func (p *fakeBasicPool) AllocManaged(int, int) (unsafe.Pointer, error) { return nil, nil } +func (p *fakeBasicPool) FreeManaged(int, unsafe.Pointer, int) {} +func (p *fakeBasicPool) Drain() error { return nil } +func (p *fakeBasicPool) Stats() (int, int) { return 0, 0 } + +var _ gpuapi.MemPool = (*fakeBasicPool)(nil) + +// --- test helpers --- + +// swapMallocAsyncFn replaces the package-level mallocAsyncFn and returns +// a restore closure. +func swapMallocAsyncFn(fn func(int, *cuda.Stream) (unsafe.Pointer, error)) func() { + prev := mallocAsyncFn + mallocAsyncFn = fn + return func() { mallocAsyncFn = prev } +} + +// swapMallocManagedFn replaces the package-level mallocManagedFn and returns +// a restore closure. +func swapMallocManagedFn(fn func(int) (unsafe.Pointer, error)) func() { + prev := mallocManagedFn + mallocManagedFn = fn + return func() { mallocManagedFn = prev } +} + +// swapMemcpyAsyncFn replaces the package-level memcpyAsyncFn and returns +// a restore closure. +func swapMemcpyAsyncFn(fn func(unsafe.Pointer, unsafe.Pointer, int, cuda.MemcpyKind, *cuda.Stream) error) func() { + prev := memcpyAsyncFn + memcpyAsyncFn = fn + return func() { memcpyAsyncFn = prev } +} + +// --- allocWeight tests --- + +// TestAllocWeight_UsesAsyncWhenCapturing verifies that allocWeight routes +// through cudaMallocAsync when CaptureAwareAllocator is active. +func TestAllocWeight_UsesAsyncWhenCapturing(t *testing.T) { + var asyncCalled atomic.Bool + var requestedSize int + var sentinel byte + + restore := swapMallocAsyncFn(func(size int, _ *cuda.Stream) (unsafe.Pointer, error) { + asyncCalled.Store(true) + requestedSize = size + return unsafe.Pointer(&sentinel), nil + }) + defer restore() + + // Also stub captureStatusFn so ensureNotCapturing does not interfere. + restoreStatus := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) { + return cuda.CaptureStatusActive, nil + }) + defer restoreStatus() + + pool := &fakeCapturePool{capturing: true} + e := &GPUEngine[float32]{ + stream: fakePtrStream{}, + pool: pool, + } + + ptr, err := e.allocWeight(4096) + if err != nil { + t.Fatalf("allocWeight during capture: unexpected error: %v", err) + } + if !asyncCalled.Load() { + t.Fatal("allocWeight during capture: expected cudaMallocAsync to be called") + } + if requestedSize != 4096 { + t.Fatalf("allocWeight during capture: async alloc size = %d, want 4096", requestedSize) + } + if ptr != unsafe.Pointer(&sentinel) { + t.Fatal("allocWeight during capture: returned pointer does not match async allocation") + } +} + +// TestAllocWeight_UsesManagedWhenNotCapturing verifies that allocWeight +// still uses cudaMallocManaged when capture is NOT active and managedMem +// is true. +func TestAllocWeight_UsesManagedWhenNotCapturing(t *testing.T) { + var managedCalled atomic.Bool + var sentinel byte + + restoreManaged := swapMallocManagedFn(func(size int) (unsafe.Pointer, error) { + managedCalled.Store(true) + return unsafe.Pointer(&sentinel), nil + }) + defer restoreManaged() + + var asyncCalled atomic.Bool + restoreAsync := swapMallocAsyncFn(func(_ int, _ *cuda.Stream) (unsafe.Pointer, error) { + asyncCalled.Store(true) + return nil, nil + }) + defer restoreAsync() + + restoreStatus := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) { + return cuda.CaptureStatusNone, nil + }) + defer restoreStatus() + + pool := &fakeCapturePool{capturing: false} + e := &GPUEngine[float32]{ + stream: fakePtrStream{}, + pool: pool, + managedMem: true, + } + + ptr, err := e.allocWeight(4096) + if err != nil { + t.Fatalf("allocWeight (not capturing, managed): unexpected error: %v", err) + } + if !managedCalled.Load() { + t.Fatal("allocWeight (not capturing, managed): expected cudaMallocManaged to be called") + } + if asyncCalled.Load() { + t.Fatal("allocWeight (not capturing, managed): cudaMallocAsync should NOT be called") + } + if ptr != unsafe.Pointer(&sentinel) { + t.Fatal("allocWeight (not capturing, managed): returned pointer does not match managed allocation") + } +} + +// TestAllocWeight_GuardFiresWithoutCaptureAwareAllocator verifies that +// ensureNotCapturing still blocks allocWeight when capture is active +// but the pool does NOT implement CaptureAwareAllocator (e.g., +// CUDAArenaPool). This is the "raw capture without BeginCapture" path. +func TestAllocWeight_GuardFiresWithoutCaptureAwareAllocator(t *testing.T) { + restoreStatus := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) { + return cuda.CaptureStatusActive, nil + }) + defer restoreStatus() + + e := &GPUEngine[float32]{ + stream: fakePtrStream{}, + pool: &fakeBasicPool{}, + } + + ptr, err := e.allocWeight(4096) + if err == nil { + t.Fatal("allocWeight with non-capture-aware pool during capture: expected error, got nil") + } + if !errors.Is(err, ErrCaptureIncompatibleAllocation) { + t.Fatalf("allocWeight: expected ErrCaptureIncompatibleAllocation, got %v", err) + } + if ptr != nil { + t.Fatalf("allocWeight: expected nil pointer on guard trip, got %p", ptr) + } +} + +// TestAllocWeight_GuardSkippedWhenCaptureAwareAllocatorActive verifies +// that ensureNotCapturing does NOT fire when CaptureAwareAllocator is +// properly engaged via BeginCapture/WithCapture. +func TestAllocWeight_GuardSkippedWhenCaptureAwareAllocatorActive(t *testing.T) { + var ensureNotCapturingReached atomic.Bool + restoreStatus := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) { + ensureNotCapturingReached.Store(true) + return cuda.CaptureStatusActive, nil + }) + defer restoreStatus() + + restoreAsync := swapMallocAsyncFn(func(_ int, _ *cuda.Stream) (unsafe.Pointer, error) { + var sentinel byte + return unsafe.Pointer(&sentinel), nil + }) + defer restoreAsync() + + pool := &fakeCapturePool{capturing: true} + e := &GPUEngine[float32]{ + stream: fakePtrStream{}, + pool: pool, + } + + _, err := e.allocWeight(4096) + if err != nil { + t.Fatalf("allocWeight with capture-aware allocator active: unexpected error: %v", err) + } + if ensureNotCapturingReached.Load() { + t.Fatal("ensureNotCapturing should NOT be called when CaptureAwareAllocator is active") + } +} + +// --- uploadBytes tests --- + +// TestUploadBytes_UsesAsyncWhenCapturing verifies that uploadBytes routes +// through cudaMemcpyAsync when CaptureAwareAllocator is active. +func TestUploadBytes_UsesAsyncWhenCapturing(t *testing.T) { + var asyncCalled atomic.Bool + var copiedSize int + var copiedKind cuda.MemcpyKind + + restoreMemcpy := swapMemcpyAsyncFn(func(_ unsafe.Pointer, _ unsafe.Pointer, count int, kind cuda.MemcpyKind, _ *cuda.Stream) error { + asyncCalled.Store(true) + copiedSize = count + copiedKind = kind + return nil + }) + defer restoreMemcpy() + + restoreStatus := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) { + return cuda.CaptureStatusActive, nil + }) + defer restoreStatus() + + pool := &fakeCapturePool{capturing: true} + e := &GPUEngine[float32]{ + stream: fakePtrStream{}, + pool: pool, + } + + src := []byte{0x01, 0x02, 0x03, 0x04} + var devMem byte + err := e.uploadBytes(unsafe.Pointer(&devMem), src) + if err != nil { + t.Fatalf("uploadBytes during capture: unexpected error: %v", err) + } + if !asyncCalled.Load() { + t.Fatal("uploadBytes during capture: expected cudaMemcpyAsync to be called") + } + if copiedSize != 4 { + t.Fatalf("uploadBytes during capture: copied size = %d, want 4", copiedSize) + } + if copiedKind != cuda.MemcpyHostToDevice { + t.Fatalf("uploadBytes during capture: copy kind = %v, want MemcpyHostToDevice", copiedKind) + } +} + +// TestUploadBytes_UsesSyncWhenNotCapturing verifies that uploadBytes +// falls through to the normal (non-async) path when capture is NOT active. +func TestUploadBytes_UsesSyncWhenNotCapturing(t *testing.T) { + var asyncCalled atomic.Bool + restoreMemcpy := swapMemcpyAsyncFn(func(_ unsafe.Pointer, _ unsafe.Pointer, _ int, _ cuda.MemcpyKind, _ *cuda.Stream) error { + asyncCalled.Store(true) + return nil + }) + defer restoreMemcpy() + + restoreStatus := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) { + return cuda.CaptureStatusNone, nil + }) + defer restoreStatus() + + pool := &fakeCapturePool{capturing: false} + e := &GPUEngine[float32]{ + stream: fakePtrStream{}, + pool: pool, + managedMem: true, + } + + // With managedMem=true and not capturing, uploadBytes does a direct CPU copy. + // We can't test the actual copy without a real managed pointer, but we can + // verify cudaMemcpyAsync was NOT called. + src := []byte{0x01, 0x02} + buf := make([]byte, 2) + err := e.uploadBytes(unsafe.Pointer(&buf[0]), src) + if err != nil { + t.Fatalf("uploadBytes (not capturing, managed): unexpected error: %v", err) + } + if asyncCalled.Load() { + t.Fatal("uploadBytes (not capturing, managed): cudaMemcpyAsync should NOT be called") + } + // Verify the sync copy worked. + if buf[0] != 0x01 || buf[1] != 0x02 { + t.Fatalf("uploadBytes (not capturing, managed): sync copy produced %v, want [1 2]", buf) + } +} + +// TestUploadBytes_GuardFiresWithoutCaptureAwareAllocator verifies that +// ensureNotCapturing still blocks uploadBytes when capture is active +// but the pool does NOT implement CaptureAwareAllocator. +func TestUploadBytes_GuardFiresWithoutCaptureAwareAllocator(t *testing.T) { + restoreStatus := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) { + return cuda.CaptureStatusActive, nil + }) + defer restoreStatus() + + e := &GPUEngine[float32]{ + stream: fakePtrStream{}, + pool: &fakeBasicPool{}, + } + + src := []byte{0x01} + err := e.uploadBytes(nil, src) + if err == nil { + t.Fatal("uploadBytes with non-capture-aware pool during capture: expected error, got nil") + } + if !errors.Is(err, ErrCaptureIncompatibleAllocation) { + t.Fatalf("uploadBytes: expected ErrCaptureIncompatibleAllocation, got %v", err) + } +} diff --git a/compute/gpu_engine.go b/compute/gpu_engine.go index dd2508f..012b912 100644 --- a/compute/gpu_engine.go +++ b/compute/gpu_engine.go @@ -588,6 +588,15 @@ var ( graphDestroyFn = cuda.GraphDestroy ) +// mallocManagedFn, mallocAsyncFn, and memcpyAsyncFn are indirection points +// for the CUDA allocation and copy functions used by allocWeight and uploadBytes. +// Tests swap them to verify capture-aware routing without real CUDA hardware. +var ( + mallocManagedFn = cuda.MallocManaged + mallocAsyncFn = cuda.MallocAsync + memcpyAsyncFn = cuda.MemcpyAsync +) + // ensureNotCapturing returns ErrCaptureIncompatibleAllocation if the // engine's stream is currently capturing a CUDA graph. On CPU-only // runtimes or when the stream handle is nil, returns nil (no capture @@ -614,24 +623,44 @@ func (e *GPUEngine[T]) ensureNotCapturing() error { // allocWeight allocates permanent memory for a weight tensor. // Uses cudaMallocManaged on devices with managed memory support, -// otherwise uses cudaMalloc. Returns ErrCaptureIncompatibleAllocation -// if invoked while a CUDA graph capture is active on the engine's stream. +// otherwise uses cudaMalloc. +// +// When CaptureAwareAllocator is active (set by BeginCapture/WithCapture), +// allocations route through cudaMallocAsync on the capture stream so they +// are recorded as graph nodes. This avoids the silent hang caused by +// cudaMallocManaged during CUDA graph capture on GB10. +// +// Returns ErrCaptureIncompatibleAllocation only if capture is active but +// the allocator was NOT properly switched via BeginCapture/WithCapture. func (e *GPUEngine[T]) allocWeight(byteSize int) (unsafe.Pointer, error) { + if cap, ok := e.pool.(gpuapi.CaptureAwareAllocator); ok && cap.IsCapturing() { + s := cuda.StreamFromPtr(e.Stream()) + return mallocAsyncFn(byteSize, s) + } if err := e.ensureNotCapturing(); err != nil { return nil, err } if e.managedMem { - return cuda.MallocManaged(byteSize) + return mallocManagedFn(byteSize) } return e.runtime.Malloc(byteSize) } // uploadBytes copies src bytes into a device (or managed) pointer. // With managed memory, this is a direct CPU memcpy (no H2D needed). -// Without managed memory, this uses cudaMemcpy H2D. Returns -// ErrCaptureIncompatibleAllocation if invoked while a CUDA graph capture -// is active on the engine's stream. +// Without managed memory, this uses cudaMemcpy H2D. +// +// When CaptureAwareAllocator is active, uses cudaMemcpyAsync on the +// capture stream so the copy is recorded as a graph node. The synchronous +// CPU copy used by the managed-memory path is illegal during capture. +// +// Returns ErrCaptureIncompatibleAllocation only if capture is active but +// the allocator was NOT properly switched via BeginCapture/WithCapture. func (e *GPUEngine[T]) uploadBytes(devPtr unsafe.Pointer, src []byte) error { + if cap, ok := e.pool.(gpuapi.CaptureAwareAllocator); ok && cap.IsCapturing() { + s := cuda.StreamFromPtr(e.Stream()) + return memcpyAsyncFn(devPtr, unsafe.Pointer(&src[0]), len(src), cuda.MemcpyHostToDevice, s) + } if err := e.ensureNotCapturing(); err != nil { return err } diff --git a/internal/cuda/mempool.go b/internal/cuda/mempool.go index 1b72cb2..9dd1c1d 100644 --- a/internal/cuda/mempool.go +++ b/internal/cuda/mempool.go @@ -65,6 +65,13 @@ func (p *MemPool) ClearCaptureStream() { p.mu.Unlock() } +// IsCapturing returns true when capture-aware allocation is active. +func (p *MemPool) IsCapturing() bool { + p.mu.Lock() + defer p.mu.Unlock() + return p.captureStream != nil +} + // bucketSize rounds byteSize up to the next reuse bucket. // Sizes <= 256 are kept exact (these are typically small scalar or shape // metadata). Sizes > 256 are rounded up to the next power of two, enabling diff --git a/internal/gpuapi/cuda_mempool.go b/internal/gpuapi/cuda_mempool.go index ce23cb0..c96a8c3 100644 --- a/internal/gpuapi/cuda_mempool.go +++ b/internal/gpuapi/cuda_mempool.go @@ -57,6 +57,11 @@ func (p *CUDAMemPool) ClearCaptureStream() { p.inner.ClearCaptureStream() } +// IsCapturing returns true when capture-aware allocation is active. +func (p *CUDAMemPool) IsCapturing() bool { + return p.inner.IsCapturing() +} + // Inner returns the underlying cuda.MemPool for backward compatibility. func (p *CUDAMemPool) Inner() *cuda.MemPool { return p.inner diff --git a/internal/gpuapi/mempool.go b/internal/gpuapi/mempool.go index 6ecb692..8ee6e38 100644 --- a/internal/gpuapi/mempool.go +++ b/internal/gpuapi/mempool.go @@ -36,4 +36,7 @@ type CaptureAwareAllocator interface { SetCaptureStream(stream unsafe.Pointer) // ClearCaptureStream disables capture-aware allocation. ClearCaptureStream() + // IsCapturing returns true when capture-aware allocation is active + // (i.e., SetCaptureStream has been called and not yet cleared). + IsCapturing() bool } From 06514b640ab1ea34038f1e7fd56dea1b01b545f6 Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 16 Apr 2026 09:26:29 -0700 Subject: [PATCH 2/3] feat(compute): T2.3 pre-allocate workspace buffers at UploadWeights to avoid capture-time alloc Add preAllocateWorkspaces() that eagerly initializes the FP8 scratchpad (scaleOne pointer + struct) and cuBLASLt handle at the end of UploadWeights, before any CUDA graph capture region begins. These two objects previously used lazy initialization (getFP8Scratch, getLtHandle) which triggered cudaMalloc on first use -- hanging silently on GB10 when first use happened inside capture. Also add captureAllocCount atomic counter to track allocWeight attempts during active capture. EndCapture resets the counter and logs a warning if non-zero. CaptureAllocCount() exposes the counter for testing. --- compute/gpu_engine.go | 57 ++++++++ compute/workspace_prealloc_test.go | 205 +++++++++++++++++++++++++++++ 2 files changed, 262 insertions(+) create mode 100644 compute/workspace_prealloc_test.go diff --git a/compute/gpu_engine.go b/compute/gpu_engine.go index dd2508f..0a08cd1 100644 --- a/compute/gpu_engine.go +++ b/compute/gpu_engine.go @@ -87,6 +87,12 @@ type GPUEngine[T tensor.Numeric] struct { // when cuBLAS receives very large matrices (e.g., 128256x4096 LM head). // Default: DefaultMaxAllocBytes (4 GB). maxAllocBytes int64 + + // captureAllocCount tracks allocWeight calls that occur during an active + // CUDA graph capture. A properly pre-allocated workload should see zero. + // Incremented atomically in allocWeight when capture is detected; + // checked and reset in EndCapture. + captureAllocCount atomic.Int64 } // NewGPUEngine creates a new GPUEngine backed by CUDA via the GRAL abstraction. @@ -570,6 +576,10 @@ func (e *GPUEngine[T]) UploadWeights(tensors []*tensor.TensorNumeric[float32]) e "device", fmt.Sprintf("%d", e.deviceID), "method", method) } + // Pre-allocate all workspace buffers that would otherwise be lazily + // initialized on first use. This ensures no cudaMalloc occurs inside + // a subsequent CUDA graph capture region. + e.preAllocateWorkspaces() return nil } @@ -618,6 +628,7 @@ func (e *GPUEngine[T]) ensureNotCapturing() error { // if invoked while a CUDA graph capture is active on the engine's stream. func (e *GPUEngine[T]) allocWeight(byteSize int) (unsafe.Pointer, error) { if err := e.ensureNotCapturing(); err != nil { + e.captureAllocCount.Add(1) return nil, err } if e.managedMem { @@ -685,6 +696,10 @@ func (e *GPUEngine[T]) EndCapture() (GraphHandle, error) { if cap, ok := e.pool.(gpuapi.CaptureAwareAllocator); ok { defer cap.ClearCaptureStream() } + if n := e.captureAllocCount.Swap(0); n > 0 { + e.logger.Warn("allocWeight called during capture", + "count", fmt.Sprintf("%d", n)) + } s := cuda.StreamFromPtr(e.Stream()) graph, err := streamEndCaptureFn(s) if err != nil { @@ -790,6 +805,48 @@ func (e *GPUEngine[T]) Close() error { return firstErr } +// CaptureAllocCount returns the cumulative number of allocWeight calls that +// were attempted while a CUDA graph capture was active. A properly +// pre-allocated workload should observe zero after EndCapture. +func (e *GPUEngine[T]) CaptureAllocCount() int64 { + return e.captureAllocCount.Load() +} + +// preAllocateWorkspaces eagerly initializes all lazy-allocated workspace +// buffers so that no cudaMalloc occurs inside a CUDA graph capture region. +// Called at the end of UploadWeights, after all weight tensors are on GPU. +// +// For dense float32 workloads, pool.Alloc (arena-backed) is capture-safe via +// CaptureAwareAllocator, but objects allocated outside the arena — the FP8 +// scratchpad and the cuBLASLt handle — use cudaMalloc and would hang if first +// touched during capture on GB10. +func (e *GPUEngine[T]) preAllocateWorkspaces() { + // 1. FP8 scratchpad: allocate scaleOne and the struct itself so that the + // first FP8 MatMul during capture does not trigger cudaMalloc. + if e.fp8Scratch == nil { + if s, err := e.getFP8Scratch(); err != nil { + e.logger.Warn("preAllocateWorkspaces: FP8 scratchpad init failed", + "error", err.Error()) + } else { + _ = s // assigned to e.fp8Scratch inside getFP8Scratch + } + } + + // 2. cuBLASLt handle: cublasLtCreate allocates internal CUDA state. + if e.ltHandle == nil { + if h, err := e.getLtHandle(); err != nil { + e.logger.Warn("preAllocateWorkspaces: cuBLASLt handle init failed", + "error", err.Error()) + } else { + _ = h // assigned to e.ltHandle inside getLtHandle + } + } + + e.logger.Info("workspace buffers pre-allocated", + "fp8Scratch", fmt.Sprintf("%v", e.fp8Scratch != nil), + "ltHandle", fmt.Sprintf("%v", e.ltHandle != nil)) +} + // OOMFallbackCount returns the number of times GPU OOM triggered CPU fallback. func (e *GPUEngine[T]) OOMFallbackCount() int64 { return e.oomFallbackCount.Load() diff --git a/compute/workspace_prealloc_test.go b/compute/workspace_prealloc_test.go new file mode 100644 index 0000000..b394f94 --- /dev/null +++ b/compute/workspace_prealloc_test.go @@ -0,0 +1,205 @@ +package compute + +import ( + "errors" + "testing" + + "github.com/zerfoo/ztensor/internal/cuda" + "github.com/zerfoo/ztensor/log" + "github.com/zerfoo/ztensor/numeric" + "github.com/zerfoo/ztensor/tensor" +) + +// TestPreAllocateWorkspaces_FP8ScratchInitialized verifies that after +// UploadWeights, the FP8 scratchpad is non-nil (eagerly initialized). +func TestPreAllocateWorkspaces_FP8ScratchInitialized(t *testing.T) { + eng := newPreallocEngine(t) + if eng.fp8Scratch != nil { + t.Fatal("precondition: fp8Scratch should be nil before UploadWeights") + } + + if err := eng.UploadWeights(nil); err != nil { + t.Fatalf("UploadWeights: %v", err) + } + + if eng.fp8Scratch == nil { + t.Fatal("fp8Scratch should be non-nil after UploadWeights") + } + if eng.fp8Scratch.scaleOne == nil { + t.Fatal("fp8Scratch.scaleOne should be non-nil after pre-allocation") + } +} + +// TestPreAllocateWorkspaces_CalledByUploadWeights verifies that +// preAllocateWorkspaces fires at the end of UploadWeights even when +// called with an empty weight list (the pre-allocation is unconditional). +func TestPreAllocateWorkspaces_CalledByUploadWeights(t *testing.T) { + eng := newPreallocEngine(t) + + if err := eng.UploadWeights([]*tensor.TensorNumeric[float32]{}); err != nil { + t.Fatalf("UploadWeights: %v", err) + } + + if eng.fp8Scratch == nil { + t.Fatal("fp8Scratch should be non-nil after UploadWeights") + } + if eng.fp8Scratch.scaleOne == nil { + t.Fatal("fp8Scratch.scaleOne should be non-nil after pre-allocation") + } +} + +// TestPreAllocateWorkspaces_TableDriven exercises workspace pre-allocation +// with varying weight list sizes. Pre-allocation is unconditional, so +// fp8Scratch should be non-nil regardless of weight count. +func TestPreAllocateWorkspaces_TableDriven(t *testing.T) { + tests := []struct { + name string + numWeights int + }{ + {name: "no weights", numWeights: 0}, + {name: "one nil entry", numWeights: 1}, + {name: "three nil entries", numWeights: 3}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + eng := newPreallocEngine(t) + pool := eng.pool.(*fakeMemPool) + + // Pass nil tensor entries -- UploadWeights skips them. + weights := make([]*tensor.TensorNumeric[float32], tt.numWeights) + if err := eng.UploadWeights(weights); err != nil { + t.Fatalf("UploadWeights: %v", err) + } + + if eng.fp8Scratch == nil { + t.Error("fp8Scratch should be non-nil after UploadWeights") + } + if eng.fp8Scratch.scaleOne == nil { + t.Error("fp8Scratch.scaleOne should be non-nil") + } + // scaleOne alloc is the minimum: 1 pool.Alloc from getFP8Scratch. + if pool.allocCount < 1 { + t.Errorf("expected at least 1 alloc from pre-allocation, got %d", pool.allocCount) + } + }) + } +} + +// TestCaptureAllocCount_ZeroAfterPrealloc verifies that captureAllocCount +// stays at zero when allocWeight is not called during capture. This is the +// expected state for a properly pre-allocated workload. +func TestCaptureAllocCount_ZeroAfterPrealloc(t *testing.T) { + eng := newPreallocEngine(t) + if err := eng.UploadWeights(nil); err != nil { + t.Fatalf("UploadWeights: %v", err) + } + + if got := eng.CaptureAllocCount(); got != 0 { + t.Fatalf("CaptureAllocCount after UploadWeights: got %d, want 0", got) + } +} + +// TestCaptureAllocCount_IncrementsOnCaptureTimeAlloc verifies that +// allocWeight increments captureAllocCount when capture is active. +func TestCaptureAllocCount_IncrementsOnCaptureTimeAlloc(t *testing.T) { + restore := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) { + return cuda.CaptureStatusActive, nil + }) + defer restore() + + eng := &GPUEngine[float32]{stream: fakePtrStream{}} + + // First attempt — should fail with capture sentinel and increment counter. + _, err := eng.allocWeight(4096) + if !errors.Is(err, ErrCaptureIncompatibleAllocation) { + t.Fatalf("allocWeight: expected ErrCaptureIncompatibleAllocation, got %v", err) + } + + if got := eng.CaptureAllocCount(); got != 1 { + t.Fatalf("CaptureAllocCount after 1 attempt: got %d, want 1", got) + } + + // Second attempt — count should increase. + _, _ = eng.allocWeight(8192) + if got := eng.CaptureAllocCount(); got != 2 { + t.Fatalf("CaptureAllocCount after 2 attempts: got %d, want 2", got) + } +} + +// TestCaptureAllocCount_ResetByEndCapture verifies that EndCapture resets +// the captureAllocCount to zero after logging. +func TestCaptureAllocCount_ResetByEndCapture(t *testing.T) { + // Arrange: inject a capture-active status for allocWeight, then swap to + // a non-capture status for EndCapture. + captureActive := true + restore := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) { + if captureActive { + return cuda.CaptureStatusActive, nil + } + return cuda.CaptureStatusNone, nil + }) + defer restore() + + eng := &GPUEngine[float32]{ + stream: fakePtrStream{}, + logger: log.Nop(), + } + + // Trigger two allocWeight attempts during capture. + _, _ = eng.allocWeight(4096) + _, _ = eng.allocWeight(8192) + if got := eng.CaptureAllocCount(); got != 2 { + t.Fatalf("CaptureAllocCount before EndCapture: got %d, want 2", got) + } + + // EndCapture will fail (no real graph) but should still reset the counter. + captureActive = false + oldEnd := streamEndCaptureFn + streamEndCaptureFn = func(_ *cuda.Stream) (*cuda.Graph, error) { + return nil, errors.New("synthetic: no graph") + } + defer func() { streamEndCaptureFn = oldEnd }() + + _, _ = eng.EndCapture() + + if got := eng.CaptureAllocCount(); got != 0 { + t.Fatalf("CaptureAllocCount after EndCapture: got %d, want 0", got) + } +} + +// TestPreAllocateWorkspaces_Idempotent verifies that calling +// preAllocateWorkspaces multiple times does not leak or double-allocate. +func TestPreAllocateWorkspaces_Idempotent(t *testing.T) { + eng := newPreallocEngine(t) + pool := eng.pool.(*fakeMemPool) + + eng.preAllocateWorkspaces() + allocsAfterFirst := pool.allocCount + + eng.preAllocateWorkspaces() + allocsAfterSecond := pool.allocCount + + if allocsAfterSecond != allocsAfterFirst { + t.Fatalf("second preAllocateWorkspaces caused %d new allocs, want 0", + allocsAfterSecond-allocsAfterFirst) + } +} + +// newPreallocEngine builds a GPUEngine suitable for testing workspace +// pre-allocation without real CUDA hardware. +func newPreallocEngine(t *testing.T) *GPUEngine[float32] { + t.Helper() + pool := newFakeMemPool() + return &GPUEngine[float32]{ + cpu: NewCPUEngine[float32](numeric.Float32Ops{}), + runtime: fakeRuntime{}, + pool: pool, + stream: fakeStream{}, + logger: log.Nop(), + deviceID: 0, + dtype: DTypeF32, + maxAllocBytes: DefaultMaxAllocBytes, + } +} + From 3c7b8b6cb77cd8c0cbaf95432aa8a09bb7ca9dad Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Thu, 16 Apr 2026 09:28:21 -0700 Subject: [PATCH 3/3] docs(plan): mark T2.2 + T2.3 complete (Wave 4b) --- docs/plan.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/plan.md b/docs/plan.md index 4ca949f..c2e820f 100644 --- a/docs/plan.md +++ b/docs/plan.md @@ -256,10 +256,10 @@ All estimates are rough; refine when a task starts. - [ ] T2.1b zerfoo: switch `graph/cuda_graph.go:beginCapture` to use `WithCapture`. Owner: TBD. Est: 45m. verifies: [UC-002] - Acceptance: Existing zerfoo GGUF inference tests still pass; gemma4e and gemma3 parity suites unchanged. - Dependencies: T2.1a, ztensor version bump merged. -- [ ] T2.2 Introduce a `managedMem` guard in `allocWeight` that routes to `cudaMallocAsync` on the capture stream when `CaptureAwareAllocator` is active. Otherwise fall back to `MallocManaged`. Owner: TBD. Est: 90m. verifies: [UC-002] +- [x] T2.2 Introduce a `managedMem` guard in `allocWeight` that routes to `cudaMallocAsync` on the capture stream when `CaptureAwareAllocator` is active. Otherwise fall back to `MallocManaged`. Owner: task-T2.2. Est: 90m. verifies: [UC-002] Completed: 2026-04-16 - Acceptance: Unit test with a mocked capture stream records an async-alloc node instead of a sync call. - Dependencies: T2.1a. -- [ ] T2.3 Pre-allocate workspace buffers used by `MatMul`, `Add`, and `RMSNorm` variants at `UploadWeights` time so no lazy alloc occurs inside capture for dense float32 workloads. Owner: TBD. Est: 3h. verifies: [UC-001, UC-002] +- [x] T2.3 Pre-allocate workspace buffers used by `MatMul`, `Add`, and `RMSNorm` variants at `UploadWeights` time so no lazy alloc occurs inside capture for dense float32 workloads. Owner: task-T2.3. Est: 3h. verifies: [UC-001, UC-002] Completed: 2026-04-16 - Acceptance: Instrument with a counter; capture region records zero `allocWeight` calls for the CrossAsset workload. - Dependencies: T1.3, T2.1a. - [ ] T2.4 Add unit and integration tests for T2.1 to T2.3. Owner: TBD. Est: 90m. verifies: [infrastructure] @@ -351,8 +351,8 @@ count equals the number of task IDs listed on that wave. #### Wave 4: Fix + fallback in parallel (4 agents) - [x] T2.1a ztensor `WithCapture` helper verifies: [UC-002] 2026-04-16 -- [ ] T2.2 Capture-aware `allocWeight` routing verifies: [UC-002] -- [ ] T2.3 Pre-allocate forward-pass workspace verifies: [UC-001, UC-002] +- [x] T2.2 Capture-aware `allocWeight` routing verifies: [UC-002] 2026-04-16 +- [x] T2.3 Pre-allocate forward-pass workspace verifies: [UC-001, UC-002] 2026-04-16 - [x] T4.1 Capture watchdog verifies: [UC-005] 2026-04-16 #### Wave 5: Tests, linters, zerfoo pickup (4 agents)