|
5 | 5 | from __future__ import annotations |
6 | 6 |
|
7 | 7 | cimport cython |
8 | | -from libc.stdint cimport uintptr_t |
| 8 | +from libc.stdint cimport uint8_t, uint16_t, uint32_t, uintptr_t |
| 9 | +from cpython.buffer cimport PyObject_GetBuffer, PyBuffer_Release, Py_buffer, PyBUF_SIMPLE |
9 | 10 |
|
10 | 11 | from cuda.bindings cimport cydriver |
11 | 12 | from cuda.core.experimental._memory._device_memory_resource import DeviceMemoryResource |
@@ -233,65 +234,27 @@ cdef class Buffer: |
233 | 234 |
|
234 | 235 | """ |
235 | 236 | cdef Stream s_stream = Stream_accept(stream) |
236 | | - cdef unsigned char c_value8 |
237 | | - cdef unsigned short c_value16 |
238 | | - cdef unsigned int c_value32 |
239 | | - cdef size_t N |
240 | | - cdef size_t width |
241 | | - cdef unsigned int int_value |
242 | | - |
243 | | - # Get fill pattern from value |
| 237 | + |
| 238 | + # Handle int case: 1-byte fill with automatic overflow checking. |
244 | 239 | if isinstance(value, int): |
245 | | - # We define the int input to mean a 1-byte pattern. |
246 | | - # Match int.to_bytes(1, "little") behavior: raise OverflowError if not in [0, 256). |
247 | | - if value < 0 or value >= 256: |
248 | | - raise OverflowError("int value must be in range [0, 256)") |
249 | | - width = 1 |
250 | | - int_value = <unsigned int>value |
251 | | - else: |
252 | | - try: |
253 | | - mv = memoryview(value) |
254 | | - except TypeError: |
255 | | - raise TypeError( |
256 | | - f"value must be an int or support the buffer protocol, got {type(value).__name__}" |
257 | | - ) from None |
258 | | - width = mv.nbytes |
259 | | - |
260 | | - # Validate width early to avoid copying/processing large invalid inputs. |
261 | | - if width not in (1, 2, 4): |
262 | | - raise ValueError(f"value must be 1, 2, or 4 bytes, got {width}") |
263 | | - |
264 | | - # Convert to a 1-D view of bytes. |
265 | | - # |
266 | | - # Note: NumPy scalar memoryviews are 0-D, and int.from_bytes(mv, ...) errors with |
267 | | - # "0-dim memory has no length". Casting to 'B' gives us a byte-addressable view. |
268 | | - try: |
269 | | - int_value = int.from_bytes(mv.cast("B"), "little") |
270 | | - except TypeError: |
271 | | - int_value = int.from_bytes(mv.tobytes(), "little") |
272 | | - |
273 | | - # Validate buffer size modulus. |
274 | | - cdef size_t buffer_size = self._size |
275 | | - if buffer_size % width != 0: |
276 | | - raise ValueError(f"buffer size ({buffer_size}) must be divisible by {width}") |
277 | | - |
278 | | - # Perform fill based on width |
279 | | - cdef cydriver.CUstream s = s_stream._handle |
280 | | - if width == 1: |
281 | | - c_value8 = <unsigned char>int_value |
282 | | - N = buffer_size |
283 | | - with nogil: |
284 | | - HANDLE_RETURN(cydriver.cuMemsetD8Async(<cydriver.CUdeviceptr>self._ptr, c_value8, N, s)) |
285 | | - elif width == 2: |
286 | | - c_value16 = <unsigned short>int_value |
287 | | - N = buffer_size // 2 |
288 | | - with nogil: |
289 | | - HANDLE_RETURN(cydriver.cuMemsetD16Async(<cydriver.CUdeviceptr>self._ptr, c_value16, N, s)) |
290 | | - else: # width == 4 |
291 | | - c_value32 = <unsigned int>int_value |
292 | | - N = buffer_size // 4 |
293 | | - with nogil: |
294 | | - HANDLE_RETURN(cydriver.cuMemsetD32Async(<cydriver.CUdeviceptr>self._ptr, c_value32, N, s)) |
| 240 | + Buffer_fill_uint8(self, value, s_stream._handle) |
| 241 | + return |
| 242 | + |
| 243 | + # Handle bytes case: direct pointer access without intermediate objects. |
| 244 | + if isinstance(value, bytes): |
| 245 | + Buffer_fill_from_ptr(self, <const char*><bytes>value, len(value), s_stream._handle) |
| 246 | + return |
| 247 | + |
| 248 | + # General buffer protocol path using C buffer API. |
| 249 | + cdef Py_buffer buf |
| 250 | + if PyObject_GetBuffer(value, &buf, PyBUF_SIMPLE) != 0: |
| 251 | + raise TypeError( |
| 252 | + f"value must be an int or support the buffer protocol, got {type(value).__name__}" |
| 253 | + ) |
| 254 | + try: |
| 255 | + Buffer_fill_from_ptr(self, <const char*>buf.buf, buf.len, s_stream._handle) |
| 256 | + finally: |
| 257 | + PyBuffer_Release(&buf) |
295 | 258 |
|
296 | 259 | def __dlpack__( |
297 | 260 | self, |
@@ -420,6 +383,36 @@ cdef inline void Buffer_close(Buffer self, stream): |
420 | 383 | self._alloc_stream = None |
421 | 384 |
|
422 | 385 |
|
| 386 | +cdef inline void Buffer_fill_uint8(Buffer self, uint8_t value, cydriver.CUstream s): |
| 387 | + with nogil: |
| 388 | + HANDLE_RETURN(cydriver.cuMemsetD8Async(<cydriver.CUdeviceptr>self._ptr, value, self._size, s)) |
| 389 | + |
| 390 | + |
| 391 | +cdef inline void Buffer_fill_from_ptr( |
| 392 | + Buffer self, const char* ptr, size_t width, cydriver.CUstream s |
| 393 | +) except *: |
| 394 | + cdef size_t buffer_size = self._size |
| 395 | + |
| 396 | + if width == 1: |
| 397 | + with nogil: |
| 398 | + HANDLE_RETURN(cydriver.cuMemsetD8Async( |
| 399 | + <cydriver.CUdeviceptr>self._ptr, (<uint8_t*>ptr)[0], buffer_size, s)) |
| 400 | + elif width == 2: |
| 401 | + if buffer_size & 0x1: |
| 402 | + raise ValueError(f"buffer size ({buffer_size}) must be divisible by 2") |
| 403 | + with nogil: |
| 404 | + HANDLE_RETURN(cydriver.cuMemsetD16Async( |
| 405 | + <cydriver.CUdeviceptr>self._ptr, (<uint16_t*>ptr)[0], buffer_size // 2, s)) |
| 406 | + elif width == 4: |
| 407 | + if buffer_size & 0x3: |
| 408 | + raise ValueError(f"buffer size ({buffer_size}) must be divisible by 4") |
| 409 | + with nogil: |
| 410 | + HANDLE_RETURN(cydriver.cuMemsetD32Async( |
| 411 | + <cydriver.CUdeviceptr>self._ptr, (<uint32_t*>ptr)[0], buffer_size // 4, s)) |
| 412 | + else: |
| 413 | + raise ValueError(f"value must be 1, 2, or 4 bytes, got {width}") |
| 414 | + |
| 415 | + |
423 | 416 | cdef Buffer_init_mem_attrs(Buffer self): |
424 | 417 | if not self._mem_attrs_inited: |
425 | 418 | query_memory_attrs(self._mem_attrs, self._ptr) |
|
0 commit comments