Skip to content

Commit 25455d8

Browse files
committed
Add CUDA process checkpointing helpers
1 parent 11347ff commit 25455d8

5 files changed

Lines changed: 404 additions & 2 deletions

File tree

‎cuda_core/cuda/core/__init__.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def _import_versioned_module():
2828
del _import_versioned_module
2929

3030

31-
from cuda.core import system, utils
31+
from cuda.core import checkpoint, system, utils
3232
from cuda.core._device import Device
3333
from cuda.core._event import Event, EventOptions
3434
from cuda.core._graphics import GraphicsResource

‎cuda_core/cuda/core/checkpoint.py‎

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from collections.abc import Mapping
6+
from dataclasses import dataclass
7+
from enum import IntEnum
8+
from typing import Any
9+
10+
from cuda.core._utils.cuda_utils import handle_return
11+
12+
13+
class ProcessState(IntEnum):
14+
"""
15+
CUDA checkpoint state for a process.
16+
"""
17+
18+
RUNNING = 0
19+
LOCKED = 1
20+
CHECKPOINTED = 2
21+
FAILED = 3
22+
23+
24+
@dataclass(frozen=True)
25+
class Process:
26+
"""
27+
CUDA process that can be locked, checkpointed, restored, and unlocked.
28+
29+
Parameters
30+
----------
31+
pid : int
32+
Process ID of the CUDA process.
33+
"""
34+
35+
pid: int
36+
37+
def __post_init__(self):
38+
_check_pid(self.pid)
39+
40+
@property
41+
def state(self) -> ProcessState:
42+
"""
43+
CUDA checkpoint state for this process.
44+
"""
45+
driver = _get_driver()
46+
state = _handle_return(driver, driver.cuCheckpointProcessGetState(self.pid))
47+
return ProcessState(int(state))
48+
49+
@property
50+
def restore_thread_id(self) -> int:
51+
"""
52+
CUDA restore thread ID for this process.
53+
"""
54+
driver = _get_driver()
55+
return _handle_return(driver, driver.cuCheckpointProcessGetRestoreThreadId(self.pid))
56+
57+
def lock(self, timeout_ms: int = 0) -> None:
58+
"""
59+
Lock this process, blocking further CUDA API calls.
60+
61+
Parameters
62+
----------
63+
timeout_ms : int, optional
64+
Timeout in milliseconds. A value of 0 indicates no timeout.
65+
"""
66+
driver = _get_driver()
67+
args = driver.CUcheckpointLockArgs()
68+
args.timeoutMs = _check_timeout_ms(timeout_ms)
69+
_handle_return(driver, driver.cuCheckpointProcessLock(self.pid, args))
70+
71+
def checkpoint(self) -> None:
72+
"""
73+
Checkpoint the GPU memory contents of this locked process.
74+
"""
75+
driver = _get_driver()
76+
_handle_return(driver, driver.cuCheckpointProcessCheckpoint(self.pid, None))
77+
78+
def restore(self, gpu_mapping: Mapping[Any, Any] | None = None) -> None:
79+
"""
80+
Restore this checkpointed process.
81+
82+
Parameters
83+
----------
84+
gpu_mapping : mapping, optional
85+
GPU UUID remapping from each checkpointed GPU UUID to the GPU UUID
86+
to restore onto. If provided, the mapping must contain every
87+
checkpointed GPU UUID.
88+
"""
89+
driver = _get_driver()
90+
args = _make_restore_args(driver, gpu_mapping)
91+
_handle_return(driver, driver.cuCheckpointProcessRestore(self.pid, args))
92+
93+
def unlock(self) -> None:
94+
"""
95+
Unlock this locked process so it can resume CUDA API calls.
96+
"""
97+
driver = _get_driver()
98+
_handle_return(driver, driver.cuCheckpointProcessUnlock(self.pid, None))
99+
100+
101+
def _get_driver():
102+
try:
103+
from cuda.bindings import driver
104+
except ImportError:
105+
from cuda import cuda as driver
106+
107+
required = (
108+
"cuCheckpointProcessCheckpoint",
109+
"cuCheckpointProcessGetRestoreThreadId",
110+
"cuCheckpointProcessGetState",
111+
"cuCheckpointProcessLock",
112+
"cuCheckpointProcessRestore",
113+
"cuCheckpointProcessUnlock",
114+
"CUcheckpointGpuPair",
115+
"CUcheckpointLockArgs",
116+
"CUcheckpointRestoreArgs",
117+
)
118+
missing = [name for name in required if not hasattr(driver, name)]
119+
if missing:
120+
raise RuntimeError(
121+
f"CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. Missing: {', '.join(missing)}"
122+
)
123+
return driver
124+
125+
126+
def _handle_return(driver, result):
127+
err = result[0]
128+
not_supported_errors = (
129+
getattr(driver.CUresult, "CUDA_ERROR_NOT_FOUND", None),
130+
getattr(driver.CUresult, "CUDA_ERROR_NOT_SUPPORTED", None),
131+
)
132+
if err in not_supported_errors:
133+
raise RuntimeError(
134+
"CUDA checkpointing is not supported by the installed NVIDIA driver. "
135+
"Upgrade to a driver version with CUDA checkpoint API support."
136+
)
137+
138+
return handle_return(result)
139+
140+
141+
def _check_pid(pid: int) -> int:
142+
if isinstance(pid, bool) or not isinstance(pid, int):
143+
raise TypeError("pid must be an int")
144+
if pid <= 0:
145+
raise ValueError("pid must be a positive int")
146+
return pid
147+
148+
149+
def _check_timeout_ms(timeout_ms: int) -> int:
150+
if isinstance(timeout_ms, bool) or not isinstance(timeout_ms, int):
151+
raise TypeError("timeout_ms must be an int")
152+
if timeout_ms < 0:
153+
raise ValueError("timeout_ms must be >= 0")
154+
return timeout_ms
155+
156+
157+
def _make_restore_args(driver, gpu_mapping: Mapping[Any, Any] | None):
158+
if gpu_mapping is None:
159+
return None
160+
if not isinstance(gpu_mapping, Mapping):
161+
raise TypeError("gpu_mapping must be a mapping from checkpointed GPU UUID to restore GPU UUID")
162+
163+
pairs = []
164+
for old_uuid, new_uuid in gpu_mapping.items():
165+
pair = driver.CUcheckpointGpuPair()
166+
pair.oldUuid = old_uuid
167+
pair.newUuid = new_uuid
168+
pairs.append(pair)
169+
170+
if not pairs:
171+
return None
172+
173+
args = driver.CUcheckpointRestoreArgs()
174+
args.gpuPairs = pairs
175+
args.gpuPairsCount = len(pairs)
176+
return args
177+
178+
179+
__all__ = [
180+
"Process",
181+
"ProcessState",
182+
]

‎cuda_core/docs/source/api.rst‎

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,22 @@ CUDA compilation toolchain
174174
LinkerOptions
175175

176176

177+
CUDA process checkpointing
178+
--------------------------
179+
180+
.. autosummary::
181+
:toctree: generated/
182+
183+
:template: class.rst
184+
185+
checkpoint.Process
186+
187+
.. autosummary::
188+
:toctree: generated/
189+
190+
checkpoint.ProcessState
191+
192+
177193
CUDA system information and NVIDIA Management Library (NVML)
178194
------------------------------------------------------------
179195

‎cuda_core/docs/source/release/1.0.0-notes.rst‎

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ Highlights
1616
New features
1717
------------
1818

19-
- TBD
19+
- Added the :mod:`cuda.core.checkpoint` module for CUDA process checkpointing,
20+
including process state queries, lock/checkpoint/restore/unlock operations,
21+
and GPU UUID remapping support for restore.
22+
(`#1343 <https://github.com/NVIDIA/cuda-python/issues/1343>`__)
2023

2124

2225
Fixes and enhancements

0 commit comments

Comments
 (0)