Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion cuda_core/cuda/core/_device.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

Expand Down Expand Up @@ -1188,6 +1188,49 @@ class Device:
def __reduce__(self):
return Device, (self.device_id,)

def __enter__(self):
"""Set this device as current for the duration of the ``with`` block.

On exit, the previously current device is restored automatically.
Nested ``with`` blocks are supported and restore correctly at each
level.

Returns
-------
Device
This device instance.

Examples
--------
>>> from cuda.core import Device
>>> with Device(0) as dev0:
... buf = dev0.allocate(1024)

See Also
--------
set_current : Non-context-manager entry point.
"""
cdef cydriver.CUcontext prev_ctx
with nogil:
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&prev_ctx))
if not hasattr(_tls, '_ctx_stack'):
_tls._ctx_stack = []
_tls._ctx_stack.append(<uintptr_t><void*>prev_ctx)
self.set_current()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
"""Restore the previously current device upon exiting the ``with`` block.

Exceptions are not suppressed.
"""
cdef uintptr_t prev_ctx_ptr = _tls._ctx_stack[-1]
cdef cydriver.CUcontext prev_ctx = <cydriver.CUcontext><void*>prev_ctx_ptr
with nogil:
HANDLE_RETURN(cydriver.cuCtxSetCurrent(prev_ctx))
_tls._ctx_stack.pop()
return False

def set_current(self, ctx: Context = None) -> Context | None:
"""Set device to be used for GPU executions.

Expand Down
6 changes: 4 additions & 2 deletions cuda_core/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,15 @@ def _mempool_device_impl(num):
@pytest.fixture
def mempool_device_x2():
"""Fixture that provides two devices if available, otherwise skips test."""
return _mempool_device_impl(2)
yield _mempool_device_impl(2)
_device_unset_current()


@pytest.fixture
def mempool_device_x3():
"""Fixture that provides three devices if available, otherwise skips test."""
return _mempool_device_impl(3)
yield _mempool_device_impl(3)
_device_unset_current()


@pytest.fixture(
Expand Down
176 changes: 176 additions & 0 deletions cuda_core/tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,179 @@ def test_device_set_membership(init_cuda):
# Same device_id should not add duplicate
device_set.add(dev0_b)
assert len(device_set) == 1, "Should not add duplicate device"


# ============================================================================
# Device Context Manager Tests
# ============================================================================


def _get_current_context():
"""Return the current CUcontext as an int (0 means NULL / no context)."""
return int(handle_return(driver.cuCtxGetCurrent()))


def test_context_manager_basic(deinit_cuda):
"""with Device(0) sets the device as current and restores on exit."""
assert _get_current_context() == 0, "Should start with no active context"

with Device(0):
assert _get_current_context() != 0, "Device should be current inside the with block"

assert _get_current_context() == 0, "No context should be current after exiting"


def test_context_manager_restores_previous(deinit_cuda):
"""Context manager restores the previously active context, not NULL."""
dev0 = Device(0)
dev0.set_current()
ctx_before = _get_current_context()
assert ctx_before != 0

with Device(0):
pass

assert _get_current_context() == ctx_before, "Should restore the previous context"


def test_context_manager_exception_safety(deinit_cuda):
"""Device context is restored even when an exception is raised."""
# Start with no active context so restoration is distinguishable
assert _get_current_context() == 0

with pytest.raises(RuntimeError, match="test error"), Device(0):
assert _get_current_context() != 0, "Device should be active inside the block"
raise RuntimeError("test error")

assert _get_current_context() == 0, "Must restore NULL context after exception"


def test_context_manager_returns_device(deinit_cuda):
"""__enter__ returns the Device instance for use in 'as' clause."""
device = Device(0)
with device as dev:
assert dev is device

assert _get_current_context() == 0


def test_context_manager_nesting_same_device(deinit_cuda):
"""Nested with-blocks on the same device work correctly."""
dev0 = Device(0)

with dev0:
ctx_outer = _get_current_context()
with dev0:
ctx_inner = _get_current_context()
assert ctx_inner == ctx_outer, "Same device should yield same context"
assert _get_current_context() == ctx_outer, "Outer context restored after inner exit"

assert _get_current_context() == 0


def test_context_manager_deep_nesting(deinit_cuda):
"""Deep nesting and reentrancy restore correctly at each level."""
dev0 = Device(0)

with dev0:
ctx_level1 = _get_current_context()
with dev0:
ctx_level2 = _get_current_context()
with dev0:
assert _get_current_context() != 0
assert _get_current_context() == ctx_level2
assert _get_current_context() == ctx_level1

assert _get_current_context() == 0


def test_context_manager_nesting_different_devices(mempool_device_x2):
"""Nested with-blocks on different devices restore correctly."""
dev0, dev1 = mempool_device_x2
ctx_dev0 = _get_current_context()

with dev1:
ctx_inside = _get_current_context()
assert ctx_inside != ctx_dev0, "Different device should have different context"

assert _get_current_context() == ctx_dev0, "Original device context should be restored"


def test_context_manager_deep_nesting_multi_gpu(mempool_device_x2):
"""Deep nesting across multiple devices restores correctly at each level."""
dev0, dev1 = mempool_device_x2

with dev0:
ctx_level0 = _get_current_context()
with dev1:
ctx_level1 = _get_current_context()
assert ctx_level1 != ctx_level0
with dev0:
assert _get_current_context() == ctx_level0, "Same device should yield same primary context"
with dev1:
assert _get_current_context() == ctx_level1
assert _get_current_context() == ctx_level0
assert _get_current_context() == ctx_level1
assert _get_current_context() == ctx_level0


def test_context_manager_set_current_inside(mempool_device_x2):
"""set_current() inside a with block does not affect restoration on exit."""
dev0, dev1 = mempool_device_x2
ctx_dev0 = _get_current_context() # dev0 is current from fixture

with dev0:
dev1.set_current() # change the active device inside the block
assert _get_current_context() != ctx_dev0

assert _get_current_context() == ctx_dev0, "Must restore the context saved at __enter__"


def test_context_manager_device_usable_after_exit(deinit_cuda):
"""Device singleton is not corrupted after context manager exit."""
device = Device(0)
with device:
pass

assert _get_current_context() == 0

# Device should still be usable via set_current
device.set_current()
assert _get_current_context() != 0
stream = device.create_stream()
assert stream is not None


def test_context_manager_initializes_device(deinit_cuda):
"""with Device(N) should initialize the device, making it ready for use."""
device = Device(0)
with device:
# allocate requires an active context; should not raise
buf = device.allocate(1024)
assert buf.handle != 0


def test_context_manager_thread_safety(mempool_device_x3):
"""Concurrent threads using context managers on different devices don't interfere."""
import concurrent.futures
import threading

devices = mempool_device_x3
barrier = threading.Barrier(len(devices))
errors = []

def worker(dev):
try:
ctx_before = _get_current_context()
with dev:
barrier.wait(timeout=5)
buf = dev.allocate(1024)
assert buf.handle != 0
assert _get_current_context() == ctx_before
except Exception as e:
errors.append(e)

with concurrent.futures.ThreadPoolExecutor(max_workers=len(devices)) as pool:
pool.map(worker, devices)

assert not errors, f"Thread errors: {errors}"
Loading