Skip to content

Commit 6b017b7

Browse files
leofangclaude
andcommitted
Address review: flip strides assertion, add _arr_dtype, merge main
Per @rwgk's review: - Flip strides check to branch on view.strides (all 3 _check_view) - Add _arr_dtype helper using __cuda_array_interface__["typestr"] for torch tensors, restore dtype assertion in CAI _check_view - Merge main to pick up #1998 (numba flags fix) Verified locally: 76/76 tests pass across all three test classes. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 95a41aa commit 6b017b7

1 file changed

Lines changed: 13 additions & 6 deletions

File tree

‎cuda_core/tests/test_utils.py‎

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ def _arr_is_writeable(arr):
111111
return arr.flags.writeable if hasattr(arr.flags, "writeable") else True
112112

113113

114+
def _arr_dtype(arr):
115+
if torch is not None and isinstance(arr, torch.Tensor):
116+
return np.dtype(arr.__cuda_array_interface__["typestr"])
117+
return arr.dtype
118+
119+
114120
def _cpu_array_samples():
115121
samples = [
116122
np.empty(3, dtype=np.int32),
@@ -171,8 +177,8 @@ def _check_view(self, view, in_arr):
171177
assert view.shape == expected_shape
172178
assert view.size == _arr_size(in_arr)
173179
strides_in_counts = _arr_strides_in_counts(in_arr)
174-
if _arr_is_c_contiguous(in_arr):
175-
assert view.strides in (None, strides_in_counts)
180+
if view.strides is None:
181+
assert _arr_is_c_contiguous(in_arr)
176182
else:
177183
assert view.strides == strides_in_counts
178184
assert view.device_id == -1
@@ -280,8 +286,8 @@ def _check_view(self, view, in_arr, dev):
280286
assert view.shape == expected_shape
281287
assert view.size == _arr_size(in_arr)
282288
strides_in_counts = _arr_strides_in_counts(in_arr)
283-
if _arr_is_c_contiguous(in_arr):
284-
assert view.strides in (None, strides_in_counts)
289+
if view.strides is None:
290+
assert _arr_is_c_contiguous(in_arr)
285291
else:
286292
assert view.strides == strides_in_counts
287293
assert view.device_id == dev.device_id
@@ -351,10 +357,11 @@ def _check_view(self, view, in_arr, dev):
351357
assert view.shape == expected_shape
352358
assert view.size == _arr_size(in_arr)
353359
strides_in_counts = _arr_strides_in_counts(in_arr)
354-
if _arr_is_c_contiguous(in_arr):
355-
assert view.strides in (None, strides_in_counts)
360+
if view.strides is None:
361+
assert _arr_is_c_contiguous(in_arr)
356362
else:
357363
assert view.strides == strides_in_counts
364+
assert view.dtype == _arr_dtype(in_arr)
358365
assert view.device_id == dev.device_id
359366
assert view.is_device_accessible is True
360367
assert view.exporting_obj is in_arr

0 commit comments

Comments
 (0)