Skip to content

Commit 76a8cee

Browse files
authored
BUG: Make nvmlGetFieldValues robust to empty field list (#1982)
* BUG: Make nvmlGetFieldValues robust to empty field list * Pre commit
1 parent 11347ff commit 76a8cee

4 files changed

Lines changed: 26 additions & 0 deletions

File tree

‎cuda_bindings/cuda/bindings/nvml.pyx‎

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27030,6 +27030,11 @@ cpdef object device_get_field_values(intptr_t device, values):
2703027030
cdef FieldValue values_ = _cast_field_values(values)
2703127031
cdef nvmlFieldValue_t *ptr = <nvmlFieldValue_t *>values_._get_ptr()
2703227032
cdef unsigned int valuesCount = len(values)
27033+
27034+
# Passing a valuesCount of 0 to nvmlDeviceGetFieldValues returns NVML_INVALID_ARGUMENT
27035+
if valuesCount == 0:
27036+
return values_
27037+
2703327038
with nogil:
2703427039
__status__ = nvmlDeviceGetFieldValues(<Device>device, valuesCount, ptr)
2703527040
check_status(__status__)
@@ -27050,6 +27055,10 @@ cpdef device_clear_field_values(intptr_t device, values):
2705027055
cdef nvmlFieldValue_t *ptr = <nvmlFieldValue_t *>values_._get_ptr()
2705127056
cdef unsigned int valuesCount = len(values)
2705227057

27058+
# Passing a valuesCount of 0 to nvmlDeviceClearFieldValues returns NVML_INVALID_ARGUMENT
27059+
if valuesCount == 0:
27060+
return values_
27061+
2705327062
with nogil:
2705427063
__status__ = nvmlDeviceClearFieldValues(<Device>device, valuesCount, ptr)
2705527064
check_status(__status__)

‎cuda_bindings/tests/nvml/test_nvlink.py‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ def test_nvlink_get_link_count(all_devices):
1010
Checks that the link count of the device is same.
1111
"""
1212
for device in all_devices:
13+
fields = nvml.FieldValue(0)
14+
assert len(nvml.device_get_field_values(device, fields)) == 0
15+
16+
assert len(nvml.device_get_field_values(device, [])) == 0
17+
1318
fields = nvml.FieldValue(1)
1419
fields[0].field_id = nvml.FieldId.DEV_NVLINK_LINK_COUNT
1520
value = nvml.device_get_field_values(device, fields)[0]

‎cuda_core/cuda/core/system/_device.pyx‎

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,11 @@ cdef class Device:
661661
:obj:`~_device.FieldValues`
662662
Container of field values corresponding to the requested field IDs.
663663
"""
664+
# Passing a field_ids array of length 0 raises an InvalidArgumentError,
665+
# so avoid that.
666+
if len(field_ids) == 0:
667+
return FieldValues(nvml.FieldValue(0))
668+
664669
return FieldValues(nvml.device_get_field_values(self._handle, field_ids))
665670
666671
def clear_field_values(self, field_ids: list[int | tuple[int, int]]) -> None:
@@ -675,6 +680,11 @@ cdef class Device:
675680
Each item may be either a single value from the :class:`FieldId`
676681
enum, or a pair of (:class:`FieldId`, scope ID).
677682
"""
683+
# Passing a field_ids array of length 0 raises an InvalidArgumentError,
684+
# so avoid that.
685+
if len(field_ids) == 0:
686+
return
687+
678688
nvml.device_clear_field_values(self._handle, field_ids)
679689
680690
##########################################################################

‎cuda_core/tests/system/test_system_device.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,8 @@ def test_field_values():
360360
# TODO: Are there any fields that return double's? It would be good to
361361
# test those.
362362

363+
assert len(device.get_field_values([])) == 0
364+
363365
field_ids = [
364366
system.FieldId.DEV_TOTAL_ENERGY_CONSUMPTION,
365367
system.FieldId.DEV_PCIE_COUNT_TX_BYTES,

0 commit comments

Comments
 (0)