Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions graph/cuda_graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,22 @@ var debugGraphCapture = os.Getenv("ZERFOO_DEBUG_GPU") == "1"
// (offset_memcpy kernel) and GQA uses GPU RoPE selection (rope_select kernel),
// all position-dependent state is read from GPU memory at replay time, making
// GQA fully capturable.
//
// Gemma4PLECombinedProducer: performs a CPU-side gather over the shared PLE
// embedding table (token ids -> per-layer rows), then calls MulScalar on the
// freshly-allocated CPUStorage tensor. Running this inside a capture stream
// triggers a synchronous H2D cudaMemcpy that CUDA rejects. The producer runs
// once per forward pass before the transformer loop, so placing it in
// pre-capture keeps the layer-body capture region intact. See ADR-088.
var nonCapturableOps = map[string]bool{
"EmbeddingLookup": true,
"Gather": true,
"AutoAttentionMask": true,
"AutoPositionIds": true,
"Slice": true,
"ConstantOfShape": true,
"Shape": true,
"EmbeddingLookup": true,
"Gather": true,
"AutoAttentionMask": true,
"AutoPositionIds": true,
"Slice": true,
"ConstantOfShape": true,
"Shape": true,
"Gemma4PLECombinedProducer": true,
}

// isNonCapturable returns true if the instruction at index i in the plan
Expand Down
Loading