Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions compute/capture_guard_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package compute

import (
"errors"
"testing"
)

// TestEnsureNotCapturing_NilStream verifies that ensureNotCapturing returns
// nil on an engine whose stream is nil (CPU-only runtime). This is the
// common path on machines without CUDA.
func TestEnsureNotCapturing_NilStream(t *testing.T) {
e := &GPUEngine[float32]{}
if err := e.ensureNotCapturing(); err != nil {
t.Fatalf("ensureNotCapturing on nil-stream engine: got %v, want nil", err)
}
}

// TestErrCaptureIncompatibleAllocation_Is verifies that
// ErrCaptureIncompatibleAllocation is a sentinel error usable with
// errors.Is, both directly and when wrapped.
func TestErrCaptureIncompatibleAllocation_Is(t *testing.T) {
if !errors.Is(ErrCaptureIncompatibleAllocation, ErrCaptureIncompatibleAllocation) {
t.Fatal("errors.Is should match sentinel against itself")
}
wrapped := wrapErr(ErrCaptureIncompatibleAllocation)
if !errors.Is(wrapped, ErrCaptureIncompatibleAllocation) {
t.Fatal("errors.Is should see sentinel through a wrapper")
}
}

// wrapErr emulates a caller that wraps the sentinel error with %w.
// Kept local to the test to avoid leaking helpers into the package API.
func wrapErr(err error) error {
return &wrappedErr{inner: err}
}

type wrappedErr struct{ inner error }

func (w *wrappedErr) Error() string { return "wrapped: " + w.inner.Error() }
func (w *wrappedErr) Unwrap() error { return w.inner }
11 changes: 11 additions & 0 deletions compute/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package compute

import "errors"

// ErrCaptureIncompatibleAllocation is returned when a weight allocation
// or upload is attempted while a CUDA graph capture is active on the
// engine's stream. Allocations during capture are not supported and
// would silently hang on GB10. Callers should either upload weights
// before BeginCapture, or catch this error and fall back to an
// uncaptured run.
var ErrCaptureIncompatibleAllocation = errors.New("compute: allocation attempted during active CUDA graph capture")
37 changes: 35 additions & 2 deletions compute/gpu_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,10 +573,38 @@ func (e *GPUEngine[T]) UploadWeights(tensors []*tensor.TensorNumeric[float32]) e
return nil
}

// ensureNotCapturing returns ErrCaptureIncompatibleAllocation if the
// engine's stream is currently capturing a CUDA graph. On CPU-only
// runtimes or when the stream handle is nil, returns nil (no capture
// is possible). If querying capture status itself fails, returns
// that error (do not assume safety on probe failure).
func (e *GPUEngine[T]) ensureNotCapturing() error {
if e.stream == nil {
return nil
}
ptr := e.stream.Ptr()
if ptr == nil {
return nil
}
s := cuda.StreamFromPtr(ptr)
status, err := cuda.StreamCaptureStatus(s)
if err != nil {
return fmt.Errorf("ensureNotCapturing: %w", err)
}
if status == cuda.CaptureStatusActive {
return ErrCaptureIncompatibleAllocation
}
return nil
}

// allocWeight allocates permanent memory for a weight tensor.
// Uses cudaMallocManaged on devices with managed memory support,
// otherwise uses cudaMalloc.
// otherwise uses cudaMalloc. Returns ErrCaptureIncompatibleAllocation
// if invoked while a CUDA graph capture is active on the engine's stream.
func (e *GPUEngine[T]) allocWeight(byteSize int) (unsafe.Pointer, error) {
if err := e.ensureNotCapturing(); err != nil {
return nil, err
}
if e.managedMem {
return cuda.MallocManaged(byteSize)
}
Expand All @@ -585,8 +613,13 @@ func (e *GPUEngine[T]) allocWeight(byteSize int) (unsafe.Pointer, error) {

// uploadBytes copies src bytes into a device (or managed) pointer.
// With managed memory, this is a direct CPU memcpy (no H2D needed).
// Without managed memory, this uses cudaMemcpy H2D.
// Without managed memory, this uses cudaMemcpy H2D. Returns
// ErrCaptureIncompatibleAllocation if invoked while a CUDA graph capture
// is active on the engine's stream.
func (e *GPUEngine[T]) uploadBytes(devPtr unsafe.Pointer, src []byte) error {
if err := e.ensureNotCapturing(); err != nil {
return err
}
if e.managedMem {
dst := unsafe.Slice((*byte)(devPtr), len(src))
copy(dst, src)
Expand Down
Loading
Loading