@@ -90,6 +90,8 @@ static int report_cputime = 0;
9090// Report average iteration time: (0=RANK0,1=AVG,2=MIN,3=MAX)
9191static int average = 1 ;
9292#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)
93+ #define LOCAL_REGISTER 1
94+ #define SYMMETRIC_REGISTER 2
9395static int local_register = 0 ;
9496#endif
9597static int minCudaArch = 1 <<30 ;
@@ -660,17 +662,33 @@ testResult_t threadInit(struct threadArgs* args) {
660662 void **sendRegHandles = (local_register) ? (void **)malloc (sizeof (*sendRegHandles)*args->nGpus ) : NULL ;
661663 void **recvRegHandles = (local_register) ? (void **)malloc (sizeof (*recvRegHandles)*args->nGpus ) : NULL ;
662664 for (int i=0 ; i<args->nGpus ; i++) {
663- if (local_register) NCCLCHECK (ncclCommRegister (args->comms [i], args->sendbuffs [i], args->maxbytes , &sendRegHandles[i]));
664- if (local_register) NCCLCHECK (ncclCommRegister (args->comms [i], args->recvbuffs [i], args->maxbytes , &recvRegHandles[i]));
665+ #if NCCL_VERSION_CODE >= NCCL_VERSION(2,27,0)
666+ if (test_ncclVersion >= NCCL_VERSION (2 ,27 ,0 ) && (local_register == SYMMETRIC_REGISTER)) {
667+ NCCLCHECK (ncclCommWindowRegister (args->comms [i], args->sendbuffs [i], args->maxbytes , (ncclWindow_t*)&sendRegHandles[i], NCCL_WIN_COLL_SYMMETRIC));
668+ NCCLCHECK (ncclCommWindowRegister (args->comms [i], args->recvbuffs [i], args->maxbytes , (ncclWindow_t*)&recvRegHandles[i], NCCL_WIN_COLL_SYMMETRIC));
669+ } else
670+ #endif
671+ {
672+ if (local_register) NCCLCHECK (ncclCommRegister (args->comms [i], args->sendbuffs [i], args->maxbytes , &sendRegHandles[i]));
673+ if (local_register) NCCLCHECK (ncclCommRegister (args->comms [i], args->recvbuffs [i], args->maxbytes , &recvRegHandles[i]));
674+ }
665675 }
666676#endif
667677
668678 TESTCHECK (threadRunTests (args));
669679
670680 for (int i=0 ; i<args->nGpus ; i++) {
671681#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)
672- if (local_register) NCCLCHECK (ncclCommDeregister (args->comms [i], sendRegHandles[i]));
673- if (local_register) NCCLCHECK (ncclCommDeregister (args->comms [i], recvRegHandles[i]));
682+ #if NCCL_VERSION_CODE >= NCCL_VERSION(2,27,0)
683+ if (test_ncclVersion >= NCCL_VERSION (2 ,27 ,0 ) && (local_register == SYMMETRIC_REGISTER)) {
684+ NCCLCHECK (ncclCommWindowDeregister (args->comms [i], (ncclWindow_t)sendRegHandles[i]));
685+ NCCLCHECK (ncclCommWindowDeregister (args->comms [i], (ncclWindow_t)recvRegHandles[i]));
686+ } else
687+ #endif
688+ {
689+ if (local_register) NCCLCHECK (ncclCommDeregister (args->comms [i], sendRegHandles[i]));
690+ if (local_register) NCCLCHECK (ncclCommDeregister (args->comms [i], recvRegHandles[i]));
691+ }
674692#endif
675693 NCCLCHECK (ncclCommDestroy (args->comms [i]));
676694 }
@@ -859,8 +877,10 @@ int main(int argc, char* argv[]) {
859877 break ;
860878 case ' R' :
861879#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)
862- if ((int )strtol (optarg, NULL , 0 )) {
863- local_register = 1 ;
880+ local_register = (int )strtol (optarg, NULL , 0 );
881+ if (local_register == SYMMETRIC_REGISTER && test_ncclVersion < NCCL_VERSION (2 ,27 ,0 )) {
882+ printf (" Option -R 2 (symmetric) is not supported before NCCL 2.27. Defaulting to local registration\n " );
883+ local_register = LOCAL_REGISTER;
864884 }
865885#else
866886 printf (" Option -R (register) is not supported before NCCL 2.19. Ignoring\n " );
@@ -897,7 +917,7 @@ int main(int argc, char* argv[]) {
897917 " [-G,--cudagraph <num graph launches>] \n\t "
898918 " [-C,--report_cputime <0/1>] \n\t "
899919 " [-a,--average <0/1/2/3> report average iteration time <0=RANK0/1=AVG/2=MIN/3=MAX>] \n\t "
900- " [-R,--local_register <1/0 > enable local buffer registration on send/recv buffers (default: disable)] \n\t "
920+ " [-R,--local_register <0/1/2 > enable local (1) or symmetric (2) buffer registration on send/recv buffers (default: disable (0) )] \n\t "
901921 " [-h,--help]\n " ,
902922 basename (argv[0 ]));
903923 return 0 ;
@@ -1107,8 +1127,16 @@ testResult_t run() {
11071127 sendRegHandles = (local_register) ? (void **)malloc (sizeof (*sendRegHandles)*nThreads*nGpus) : NULL ;
11081128 recvRegHandles = (local_register) ? (void **)malloc (sizeof (*recvRegHandles)*nThreads*nGpus) : NULL ;
11091129 for (int i=0 ; i<nGpus*nThreads; i++) {
1110- if (local_register) NCCLCHECK (ncclCommRegister (comms[i], sendbuffs[i], maxBytes, &sendRegHandles[i]));
1111- if (local_register) NCCLCHECK (ncclCommRegister (comms[i], recvbuffs[i], maxBytes, &recvRegHandles[i]));
1130+ #if NCCL_VERSION_CODE >= NCCL_VERSION(2,27,0)
1131+ if (test_ncclVersion >= NCCL_VERSION (2 ,27 ,0 ) && (local_register == SYMMETRIC_REGISTER)) {
1132+ NCCLCHECK (ncclCommWindowRegister (comms[i], sendbuffs[i], maxBytes, (ncclWindow_t*)&sendRegHandles[i], NCCL_WIN_COLL_SYMMETRIC));
1133+ NCCLCHECK (ncclCommWindowRegister (comms[i], recvbuffs[i], maxBytes, (ncclWindow_t*)&recvRegHandles[i], NCCL_WIN_COLL_SYMMETRIC));
1134+ } else
1135+ #endif
1136+ {
1137+ if (local_register) NCCLCHECK (ncclCommRegister (comms[i], sendbuffs[i], maxBytes, &sendRegHandles[i]));
1138+ if (local_register) NCCLCHECK (ncclCommRegister (comms[i], recvbuffs[i], maxBytes, &recvRegHandles[i]));
1139+ }
11121140 }
11131141#endif
11141142 }
@@ -1188,8 +1216,16 @@ testResult_t run() {
11881216 if (!parallel_init) {
11891217 for (int i=0 ; i<nGpus*nThreads; ++i) {
11901218#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)
1191- if (local_register) NCCLCHECK (ncclCommDeregister (comms[i], sendRegHandles[i]));
1192- if (local_register) NCCLCHECK (ncclCommDeregister (comms[i], recvRegHandles[i]));
1219+ #if NCCL_VERSION_CODE >= NCCL_VERSION(2,27,0)
1220+ if (test_ncclVersion >= NCCL_VERSION (2 ,27 ,0 ) && (local_register == SYMMETRIC_REGISTER)) {
1221+ NCCLCHECK (ncclCommWindowDeregister (comms[i], (ncclWindow_t)sendRegHandles[i]));
1222+ NCCLCHECK (ncclCommWindowDeregister (comms[i], (ncclWindow_t)recvRegHandles[i]));
1223+ } else
1224+ #endif
1225+ {
1226+ if (local_register) NCCLCHECK (ncclCommDeregister (comms[i], sendRegHandles[i]));
1227+ if (local_register) NCCLCHECK (ncclCommDeregister (comms[i], recvRegHandles[i]));
1228+ }
11931229#endif
11941230 NCCLCHECK (ncclCommDestroy (comms[i]));
11951231 }
0 commit comments