diff --git a/cuda_core/cuda/core/_device.pyx b/cuda_core/cuda/core/_device.pyx index 9d143679f8..ad12af15c9 100644 --- a/cuda_core/cuda/core/_device.pyx +++ b/cuda_core/cuda/core/_device.pyx @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 @@ -1188,6 +1188,49 @@ class Device: def __reduce__(self): return Device, (self.device_id,) + def __enter__(self): + """Set this device as current for the duration of the ``with`` block. + + On exit, the previously current device is restored automatically. + Nested ``with`` blocks are supported and restore correctly at each + level. + + Returns + ------- + Device + This device instance. + + Examples + -------- + >>> from cuda.core import Device + >>> with Device(0) as dev0: + ... buf = dev0.allocate(1024) + + See Also + -------- + set_current : Non-context-manager entry point. + """ + cdef cydriver.CUcontext prev_ctx + with nogil: + HANDLE_RETURN(cydriver.cuCtxGetCurrent(&prev_ctx)) + if not hasattr(_tls, '_ctx_stack'): + _tls._ctx_stack = [] + _tls._ctx_stack.append(prev_ctx) + self.set_current() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Restore the previously current device upon exiting the ``with`` block. + + Exceptions are not suppressed. + """ + cdef uintptr_t prev_ctx_ptr = _tls._ctx_stack[-1] + cdef cydriver.CUcontext prev_ctx = prev_ctx_ptr + with nogil: + HANDLE_RETURN(cydriver.cuCtxSetCurrent(prev_ctx)) + _tls._ctx_stack.pop() + return False + def set_current(self, ctx: Context = None) -> Context | None: """Set device to be used for GPU executions. diff --git a/cuda_core/tests/conftest.py b/cuda_core/tests/conftest.py index 340e632719..caa8fae78a 100644 --- a/cuda_core/tests/conftest.py +++ b/cuda_core/tests/conftest.py @@ -192,13 +192,15 @@ def _mempool_device_impl(num): @pytest.fixture def mempool_device_x2(): """Fixture that provides two devices if available, otherwise skips test.""" - return _mempool_device_impl(2) + yield _mempool_device_impl(2) + _device_unset_current() @pytest.fixture def mempool_device_x3(): """Fixture that provides three devices if available, otherwise skips test.""" - return _mempool_device_impl(3) + yield _mempool_device_impl(3) + _device_unset_current() @pytest.fixture( diff --git a/cuda_core/tests/test_device.py b/cuda_core/tests/test_device.py index d561f92b9e..fd1302f0e8 100644 --- a/cuda_core/tests/test_device.py +++ b/cuda_core/tests/test_device.py @@ -436,3 +436,179 @@ def test_device_set_membership(init_cuda): # Same device_id should not add duplicate device_set.add(dev0_b) assert len(device_set) == 1, "Should not add duplicate device" + + +# ============================================================================ +# Device Context Manager Tests +# ============================================================================ + + +def _get_current_context(): + """Return the current CUcontext as an int (0 means NULL / no context).""" + return int(handle_return(driver.cuCtxGetCurrent())) + + +def test_context_manager_basic(deinit_cuda): + """with Device(0) sets the device as current and restores on exit.""" + assert _get_current_context() == 0, "Should start with no active context" + + with Device(0): + assert _get_current_context() != 0, "Device should be current inside the with block" + + assert _get_current_context() == 0, "No context should be current after exiting" + + +def test_context_manager_restores_previous(deinit_cuda): + """Context manager restores the previously active context, not NULL.""" + dev0 = Device(0) + dev0.set_current() + ctx_before = _get_current_context() + assert ctx_before != 0 + + with Device(0): + pass + + assert _get_current_context() == ctx_before, "Should restore the previous context" + + +def test_context_manager_exception_safety(deinit_cuda): + """Device context is restored even when an exception is raised.""" + # Start with no active context so restoration is distinguishable + assert _get_current_context() == 0 + + with pytest.raises(RuntimeError, match="test error"), Device(0): + assert _get_current_context() != 0, "Device should be active inside the block" + raise RuntimeError("test error") + + assert _get_current_context() == 0, "Must restore NULL context after exception" + + +def test_context_manager_returns_device(deinit_cuda): + """__enter__ returns the Device instance for use in 'as' clause.""" + device = Device(0) + with device as dev: + assert dev is device + + assert _get_current_context() == 0 + + +def test_context_manager_nesting_same_device(deinit_cuda): + """Nested with-blocks on the same device work correctly.""" + dev0 = Device(0) + + with dev0: + ctx_outer = _get_current_context() + with dev0: + ctx_inner = _get_current_context() + assert ctx_inner == ctx_outer, "Same device should yield same context" + assert _get_current_context() == ctx_outer, "Outer context restored after inner exit" + + assert _get_current_context() == 0 + + +def test_context_manager_deep_nesting(deinit_cuda): + """Deep nesting and reentrancy restore correctly at each level.""" + dev0 = Device(0) + + with dev0: + ctx_level1 = _get_current_context() + with dev0: + ctx_level2 = _get_current_context() + with dev0: + assert _get_current_context() != 0 + assert _get_current_context() == ctx_level2 + assert _get_current_context() == ctx_level1 + + assert _get_current_context() == 0 + + +def test_context_manager_nesting_different_devices(mempool_device_x2): + """Nested with-blocks on different devices restore correctly.""" + dev0, dev1 = mempool_device_x2 + ctx_dev0 = _get_current_context() + + with dev1: + ctx_inside = _get_current_context() + assert ctx_inside != ctx_dev0, "Different device should have different context" + + assert _get_current_context() == ctx_dev0, "Original device context should be restored" + + +def test_context_manager_deep_nesting_multi_gpu(mempool_device_x2): + """Deep nesting across multiple devices restores correctly at each level.""" + dev0, dev1 = mempool_device_x2 + + with dev0: + ctx_level0 = _get_current_context() + with dev1: + ctx_level1 = _get_current_context() + assert ctx_level1 != ctx_level0 + with dev0: + assert _get_current_context() == ctx_level0, "Same device should yield same primary context" + with dev1: + assert _get_current_context() == ctx_level1 + assert _get_current_context() == ctx_level0 + assert _get_current_context() == ctx_level1 + assert _get_current_context() == ctx_level0 + + +def test_context_manager_set_current_inside(mempool_device_x2): + """set_current() inside a with block does not affect restoration on exit.""" + dev0, dev1 = mempool_device_x2 + ctx_dev0 = _get_current_context() # dev0 is current from fixture + + with dev0: + dev1.set_current() # change the active device inside the block + assert _get_current_context() != ctx_dev0 + + assert _get_current_context() == ctx_dev0, "Must restore the context saved at __enter__" + + +def test_context_manager_device_usable_after_exit(deinit_cuda): + """Device singleton is not corrupted after context manager exit.""" + device = Device(0) + with device: + pass + + assert _get_current_context() == 0 + + # Device should still be usable via set_current + device.set_current() + assert _get_current_context() != 0 + stream = device.create_stream() + assert stream is not None + + +def test_context_manager_initializes_device(deinit_cuda): + """with Device(N) should initialize the device, making it ready for use.""" + device = Device(0) + with device: + # allocate requires an active context; should not raise + buf = device.allocate(1024) + assert buf.handle != 0 + + +def test_context_manager_thread_safety(mempool_device_x3): + """Concurrent threads using context managers on different devices don't interfere.""" + import concurrent.futures + import threading + + devices = mempool_device_x3 + barrier = threading.Barrier(len(devices)) + errors = [] + + def worker(dev): + try: + ctx_before = _get_current_context() + with dev: + barrier.wait(timeout=5) + buf = dev.allocate(1024) + assert buf.handle != 0 + assert _get_current_context() == ctx_before + except Exception as e: + errors.append(e) + + with concurrent.futures.ThreadPoolExecutor(max_workers=len(devices)) as pool: + pool.map(worker, devices) + + assert not errors, f"Thread errors: {errors}"