-
Notifications
You must be signed in to change notification settings - Fork 245
Description
Problem
There is no context manager for temporarily switching the active CUDA device in cuda.core. Users who need to perform work on a specific device and then restore the previous device must manually manage this state:
from cuda.core import Device
# Save current state, switch, do work, restore
dev0 = Device(0)
dev0.set_current()
# ... do work on device 0 ...
dev1 = Device(1)
dev1.set_current()
# ... do work on device 1 ...
dev0.set_current() # manually restore -- easy to forget, not exception-safeThis pattern is error-prone (the restore call is not exception-safe) and verbose compared to the idiomatic Python with statement. The pattern appears in real-world code such as dask-cuda, which implemented a workaround for this missing feature.
Proposed Design
Add __enter__ and __exit__ methods to Device, making it usable as a context manager that temporarily activates the device's primary context and restores the previous state on exit.
API
from cuda.core import Device
dev0 = Device(0)
dev0.set_current()
# ... do work on device 0 ...
with Device(1) as device:
# device 1 is now current
stream = device.create_stream()
# ...
# device 0 is automatically restored hereSemantics
On __enter__:
- Query the current CUDA context via
cuCtxGetCurrentand save it on the context manager instance. - Call
self.set_current()(which uses the primary context for this device viacuCtxSetCurrent). - Return
self.
On __exit__:
- Restore the saved context via
cuCtxSetCurrent. If the saved context wasNULL, setNULL(no active context). - Do NOT suppress exceptions (return
False).
Key design properties
Stateless restoration (no Python-side stack). Each __enter__ call queries the actual CUDA driver state for the current context rather than maintaining a Python-side stack. On __exit__, it restores exactly what was saved. This is the critical lesson from CuPy's experience (cupy/cupy#6965, cupy/cupy#7427): libraries that maintain their own stack of previous devices break interoperability with libraries that use the CUDA API directly to check device state. By always querying and restoring the driver-level state, we interoperate correctly with PyTorch, CuPy, and any other library that uses cudaGetDevice/cudaSetDevice or cuCtxGetCurrent/cuCtxSetCurrent.
Reentrant and reusable. Because Device is a thread-local singleton and the saved-context state is stored per-__enter__ invocation (not on the Device object itself), the context manager is both reusable and reentrant:
dev0 = Device(0)
dev1 = Device(1)
with dev0:
with dev1:
with dev0: # reentrant -- works correctly
...
# dev1 restored
# dev0 restored
# original state restoredTo achieve reentrancy, the saved context must NOT be stored on self (the Device singleton). Instead, use a thread-local stack or return a helper object from __enter__ that holds the saved state. The simplest correct approach: store a per-thread stack of saved contexts on the Device class (or module-level), pushing on __enter__ and popping on __exit__.
Implementation sketch (in _device.pyx):
# Module-level: add a per-thread stack for saved contexts
# (reuse existing _tls threading.local())
def __enter__(self):
# Query actual CUDA state -- do NOT use a Python-side device cache
prev_ctx = handle_return(cuCtxGetCurrent())
# Store on a per-thread stack so nested `with` works
if not hasattr(_tls, '_ctx_stack'):
_tls._ctx_stack = []
_tls._ctx_stack.append(prev_ctx)
self.set_current()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
prev_ctx = _tls._ctx_stack.pop()
handle_return(cuCtxSetCurrent(prev_ctx))
return FalseNote: The stack here is NOT a device stack -- it is a stack of saved CUcontext values that the __exit__ restores. Each entry corresponds to exactly one __enter__ call. This is fundamentally different from CuPy's old broken approach which tracked a stack of device IDs and queried that stack instead of the CUDA API.
Interoperability with other libraries. Because we use cuCtxSetCurrent (driver API), and both PyTorch and CuPy use cudaGetDevice/cudaSetDevice (runtime API) which queries the same underlying driver state, cross-library nesting works:
with torch.cuda.device(1):
with Device(2):
# Both torch and cuda.core see device 2
...
# torch sees device 1 again (cuda.core restored the context)Note that correct cross-library nesting depends on each library querying the CUDA API for the current device on context exit rather than relying on a cached value. Libraries that follow this pattern (including CuPy v12+ and the CUDA runtime API) will interoperate correctly.
Alternatives Considered
1. Separate Device.activate() method returning a context manager
with Device(1).activate():
...This avoids adding __enter__/__exit__ to the singleton Device object. However, it adds API surface for no practical benefit -- the saved context state can be stored on a thread-local stack rather than on the Device instance, making Device itself safe to use directly as a reentrant context manager. The with Device(1): syntax is also more natural and matches PyTorch's with torch.cuda.device(1): pattern.
Rejected because it adds unnecessary indirection.
2. Do nothing -- recommend set_current() only
Per CuPy's internal policy, context managers for device switching are banned in CuPy's own codebase because they are footguns for library developers. The argument is that set_current() is explicit and unambiguous.
However, cuda.core targets end users (not just library internals), and the context manager pattern is:
- Exception-safe by default
- Idiomatic Python
- Already provided by PyTorch and CuPy (for end users)
- Requested by downstream users (dask-cuda)
Rejected as the sole approach, but set_current() remains the recommended approach for library code that needs precise control.
3. Use cuCtxPushCurrent/cuCtxPopCurrent instead of cuCtxSetCurrent
The CUDA driver provides an explicit context stack via push/pop. Using this would make nesting trivially correct. However, Device.set_current() currently uses cuCtxSetCurrent for primary contexts (not push/pop), and mixing the two models is fragile. The push/pop model also does not interoperate with libraries using cudaSetDevice (runtime API). The current approach of save-via-query/restore-via-set is correct and interoperable.
Rejected because it would diverge from the runtime API model that other libraries use.
Open Questions
-
Should
__enter__callset_current()even if this device is already current? CallingcuCtxSetCurrentwith the already-current context is cheap (no-op at the driver level) and keeps the implementation simple. The alternative (check-and-skip) adds complexity for negligible performance gain. Recommendation: always callset_current(). -
What should
__enter__do ifset_current()has never been called on this device? Currently, manyDeviceproperties requireset_current()to have been called first (_check_context_initialized). The context manager should unconditionally callset_current(), initializing the device if needed. This is the natural expectation:with Device(1):should make device 1 ready for use. -
Should we document cross-library interop expectations? We should document that
with Device(N):works correctly forcuda.corecode, and that cross-library nesting works as long as the other library's context manager correctly queries CUDA state on exit rather than relying on a cached value.
Test Plan
- Basic usage:
with Device(0):sets device 0 as current, restores on exit. - Exception safety: device is restored even when an exception is raised inside the
withblock. - Nesting (same device):
with dev0: with dev0:works without error. - Nesting (different devices):
with dev0: with dev1:correctly restores dev0 on exit of inner block. - Deep nesting / reentrancy:
with dev0: with dev1: with dev0: with dev1:restores correctly at each level. Deviceremains usable after context manager exit (singleton not corrupted).- Multi-GPU: requires 2+ GPUs. Verify
cudaGetDevice()(runtime API) reflects the device set by the context manager. - Thread safety: context manager state is per-thread (uses thread-local storage), so concurrent threads using different devices should not interfere.