@@ -659,6 +659,7 @@ testResult_t threadInit(struct threadArgs* args) {
659659 }
660660 NCCLCHECK (ncclGroupEnd ());
661661#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)
662+ NCCLCHECK (ncclGroupStart ());
662663 void **sendRegHandles = (local_register) ? (void **)malloc (sizeof (*sendRegHandles)*args->nGpus ) : NULL ;
663664 void **recvRegHandles = (local_register) ? (void **)malloc (sizeof (*recvRegHandles)*args->nGpus ) : NULL ;
664665 for (int i=0 ; i<args->nGpus ; i++) {
@@ -673,6 +674,7 @@ testResult_t threadInit(struct threadArgs* args) {
673674 if (local_register) NCCLCHECK (ncclCommRegister (args->comms [i], args->recvbuffs [i], args->maxbytes , &recvRegHandles[i]));
674675 }
675676 }
677+ NCCLCHECK (ncclGroupEnd ());
676678#endif
677679
678680 TESTCHECK (threadRunTests (args));
@@ -1124,6 +1126,7 @@ testResult_t run() {
11241126 NCCLCHECK (ncclGroupEnd ());
11251127 }
11261128#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)
1129+ NCCLCHECK (ncclGroupStart ());
11271130 sendRegHandles = (local_register) ? (void **)malloc (sizeof (*sendRegHandles)*nThreads*nGpus) : NULL ;
11281131 recvRegHandles = (local_register) ? (void **)malloc (sizeof (*recvRegHandles)*nThreads*nGpus) : NULL ;
11291132 for (int i=0 ; i<nGpus*nThreads; i++) {
@@ -1138,6 +1141,7 @@ testResult_t run() {
11381141 if (local_register) NCCLCHECK (ncclCommRegister (comms[i], recvbuffs[i], maxBytes, &recvRegHandles[i]));
11391142 }
11401143 }
1144+ NCCLCHECK (ncclGroupEnd ());
11411145#endif
11421146 }
11431147
0 commit comments