Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion cuda_core/cuda/core/_memory/_managed_memory_resource.pyx
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from cuda.bindings cimport cydriver

from cuda.core._memory._memory_pool cimport _MemPool, _MemPoolOptions
from cuda.core._utils.cuda_utils cimport (
HANDLE_RETURN,
check_or_create_options,
)

from dataclasses import dataclass
import threading
import warnings

__all__ = ['ManagedMemoryResource', 'ManagedMemoryResourceOptions']

Expand Down Expand Up @@ -91,6 +95,7 @@ cdef class ManagedMemoryResource(_MemPool):
opts_base._type = cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_MANAGED

super().__init__(device_id, opts_base)
_check_concurrent_managed_access(self._dev_id)
ELSE:
raise RuntimeError("ManagedMemoryResource requires CUDA 13.0 or later")

Expand All @@ -103,3 +108,44 @@ cdef class ManagedMemoryResource(_MemPool):
def is_host_accessible(self) -> bool:
"""Return True. This memory resource provides host-accessible buffers."""
return True


cdef bint _concurrent_access_warned = False
cdef object _concurrent_access_lock = threading.Lock()


cdef inline _check_concurrent_managed_access(int device_id):
"""Warn once if the device lacks concurrent managed memory access."""
global _concurrent_access_warned
if _concurrent_access_warned:
return

cdef int c_concurrent = 0
with _concurrent_access_lock:
if _concurrent_access_warned:
return

with nogil:
HANDLE_RETURN(cydriver.cuDeviceGetAttribute(
&c_concurrent,
cydriver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS,
device_id))
if not c_concurrent:
warnings.warn(
"This platform does not support concurrent managed memory access "
"(Device.properties.concurrent_managed_access is False). Host access to any managed "
"allocation is forbidden while any GPU kernel is in flight, even "
"if the kernel does not touch that allocation. Failing to "
"synchronize before host access will cause a segfault. "
"See: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-coherency-hd",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UserWarning,
stacklevel=3
)

_concurrent_access_warned = True


def reset_concurrent_access_warning():
"""Reset the concurrent access warning flag for testing purposes."""
global _concurrent_access_warned
_concurrent_access_warned = False
3 changes: 2 additions & 1 deletion cuda_core/tests/test_build_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@

import pytest

# build_hooks.py imports Cython at the top level, so skip if not available
# build_hooks.py imports Cython and setuptools at the top level, so skip if not available
pytest.importorskip("Cython")
pytest.importorskip("setuptools")


def _load_build_hooks():
Expand Down
78 changes: 78 additions & 0 deletions cuda_core/tests/test_managed_memory_warning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Test that a warning is emitted when ManagedMemoryResource is created on a
platform without concurrent managed memory access.

These tests only run on affected platforms (concurrent_managed_access is False).
"""

import warnings

import pytest
from cuda.core import Device, ManagedMemoryResource, ManagedMemoryResourceOptions
from cuda.core._memory._managed_memory_resource import reset_concurrent_access_warning


def _make_managed_mr(device_id):
"""Create a ManagedMemoryResource with an explicit device preference."""
return ManagedMemoryResource(options=ManagedMemoryResourceOptions(preferred_location=device_id))


@pytest.fixture
def device_without_concurrent_managed_access(init_cuda):
"""Return a device that lacks concurrent managed access, or skip."""
device = Device()
device.set_current()

try:
pools_supported = device.properties.memory_pools_supported
except AttributeError:
pytest.skip("ManagedMemoryResource requires CUDA 13.0 or later")

if not pools_supported:
pytest.skip("Device does not support memory pools")

if device.properties.concurrent_managed_access:
pytest.skip("Device supports concurrent managed access; warning not applicable")

return device


def test_warning_emitted(device_without_concurrent_managed_access):
"""ManagedMemoryResource emits a warning when concurrent managed access is unsupported."""
dev_id = device_without_concurrent_managed_access.device_id
reset_concurrent_access_warning()

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
mr = _make_managed_mr(dev_id)

concurrent_warnings = [
warning for warning in w if "concurrent managed memory access" in str(warning.message).lower()
]
assert len(concurrent_warnings) == 1
assert concurrent_warnings[0].category is UserWarning
assert "segfault" in str(concurrent_warnings[0].message).lower()

mr.close()


def test_warning_emitted_only_once(device_without_concurrent_managed_access):
"""Warning fires only once even when multiple ManagedMemoryResources are created."""
dev_id = device_without_concurrent_managed_access.device_id
reset_concurrent_access_warning()

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
mr1 = _make_managed_mr(dev_id)
mr2 = _make_managed_mr(dev_id)

concurrent_warnings = [
warning for warning in w if "concurrent managed memory access" in str(warning.message).lower()
]
assert len(concurrent_warnings) == 1

mr1.close()
mr2.close()
Loading