Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions compute/capture_guard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@ package compute

import (
"errors"
"fmt"
"testing"
"unsafe"

"github.com/zerfoo/ztensor/internal/cuda"
"github.com/zerfoo/ztensor/internal/gpuapi"
)

// TestEnsureNotCapturing_NilStream verifies that ensureNotCapturing returns
Expand All @@ -15,6 +20,73 @@ func TestEnsureNotCapturing_NilStream(t *testing.T) {
}
}

// TestEnsureNotCapturing_NilPtr verifies that ensureNotCapturing returns nil
// when the engine has a stream whose Ptr() is nil. This can happen when a
// stream object is present but the underlying vendor handle was never
// assigned (CPU-shim runtimes).
func TestEnsureNotCapturing_NilPtr(t *testing.T) {
e := &GPUEngine[float32]{stream: nilPtrStream{}}
if err := e.ensureNotCapturing(); err != nil {
t.Fatalf("ensureNotCapturing on nil-ptr stream: got %v, want nil", err)
}
}

// TestEnsureNotCapturing_ProbeStatuses is a table-driven test that walks
// every cudaStreamCaptureStatus value through ensureNotCapturing and asserts
// the mapping to the guard's outcome:
// - None -> nil (allocation allowed)
// - Active -> ErrCaptureIncompatibleAllocation
// - Invalidated -> nil (guard only blocks Active; fallback logic handles Invalidated)
func TestEnsureNotCapturing_ProbeStatuses(t *testing.T) {
tests := []struct {
name string
status cuda.CaptureStatus
want error
}{
{name: "None allows allocation", status: cuda.CaptureStatusNone, want: nil},
{name: "Active blocks allocation", status: cuda.CaptureStatusActive, want: ErrCaptureIncompatibleAllocation},
{name: "Invalidated does not trip the active guard", status: cuda.CaptureStatusInvalidated, want: nil},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
restore := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) {
return tc.status, nil
})
defer restore()

e := &GPUEngine[float32]{stream: fakePtrStream{}}
got := e.ensureNotCapturing()
if !errors.Is(got, tc.want) && got != tc.want {
t.Fatalf("ensureNotCapturing(status=%v): got %v, want %v", tc.status, got, tc.want)
}
})
}
}

// TestEnsureNotCapturing_ProbeError verifies that when cudaStreamGetCaptureInfo
// itself fails, ensureNotCapturing returns that error (wrapped for context) and
// does NOT silently treat the stream as safe. Probe failure must propagate so
// callers fail loud instead of racing a hang on GB10.
func TestEnsureNotCapturing_ProbeError(t *testing.T) {
probeErr := errors.New("cudaStreamGetCaptureInfo failed: synthetic")
restore := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) {
return cuda.CaptureStatusNone, probeErr
})
defer restore()

e := &GPUEngine[float32]{stream: fakePtrStream{}}
err := e.ensureNotCapturing()
if err == nil {
t.Fatal("ensureNotCapturing: expected error from failing probe, got nil")
}
if !errors.Is(err, probeErr) {
t.Fatalf("ensureNotCapturing: expected error to wrap probe error, got %v", err)
}
if errors.Is(err, ErrCaptureIncompatibleAllocation) {
t.Fatal("ensureNotCapturing: probe error must not be surfaced as ErrCaptureIncompatibleAllocation")
}
}

// TestErrCaptureIncompatibleAllocation_Is verifies that
// ErrCaptureIncompatibleAllocation is a sentinel error usable with
// errors.Is, both directly and when wrapped.
Expand All @@ -28,6 +100,17 @@ func TestErrCaptureIncompatibleAllocation_Is(t *testing.T) {
}
}

// TestErrCaptureIncompatibleAllocation_FmtErrorfWrap verifies that the sentinel
// survives fmt.Errorf("...: %w", ...) wrapping — the idiom callers in
// allocWeight / uploadBytes use indirectly via ensureNotCapturing and that
// downstream callers use when adding their own context.
func TestErrCaptureIncompatibleAllocation_FmtErrorfWrap(t *testing.T) {
wrapped := fmt.Errorf("upload layer %d: %w", 7, ErrCaptureIncompatibleAllocation)
if !errors.Is(wrapped, ErrCaptureIncompatibleAllocation) {
t.Fatalf("errors.Is through fmt.Errorf wrap: got false, want true (err=%v)", wrapped)
}
}

// wrapErr emulates a caller that wraps the sentinel error with %w.
// Kept local to the test to avoid leaking helpers into the package API.
func wrapErr(err error) error {
Expand All @@ -38,3 +121,40 @@ type wrappedErr struct{ inner error }

func (w *wrappedErr) Error() string { return "wrapped: " + w.inner.Error() }
func (w *wrappedErr) Unwrap() error { return w.inner }

// swapCaptureStatusFn replaces the package-level captureStatusFn for a test
// and returns a restore closure. Callers defer restore() to keep tests hermetic.
func swapCaptureStatusFn(fn func(*cuda.Stream) (cuda.CaptureStatus, error)) func() {
prev := captureStatusFn
captureStatusFn = fn
return func() { captureStatusFn = prev }
}

// fakeStreamSentinel backs fakePtrStream.Ptr() with a stable address so that
// escape-analysis does not re-allocate per call and returned pointers remain
// valid for the lifetime of the test binary. The probe is stubbed, so the
// handle is never dereferenced.
var fakeStreamSentinel byte

// fakePtrStream satisfies gpuapi.Stream and returns a non-nil Ptr so that
// ensureNotCapturing proceeds past the early-return guards and exercises the
// probe path. Synchronize / Destroy are never called by the guard.
type fakePtrStream struct{}

func (fakePtrStream) Synchronize() error { return nil }
func (fakePtrStream) Destroy() error { return nil }
func (fakePtrStream) Ptr() unsafe.Pointer { return unsafe.Pointer(&fakeStreamSentinel) }

// nilPtrStream satisfies gpuapi.Stream but returns a nil Ptr. Used to cover
// the "stream present but unbacked" branch of ensureNotCapturing.
type nilPtrStream struct{}

func (nilPtrStream) Synchronize() error { return nil }
func (nilPtrStream) Destroy() error { return nil }
func (nilPtrStream) Ptr() unsafe.Pointer { return nil }

// Compile-time assertions that the fakes satisfy gpuapi.Stream.
var (
_ gpuapi.Stream = fakePtrStream{}
_ gpuapi.Stream = nilPtrStream{}
)
7 changes: 6 additions & 1 deletion compute/gpu_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,11 @@ func (e *GPUEngine[T]) UploadWeights(tensors []*tensor.TensorNumeric[float32]) e
return nil
}

// captureStatusFn is the indirection point for cuda.StreamCaptureStatus used
// by ensureNotCapturing. Tests swap it to inject synthetic capture state
// without requiring real CUDA hardware.
var captureStatusFn = cuda.StreamCaptureStatus

// ensureNotCapturing returns ErrCaptureIncompatibleAllocation if the
// engine's stream is currently capturing a CUDA graph. On CPU-only
// runtimes or when the stream handle is nil, returns nil (no capture
Expand All @@ -587,7 +592,7 @@ func (e *GPUEngine[T]) ensureNotCapturing() error {
return nil
}
s := cuda.StreamFromPtr(ptr)
status, err := cuda.StreamCaptureStatus(s)
status, err := captureStatusFn(s)
if err != nil {
return fmt.Errorf("ensureNotCapturing: %w", err)
}
Expand Down
113 changes: 113 additions & 0 deletions compute/gpu_engine_alloc_guard_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package compute

import (
"errors"
"testing"

"github.com/zerfoo/ztensor/internal/cuda"
)

// TestAllocWeight_PropagatesCaptureSentinel confirms the capture guard's
// sentinel flows out of allocWeight unchanged. A caller wrapping the error
// with fmt.Errorf("%w") must still match the sentinel via errors.Is so that
// fallback paths (CaptureSafe, later epics) can catch the exact failure mode.
func TestAllocWeight_PropagatesCaptureSentinel(t *testing.T) {
restore := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) {
return cuda.CaptureStatusActive, nil
})
defer restore()

e := &GPUEngine[float32]{stream: fakePtrStream{}}
ptr, err := e.allocWeight(4096)
if err == nil {
t.Fatal("allocWeight under active capture: expected error, got nil")
}
if !errors.Is(err, ErrCaptureIncompatibleAllocation) {
t.Fatalf("allocWeight: expected ErrCaptureIncompatibleAllocation, got %v", err)
}
if ptr != nil {
t.Fatalf("allocWeight: expected nil pointer on guard trip, got %p", ptr)
}
}

// TestAllocWeight_PropagatesProbeError confirms that if the capture probe
// itself fails, allocWeight returns the wrapped probe error — not the
// sentinel, and not a nil error that would let a hang happen silently.
func TestAllocWeight_PropagatesProbeError(t *testing.T) {
probeErr := errors.New("cudaStreamGetCaptureInfo failed: synthetic")
restore := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) {
return cuda.CaptureStatusNone, probeErr
})
defer restore()

e := &GPUEngine[float32]{stream: fakePtrStream{}}
ptr, err := e.allocWeight(4096)
if err == nil {
t.Fatal("allocWeight with failing probe: expected error, got nil")
}
if !errors.Is(err, probeErr) {
t.Fatalf("allocWeight: expected wrapped probe error, got %v", err)
}
if errors.Is(err, ErrCaptureIncompatibleAllocation) {
t.Fatal("allocWeight: probe failure must not be reported as capture sentinel")
}
if ptr != nil {
t.Fatalf("allocWeight: expected nil pointer on probe failure, got %p", ptr)
}
}

// TestUploadBytes_PropagatesCaptureSentinel mirrors the allocWeight test on
// the upload path. uploadBytes is the second weight-load entry point touched
// during UploadWeights, so both must fail loud under active capture.
func TestUploadBytes_PropagatesCaptureSentinel(t *testing.T) {
restore := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) {
return cuda.CaptureStatusActive, nil
})
defer restore()

e := &GPUEngine[float32]{stream: fakePtrStream{}}
src := []byte{0x01, 0x02, 0x03, 0x04}
err := e.uploadBytes(nil, src)
if err == nil {
t.Fatal("uploadBytes under active capture: expected error, got nil")
}
if !errors.Is(err, ErrCaptureIncompatibleAllocation) {
t.Fatalf("uploadBytes: expected ErrCaptureIncompatibleAllocation, got %v", err)
}
}

// TestUploadBytes_PropagatesProbeError confirms probe failures propagate out
// of uploadBytes the same way they do out of allocWeight.
func TestUploadBytes_PropagatesProbeError(t *testing.T) {
probeErr := errors.New("cudaStreamGetCaptureInfo failed: synthetic")
restore := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) {
return cuda.CaptureStatusNone, probeErr
})
defer restore()

e := &GPUEngine[float32]{stream: fakePtrStream{}}
src := []byte{0x01, 0x02}
err := e.uploadBytes(nil, src)
if err == nil {
t.Fatal("uploadBytes with failing probe: expected error, got nil")
}
if !errors.Is(err, probeErr) {
t.Fatalf("uploadBytes: expected wrapped probe error, got %v", err)
}
if errors.Is(err, ErrCaptureIncompatibleAllocation) {
t.Fatal("uploadBytes: probe failure must not be reported as capture sentinel")
}
}

// TestAllocWeight_PassesWhenNotCapturing_NilStream is a negative control: on
// an engine with a nil stream (CPU-only path), allocWeight must NOT be
// short-circuited by the guard. We cannot safely drive it into the real
// runtime Malloc here (no GPU), but we can confirm the guard returns nil and
// the failure, if any, comes from downstream (runtime == nil panic would
// indicate the guard path is wrong).
func TestEnsureNotCapturing_AllowsAllocationWhenStreamAbsent(t *testing.T) {
e := &GPUEngine[float32]{}
if err := e.ensureNotCapturing(); err != nil {
t.Fatalf("ensureNotCapturing with nil stream: got %v, want nil", err)
}
}
Loading
Loading