diff --git a/cuda_core/cuda/core/_memory/_managed_memory_resource.pyx b/cuda_core/cuda/core/_memory/_managed_memory_resource.pyx index 59921bdfc1..ca375d7f5a 100644 --- a/cuda_core/cuda/core/_memory/_managed_memory_resource.pyx +++ b/cuda_core/cuda/core/_memory/_managed_memory_resource.pyx @@ -1,16 +1,20 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations from cuda.bindings cimport cydriver + from cuda.core._memory._memory_pool cimport _MemPool, _MemPoolOptions from cuda.core._utils.cuda_utils cimport ( + HANDLE_RETURN, check_or_create_options, ) from dataclasses import dataclass +import threading +import warnings __all__ = ['ManagedMemoryResource', 'ManagedMemoryResourceOptions'] @@ -91,6 +95,7 @@ cdef class ManagedMemoryResource(_MemPool): opts_base._type = cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_MANAGED super().__init__(device_id, opts_base) + _check_concurrent_managed_access(self._dev_id) ELSE: raise RuntimeError("ManagedMemoryResource requires CUDA 13.0 or later") @@ -103,3 +108,44 @@ cdef class ManagedMemoryResource(_MemPool): def is_host_accessible(self) -> bool: """Return True. This memory resource provides host-accessible buffers.""" return True + + +cdef bint _concurrent_access_warned = False +cdef object _concurrent_access_lock = threading.Lock() + + +cdef inline _check_concurrent_managed_access(int device_id): + """Warn once if the device lacks concurrent managed memory access.""" + global _concurrent_access_warned + if _concurrent_access_warned: + return + + cdef int c_concurrent = 0 + with _concurrent_access_lock: + if _concurrent_access_warned: + return + + with nogil: + HANDLE_RETURN(cydriver.cuDeviceGetAttribute( + &c_concurrent, + cydriver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS, + device_id)) + if not c_concurrent: + warnings.warn( + "This platform does not support concurrent managed memory access " + "(Device.properties.concurrent_managed_access is False). Host access to any managed " + "allocation is forbidden while any GPU kernel is in flight, even " + "if the kernel does not touch that allocation. Failing to " + "synchronize before host access will cause a segfault. " + "See: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-coherency-hd", + UserWarning, + stacklevel=3 + ) + + _concurrent_access_warned = True + + +def reset_concurrent_access_warning(): + """Reset the concurrent access warning flag for testing purposes.""" + global _concurrent_access_warned + _concurrent_access_warned = False diff --git a/cuda_core/tests/test_build_hooks.py b/cuda_core/tests/test_build_hooks.py index e416503bc0..419efbe065 100644 --- a/cuda_core/tests/test_build_hooks.py +++ b/cuda_core/tests/test_build_hooks.py @@ -24,8 +24,9 @@ import pytest -# build_hooks.py imports Cython at the top level, so skip if not available +# build_hooks.py imports Cython and setuptools at the top level, so skip if not available pytest.importorskip("Cython") +pytest.importorskip("setuptools") def _load_build_hooks(): diff --git a/cuda_core/tests/test_managed_memory_warning.py b/cuda_core/tests/test_managed_memory_warning.py new file mode 100644 index 0000000000..84bd1a2dab --- /dev/null +++ b/cuda_core/tests/test_managed_memory_warning.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Test that a warning is emitted when ManagedMemoryResource is created on a +platform without concurrent managed memory access. + +These tests only run on affected platforms (concurrent_managed_access is False). +""" + +import warnings + +import pytest +from cuda.core import Device, ManagedMemoryResource, ManagedMemoryResourceOptions +from cuda.core._memory._managed_memory_resource import reset_concurrent_access_warning + + +def _make_managed_mr(device_id): + """Create a ManagedMemoryResource with an explicit device preference.""" + return ManagedMemoryResource(options=ManagedMemoryResourceOptions(preferred_location=device_id)) + + +@pytest.fixture +def device_without_concurrent_managed_access(init_cuda): + """Return a device that lacks concurrent managed access, or skip.""" + device = Device() + device.set_current() + + try: + pools_supported = device.properties.memory_pools_supported + except AttributeError: + pytest.skip("ManagedMemoryResource requires CUDA 13.0 or later") + + if not pools_supported: + pytest.skip("Device does not support memory pools") + + if device.properties.concurrent_managed_access: + pytest.skip("Device supports concurrent managed access; warning not applicable") + + return device + + +def test_warning_emitted(device_without_concurrent_managed_access): + """ManagedMemoryResource emits a warning when concurrent managed access is unsupported.""" + dev_id = device_without_concurrent_managed_access.device_id + reset_concurrent_access_warning() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + mr = _make_managed_mr(dev_id) + + concurrent_warnings = [ + warning for warning in w if "concurrent managed memory access" in str(warning.message).lower() + ] + assert len(concurrent_warnings) == 1 + assert concurrent_warnings[0].category is UserWarning + assert "segfault" in str(concurrent_warnings[0].message).lower() + + mr.close() + + +def test_warning_emitted_only_once(device_without_concurrent_managed_access): + """Warning fires only once even when multiple ManagedMemoryResources are created.""" + dev_id = device_without_concurrent_managed_access.device_id + reset_concurrent_access_warning() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + mr1 = _make_managed_mr(dev_id) + mr2 = _make_managed_mr(dev_id) + + concurrent_warnings = [ + warning for warning in w if "concurrent managed memory access" in str(warning.message).lower() + ] + assert len(concurrent_warnings) == 1 + + mr1.close() + mr2.close()