Skip to content

Commit 83eaec1

Browse files
authored
Optimize Buffer.fill() to avoid intermediate object creation (NVIDIA#1376)
Use Cython typed parameters and C buffer API to eliminate overhead: - int path: uint8_t parameter for automatic overflow checking - bytes path: direct char* pointer access - general buffer path: PyObject_GetBuffer for direct void* access
1 parent 96a3f50 commit 83eaec1

1 file changed

Lines changed: 52 additions & 59 deletions

File tree

‎cuda_core/cuda/core/experimental/_memory/_buffer.pyx‎

Lines changed: 52 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from __future__ import annotations
66

77
cimport cython
8-
from libc.stdint cimport uintptr_t
8+
from libc.stdint cimport uint8_t, uint16_t, uint32_t, uintptr_t
9+
from cpython.buffer cimport PyObject_GetBuffer, PyBuffer_Release, Py_buffer, PyBUF_SIMPLE
910

1011
from cuda.bindings cimport cydriver
1112
from cuda.core.experimental._memory._device_memory_resource import DeviceMemoryResource
@@ -233,65 +234,27 @@ cdef class Buffer:
233234
234235
"""
235236
cdef Stream s_stream = Stream_accept(stream)
236-
cdef unsigned char c_value8
237-
cdef unsigned short c_value16
238-
cdef unsigned int c_value32
239-
cdef size_t N
240-
cdef size_t width
241-
cdef unsigned int int_value
242-
243-
# Get fill pattern from value
237+
238+
# Handle int case: 1-byte fill with automatic overflow checking.
244239
if isinstance(value, int):
245-
# We define the int input to mean a 1-byte pattern.
246-
# Match int.to_bytes(1, "little") behavior: raise OverflowError if not in [0, 256).
247-
if value < 0 or value >= 256:
248-
raise OverflowError("int value must be in range [0, 256)")
249-
width = 1
250-
int_value = <unsigned int>value
251-
else:
252-
try:
253-
mv = memoryview(value)
254-
except TypeError:
255-
raise TypeError(
256-
f"value must be an int or support the buffer protocol, got {type(value).__name__}"
257-
) from None
258-
width = mv.nbytes
259-
260-
# Validate width early to avoid copying/processing large invalid inputs.
261-
if width not in (1, 2, 4):
262-
raise ValueError(f"value must be 1, 2, or 4 bytes, got {width}")
263-
264-
# Convert to a 1-D view of bytes.
265-
#
266-
# Note: NumPy scalar memoryviews are 0-D, and int.from_bytes(mv, ...) errors with
267-
# "0-dim memory has no length". Casting to 'B' gives us a byte-addressable view.
268-
try:
269-
int_value = int.from_bytes(mv.cast("B"), "little")
270-
except TypeError:
271-
int_value = int.from_bytes(mv.tobytes(), "little")
272-
273-
# Validate buffer size modulus.
274-
cdef size_t buffer_size = self._size
275-
if buffer_size % width != 0:
276-
raise ValueError(f"buffer size ({buffer_size}) must be divisible by {width}")
277-
278-
# Perform fill based on width
279-
cdef cydriver.CUstream s = s_stream._handle
280-
if width == 1:
281-
c_value8 = <unsigned char>int_value
282-
N = buffer_size
283-
with nogil:
284-
HANDLE_RETURN(cydriver.cuMemsetD8Async(<cydriver.CUdeviceptr>self._ptr, c_value8, N, s))
285-
elif width == 2:
286-
c_value16 = <unsigned short>int_value
287-
N = buffer_size // 2
288-
with nogil:
289-
HANDLE_RETURN(cydriver.cuMemsetD16Async(<cydriver.CUdeviceptr>self._ptr, c_value16, N, s))
290-
else: # width == 4
291-
c_value32 = <unsigned int>int_value
292-
N = buffer_size // 4
293-
with nogil:
294-
HANDLE_RETURN(cydriver.cuMemsetD32Async(<cydriver.CUdeviceptr>self._ptr, c_value32, N, s))
240+
Buffer_fill_uint8(self, value, s_stream._handle)
241+
return
242+
243+
# Handle bytes case: direct pointer access without intermediate objects.
244+
if isinstance(value, bytes):
245+
Buffer_fill_from_ptr(self, <const char*><bytes>value, len(value), s_stream._handle)
246+
return
247+
248+
# General buffer protocol path using C buffer API.
249+
cdef Py_buffer buf
250+
if PyObject_GetBuffer(value, &buf, PyBUF_SIMPLE) != 0:
251+
raise TypeError(
252+
f"value must be an int or support the buffer protocol, got {type(value).__name__}"
253+
)
254+
try:
255+
Buffer_fill_from_ptr(self, <const char*>buf.buf, buf.len, s_stream._handle)
256+
finally:
257+
PyBuffer_Release(&buf)
295258

296259
def __dlpack__(
297260
self,
@@ -420,6 +383,36 @@ cdef inline void Buffer_close(Buffer self, stream):
420383
self._alloc_stream = None
421384

422385

386+
cdef inline void Buffer_fill_uint8(Buffer self, uint8_t value, cydriver.CUstream s):
387+
with nogil:
388+
HANDLE_RETURN(cydriver.cuMemsetD8Async(<cydriver.CUdeviceptr>self._ptr, value, self._size, s))
389+
390+
391+
cdef inline void Buffer_fill_from_ptr(
392+
Buffer self, const char* ptr, size_t width, cydriver.CUstream s
393+
) except *:
394+
cdef size_t buffer_size = self._size
395+
396+
if width == 1:
397+
with nogil:
398+
HANDLE_RETURN(cydriver.cuMemsetD8Async(
399+
<cydriver.CUdeviceptr>self._ptr, (<uint8_t*>ptr)[0], buffer_size, s))
400+
elif width == 2:
401+
if buffer_size & 0x1:
402+
raise ValueError(f"buffer size ({buffer_size}) must be divisible by 2")
403+
with nogil:
404+
HANDLE_RETURN(cydriver.cuMemsetD16Async(
405+
<cydriver.CUdeviceptr>self._ptr, (<uint16_t*>ptr)[0], buffer_size // 2, s))
406+
elif width == 4:
407+
if buffer_size & 0x3:
408+
raise ValueError(f"buffer size ({buffer_size}) must be divisible by 4")
409+
with nogil:
410+
HANDLE_RETURN(cydriver.cuMemsetD32Async(
411+
<cydriver.CUdeviceptr>self._ptr, (<uint32_t*>ptr)[0], buffer_size // 4, s))
412+
else:
413+
raise ValueError(f"value must be 1, 2, or 4 bytes, got {width}")
414+
415+
423416
cdef Buffer_init_mem_attrs(Buffer self):
424417
if not self._mem_attrs_inited:
425418
query_memory_attrs(self._mem_attrs, self._ptr)

0 commit comments

Comments
 (0)