Skip to content

Commit a5c539e

Browse files
committed
Add support for Symmetric Memory Registration
From NCCL 2.27.x we can now use the Symmetric Memory APIs (-R 2)
1 parent e041d90 commit a5c539e

File tree

2 files changed

+48
-12
lines changed

2 files changed

+48
-12
lines changed

‎README.md‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ All tests support the same set of arguments :
7878
* `-z,--blocking <0/1>` Make NCCL collective blocking, i.e. have CPUs wait and sync after each collective. Default : 0.
7979
* `-G,--cudagraph <num graph launches>` Capture iterations as a CUDA graph and then replay specified number of times. Default : 0.
8080
* `-C,--report_cputime <0/1>]` Report CPU time instead of latency. Default : 0.
81-
* `-R,--local_register <1/0>` enable local buffer registration on send/recv buffers. Default : 0.
81+
* `-R,--local_register <0/1/2> enable local (1) or symmetric (2) buffer registration on send/recv buffers. Default : 0.
8282
* `-T,--timeout <time in seconds>` timeout each test after specified number of seconds. Default : disabled.
8383

8484
### Running multiple operations in parallel

‎src/common.cu‎

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ static int report_cputime = 0;
9090
// Report average iteration time: (0=RANK0,1=AVG,2=MIN,3=MAX)
9191
static int average = 1;
9292
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)
93+
#define LOCAL_REGISTER 1
94+
#define SYMMETRIC_REGISTER 2
9395
static int local_register = 0;
9496
#endif
9597
static int minCudaArch = 1<<30;
@@ -660,17 +662,33 @@ testResult_t threadInit(struct threadArgs* args) {
660662
void **sendRegHandles = (local_register) ? (void **)malloc(sizeof(*sendRegHandles)*args->nGpus) : NULL;
661663
void **recvRegHandles = (local_register) ? (void **)malloc(sizeof(*recvRegHandles)*args->nGpus) : NULL;
662664
for (int i=0; i<args->nGpus; i++) {
663-
if (local_register) NCCLCHECK(ncclCommRegister(args->comms[i], args->sendbuffs[i], args->maxbytes, &sendRegHandles[i]));
664-
if (local_register) NCCLCHECK(ncclCommRegister(args->comms[i], args->recvbuffs[i], args->maxbytes, &recvRegHandles[i]));
665+
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,27,0)
666+
if (test_ncclVersion >= NCCL_VERSION(2,27,0) && (local_register == SYMMETRIC_REGISTER)) {
667+
NCCLCHECK(ncclCommWindowRegister(args->comms[i], args->sendbuffs[i], args->maxbytes, (ncclWindow_t*)&sendRegHandles[i], NCCL_WIN_COLL_SYMMETRIC));
668+
NCCLCHECK(ncclCommWindowRegister(args->comms[i], args->recvbuffs[i], args->maxbytes, (ncclWindow_t*)&recvRegHandles[i], NCCL_WIN_COLL_SYMMETRIC));
669+
} else
670+
#endif
671+
{
672+
if (local_register) NCCLCHECK(ncclCommRegister(args->comms[i], args->sendbuffs[i], args->maxbytes, &sendRegHandles[i]));
673+
if (local_register) NCCLCHECK(ncclCommRegister(args->comms[i], args->recvbuffs[i], args->maxbytes, &recvRegHandles[i]));
674+
}
665675
}
666676
#endif
667677

668678
TESTCHECK(threadRunTests(args));
669679

670680
for (int i=0; i<args->nGpus; i++) {
671681
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)
672-
if (local_register) NCCLCHECK(ncclCommDeregister(args->comms[i], sendRegHandles[i]));
673-
if (local_register) NCCLCHECK(ncclCommDeregister(args->comms[i], recvRegHandles[i]));
682+
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,27,0)
683+
if (test_ncclVersion >= NCCL_VERSION(2,27,0) && (local_register == SYMMETRIC_REGISTER)) {
684+
NCCLCHECK(ncclCommWindowDeregister(args->comms[i], (ncclWindow_t)sendRegHandles[i]));
685+
NCCLCHECK(ncclCommWindowDeregister(args->comms[i], (ncclWindow_t)recvRegHandles[i]));
686+
} else
687+
#endif
688+
{
689+
if (local_register) NCCLCHECK(ncclCommDeregister(args->comms[i], sendRegHandles[i]));
690+
if (local_register) NCCLCHECK(ncclCommDeregister(args->comms[i], recvRegHandles[i]));
691+
}
674692
#endif
675693
NCCLCHECK(ncclCommDestroy(args->comms[i]));
676694
}
@@ -859,8 +877,10 @@ int main(int argc, char* argv[]) {
859877
break;
860878
case 'R':
861879
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)
862-
if ((int)strtol(optarg, NULL, 0)) {
863-
local_register = 1;
880+
local_register = (int)strtol(optarg, NULL, 0);
881+
if (local_register == SYMMETRIC_REGISTER && test_ncclVersion < NCCL_VERSION(2,27,0)) {
882+
printf("Option -R 2 (symmetric) is not supported before NCCL 2.27. Defaulting to local registration\n");
883+
local_register = LOCAL_REGISTER;
864884
}
865885
#else
866886
printf("Option -R (register) is not supported before NCCL 2.19. Ignoring\n");
@@ -897,7 +917,7 @@ int main(int argc, char* argv[]) {
897917
"[-G,--cudagraph <num graph launches>] \n\t"
898918
"[-C,--report_cputime <0/1>] \n\t"
899919
"[-a,--average <0/1/2/3> report average iteration time <0=RANK0/1=AVG/2=MIN/3=MAX>] \n\t"
900-
"[-R,--local_register <1/0> enable local buffer registration on send/recv buffers (default: disable)] \n\t"
920+
"[-R,--local_register <0/1/2> enable local (1) or symmetric (2) buffer registration on send/recv buffers (default: disable (0))] \n\t"
901921
"[-h,--help]\n",
902922
basename(argv[0]));
903923
return 0;
@@ -1107,8 +1127,16 @@ testResult_t run() {
11071127
sendRegHandles = (local_register) ? (void **)malloc(sizeof(*sendRegHandles)*nThreads*nGpus) : NULL;
11081128
recvRegHandles = (local_register) ? (void **)malloc(sizeof(*recvRegHandles)*nThreads*nGpus) : NULL;
11091129
for (int i=0; i<nGpus*nThreads; i++) {
1110-
if (local_register) NCCLCHECK(ncclCommRegister(comms[i], sendbuffs[i], maxBytes, &sendRegHandles[i]));
1111-
if (local_register) NCCLCHECK(ncclCommRegister(comms[i], recvbuffs[i], maxBytes, &recvRegHandles[i]));
1130+
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,27,0)
1131+
if (test_ncclVersion >= NCCL_VERSION(2,27,0) && (local_register == SYMMETRIC_REGISTER)) {
1132+
NCCLCHECK(ncclCommWindowRegister(comms[i], sendbuffs[i], maxBytes, (ncclWindow_t*)&sendRegHandles[i], NCCL_WIN_COLL_SYMMETRIC));
1133+
NCCLCHECK(ncclCommWindowRegister(comms[i], recvbuffs[i], maxBytes, (ncclWindow_t*)&recvRegHandles[i], NCCL_WIN_COLL_SYMMETRIC));
1134+
} else
1135+
#endif
1136+
{
1137+
if (local_register) NCCLCHECK(ncclCommRegister(comms[i], sendbuffs[i], maxBytes, &sendRegHandles[i]));
1138+
if (local_register) NCCLCHECK(ncclCommRegister(comms[i], recvbuffs[i], maxBytes, &recvRegHandles[i]));
1139+
}
11121140
}
11131141
#endif
11141142
}
@@ -1188,8 +1216,16 @@ testResult_t run() {
11881216
if (!parallel_init) {
11891217
for(int i=0; i<nGpus*nThreads; ++i) {
11901218
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)
1191-
if (local_register) NCCLCHECK(ncclCommDeregister(comms[i], sendRegHandles[i]));
1192-
if (local_register) NCCLCHECK(ncclCommDeregister(comms[i], recvRegHandles[i]));
1219+
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,27,0)
1220+
if (test_ncclVersion >= NCCL_VERSION(2,27,0) && (local_register == SYMMETRIC_REGISTER)) {
1221+
NCCLCHECK(ncclCommWindowDeregister(comms[i], (ncclWindow_t)sendRegHandles[i]));
1222+
NCCLCHECK(ncclCommWindowDeregister(comms[i], (ncclWindow_t)recvRegHandles[i]));
1223+
} else
1224+
#endif
1225+
{
1226+
if (local_register) NCCLCHECK(ncclCommDeregister(comms[i], sendRegHandles[i]));
1227+
if (local_register) NCCLCHECK(ncclCommDeregister(comms[i], recvRegHandles[i]));
1228+
}
11931229
#endif
11941230
NCCLCHECK(ncclCommDestroy(comms[i]));
11951231
}

0 commit comments

Comments
 (0)