diff --git a/graph/cuda_graph.go b/graph/cuda_graph.go index 009aedf..f813dbb 100644 --- a/graph/cuda_graph.go +++ b/graph/cuda_graph.go @@ -44,14 +44,22 @@ var debugGraphCapture = os.Getenv("ZERFOO_DEBUG_GPU") == "1" // (offset_memcpy kernel) and GQA uses GPU RoPE selection (rope_select kernel), // all position-dependent state is read from GPU memory at replay time, making // GQA fully capturable. +// +// Gemma4PLECombinedProducer: performs a CPU-side gather over the shared PLE +// embedding table (token ids -> per-layer rows), then calls MulScalar on the +// freshly-allocated CPUStorage tensor. Running this inside a capture stream +// triggers a synchronous H2D cudaMemcpy that CUDA rejects. The producer runs +// once per forward pass before the transformer loop, so placing it in +// pre-capture keeps the layer-body capture region intact. See ADR-088. var nonCapturableOps = map[string]bool{ - "EmbeddingLookup": true, - "Gather": true, - "AutoAttentionMask": true, - "AutoPositionIds": true, - "Slice": true, - "ConstantOfShape": true, - "Shape": true, + "EmbeddingLookup": true, + "Gather": true, + "AutoAttentionMask": true, + "AutoPositionIds": true, + "Slice": true, + "ConstantOfShape": true, + "Shape": true, + "Gemma4PLECombinedProducer": true, } // isNonCapturable returns true if the instruction at index i in the plan