Skip to content

Commit 5e838ad

Browse files
mberenjkMarzieh Berenjkoub
andauthored
skipping the prod test for FP8 types in reduce and reduce-scatter (ROCm#111)
* skipping the prod test for FP8 types in reduce and reduce-scatter --------- Co-authored-by: Marzieh Berenjkoub <mberenjk@amd.com>
1 parent 284ff2a commit 5e838ad

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

‎src/all_reduce.cu‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@ testResult_t AllReduceRunTest(struct threadArgs* args, int root, ncclDataType_t
6565
ncclRedOp_t *run_ops;
6666
const char **run_typenames, **run_opnames;
6767
int type_count, op_count;
68-
if((type == ncclFp8E4M3 || type == ncclFp8E5M2) && op == ncclProd)
69-
return testSuccess;
7068

7169
if ((int)type != -1) {
7270
type_count = 1;
@@ -90,8 +88,10 @@ testResult_t AllReduceRunTest(struct threadArgs* args, int root, ncclDataType_t
9088

9189
for (int i=0; i<type_count; i++) {
9290
for (int j=0; j<op_count; j++) {
93-
if((i == ncclFp8E4M3 || i == ncclFp8E5M2) && j == ncclProd)
91+
#if defined(RCCL_FLOAT8)
92+
if((run_types[i] == ncclFp8E4M3 || run_types[i] == ncclFp8E5M2) && run_ops[j] == ncclProd)
9493
continue;
94+
#endif
9595
TESTCHECK(TimeTest(args, run_types[i], run_typenames[i], run_ops[j], run_opnames[j], -1));
9696
}
9797
}

‎src/reduce.cu‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ testResult_t ReduceRunTest(struct threadArgs* args, int root, ncclDataType_t typ
9595

9696
for (int i=0; i<type_count; i++) {
9797
for (int j=0; j<op_count; j++) {
98+
#if defined(RCCL_FLOAT8)
99+
if((run_types[i] == ncclFp8E4M3 || run_types[i] == ncclFp8E5M2) && run_ops[j] == ncclProd)
100+
continue;
101+
#endif
98102
for (int k=begin_root; k<=end_root; k++) {
99103
TESTCHECK(TimeTest(args, run_types[i], run_typenames[i], run_ops[j], run_opnames[j], k));
100104
}

‎src/reduce_scatter.cu‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ testResult_t ReduceScatterRunTest(struct threadArgs* args, int root, ncclDataTyp
9292

9393
for (int i=0; i<type_count; i++) {
9494
for (int j=0; j<op_count; j++) {
95+
#if defined(RCCL_FLOAT8)
96+
if((run_types[i] == ncclFp8E4M3 || run_types[i] == ncclFp8E5M2) && run_ops[j] == ncclProd)
97+
continue;
98+
#endif
9599
TESTCHECK(TimeTest(args, run_types[i], run_typenames[i], run_ops[j], run_opnames[j], -1));
96100
}
97101
}

0 commit comments

Comments
 (0)