Skip to content

Commit 067fb15

Browse files
rparolinclaude
andcommitted
refactor(cuda.core): factor shared body of _do_batch_{prefetch,discard_prefetch} (N2)
Per Leo's review on PR #1775 (_managed_memory_ops.pyx:425), the two batched-with-locations helpers were byte-for-byte identical except for the driver function being called. Both: - declare the same four std::vectors (ptrs, sizes, loc_arr, loc_indices) - resize and fill them in the same loop - release the GIL and call cuMem{Prefetch,DiscardAndPrefetch}BatchAsync with the same argument shape Introduce a function-pointer typedef _BatchPrefetchFn (the two driver calls share signature), parameterize the shared body as _do_batch_prefetch_op, and have the two callers pass the appropriate driver function. Both the typedef and the helper live inside the IF CUDA_CORE_BUILD_MAJOR >= 13 block since they reference cu13-only types. Net: -28 lines duplication, +25 for the shared helper. No behavior change; tests unaffected. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 36012fd commit 067fb15

1 file changed

Lines changed: 19 additions & 27 deletions

File tree

‎cuda_core/cuda/core/_memory/_managed_memory_ops.pyx‎

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,18 @@ cdef void _do_single_prefetch(Buffer buf, object loc, Stream s):
322322
HANDLE_RETURN(cydriver.cuMemPrefetchAsync(cu_ptr, nbytes, dev_int, hstream))
323323

324324

325-
cdef void _do_batch_prefetch(tuple bufs, tuple locs, Stream s):
326-
IF CUDA_CORE_BUILD_MAJOR >= 13:
325+
IF CUDA_CORE_BUILD_MAJOR >= 13:
326+
# Function-pointer type for cuMemPrefetchBatchAsync /
327+
# cuMemDiscardAndPrefetchBatchAsync; both have identical signatures.
328+
ctypedef cydriver.CUresult (*_BatchPrefetchFn)(
329+
cydriver.CUdeviceptr*, size_t*, size_t,
330+
cydriver.CUmemLocation*, size_t*, size_t,
331+
unsigned long long, cydriver.CUstream,
332+
) except ?cydriver.CUDA_ERROR_NOT_FOUND nogil
333+
334+
335+
cdef void _do_batch_prefetch_op(tuple bufs, tuple locs, Stream s, _BatchPrefetchFn fn):
336+
"""Shared body for batched prefetch / discard-and-prefetch."""
327337
cdef Py_ssize_t n = len(bufs)
328338
cdef cydriver.CUstream hstream = as_cu(s._h_stream)
329339
cdef vector[cydriver.CUdeviceptr] ptrs
@@ -343,11 +353,16 @@ cdef void _do_batch_prefetch(tuple bufs, tuple locs, Stream s):
343353
loc_arr[i] = _to_cumemlocation(locs[i])
344354
loc_indices[i] = <size_t>i
345355
with nogil:
346-
HANDLE_RETURN(cydriver.cuMemPrefetchBatchAsync(
356+
HANDLE_RETURN(fn(
347357
ptrs.data(), sizes.data(), <size_t>n,
348358
loc_arr.data(), loc_indices.data(), <size_t>n,
349359
0, hstream,
350360
))
361+
362+
363+
cdef void _do_batch_prefetch(tuple bufs, tuple locs, Stream s):
364+
IF CUDA_CORE_BUILD_MAJOR >= 13:
365+
_do_batch_prefetch_op(bufs, locs, s, cydriver.cuMemPrefetchBatchAsync)
351366
ELSE:
352367
raise NotImplementedError(
353368
"batched prefetch requires a CUDA 13 build of cuda.core"
@@ -400,30 +415,7 @@ def _do_single_discard_prefetch_py(Buffer buf, location, stream):
400415

401416
cdef void _do_batch_discard_prefetch(tuple bufs, tuple locs, Stream s):
402417
IF CUDA_CORE_BUILD_MAJOR >= 13:
403-
cdef Py_ssize_t n = len(bufs)
404-
cdef cydriver.CUstream hstream = as_cu(s._h_stream)
405-
cdef vector[cydriver.CUdeviceptr] ptrs
406-
cdef vector[size_t] sizes
407-
cdef vector[cydriver.CUmemLocation] loc_arr
408-
cdef vector[size_t] loc_indices
409-
ptrs.resize(n)
410-
sizes.resize(n)
411-
loc_arr.resize(n)
412-
loc_indices.resize(n)
413-
cdef Buffer buf
414-
cdef Py_ssize_t i
415-
for i in range(n):
416-
buf = <Buffer>bufs[i]
417-
ptrs[i] = as_cu(buf._h_ptr)
418-
sizes[i] = buf._size
419-
loc_arr[i] = _to_cumemlocation(locs[i])
420-
loc_indices[i] = <size_t>i
421-
with nogil:
422-
HANDLE_RETURN(cydriver.cuMemDiscardAndPrefetchBatchAsync(
423-
ptrs.data(), sizes.data(), <size_t>n,
424-
loc_arr.data(), loc_indices.data(), <size_t>n,
425-
0, hstream,
426-
))
418+
_do_batch_prefetch_op(bufs, locs, s, cydriver.cuMemDiscardAndPrefetchBatchAsync)
427419
ELSE:
428420
raise NotImplementedError(
429421
"discard_prefetch requires a CUDA 13 build of cuda.core"

0 commit comments

Comments
 (0)