@@ -111,6 +111,7 @@ static int average = 1;
111111static int numDevices = 1 ;
112112static int delay_inout_place = 0 ;
113113static int enable_out_of_place = 1 ;
114+ static int enable_in_place = 1 ;
114115static int enable_cache_flush = 0 ;
115116static int enable_rotating_tensor = 0 ;
116117#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)
@@ -410,7 +411,7 @@ testResult_t CheckData(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
410411
411412 int64_t *wrongPerGpu = nullptr ;
412413 CUDACHECK (hipHostMalloc ((void **)&wrongPerGpu, args->nGpus *sizeof (int64_t ), cudaHostAllocMapped));
413-
414+
414415 for (int i=0 ; i<args->nGpus ; i++) {
415416 int rank = ((args->proc *args->nThreads + args->thread )*args->nGpus + i);
416417 CUDACHECK (cudaSetDevice (args->gpus [i]));
@@ -450,14 +451,14 @@ testResult_t CheckData(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
450451 if (args->reportErrors && *wrongElts) args->errors [0 ]++;
451452 return testSuccess;
452453}
453-
454+
454455testResult_t testStreamSynchronize (int ngpus, cudaStream_t* streams, ncclComm_t* comms) {
455456 cudaError_t cudaErr;
456457 int remaining = ngpus;
457458 int * done = (int *)malloc (sizeof (int )*ngpus);
458459 memset (done, 0 , sizeof (int )*ngpus);
459460 timer tim;
460-
461+
461462 while (remaining) {
462463 int idle = 1 ;
463464 for (int i=0 ; i<ngpus; i++) {
@@ -522,7 +523,7 @@ testResult_t startColl(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
522523 size_t steps = totalnbytes ? args->maxbytes / totalnbytes : 1 ;
523524 shift = totalnbytes * (iter % steps);
524525 }
525-
526+
526527 if (args->nGpus > 1 ) NCCLCHECK (ncclGroupStart ());
527528 for (int i = 0 ; i < args->nGpus ; i++) {
528529#ifndef NCCL_MAJOR
@@ -912,7 +913,8 @@ testResult_t TimeTest(struct threadArgs* args, ncclDataType_t type, const char*
912913 TESTCHECK (BenchTime (args, type, op, root, 0 ));
913914 usleep (delay_inout_place);
914915 }
915- TESTCHECK (BenchTime (args, type, op, root, 1 ));
916+ if (enable_in_place)
917+ TESTCHECK (BenchTime (args, type, op, root, 1 ));
916918 PRINT (" \n " );
917919 }
918920 --repeat;
@@ -1206,10 +1208,11 @@ int main(int argc, char* argv[]) {
12061208 break ;
12071209 case ' O' :
12081210 enable_out_of_place = strtol (optarg, NULL , 0 );
1211+ enable_in_place = enable_out_of_place ? 0 : 1 ;
12091212 break ;
12101213 case ' q' :
12111214 delay_inout_place = (int )strtol (optarg, NULL , 10 );
1212- break ;
1215+ break ;
12131216 case ' F' :
12141217 enable_cache_flush = strtol (optarg, NULL , 0 );
12151218 if (enable_cache_flush > 0 ) {
@@ -1500,14 +1503,20 @@ testResult_t run() {
15001503
15011504 const char * timeStr = report_cputime ? " cputime" : " time" ;
15021505 PRINT (" #\n " );
1503- if (enable_out_of_place) {
1506+ if (enable_out_of_place && enable_in_place ) {
15041507 PRINT (" # %10s %12s %8s %6s %6s out-of-place in-place \n " , " " , " " , " " , " " , " " );
15051508 PRINT (" # %10s %12s %8s %6s %6s %7s %6s %6s %6s %7s %6s %6s %6s\n " , " size" , " count" , " type" , " redop" , " root" ,
15061509 timeStr, " algbw" , " busbw" , " #wrong" , timeStr, " algbw" , " busbw" , " #wrong" );
15071510 PRINT (" # %10s %12s %8s %6s %6s %7s %6s %6s %5s %7s %6s %6s %5s\n " , " (B)" , " (elements)" , " " , " " , " " ,
15081511 " (us)" , " (GB/s)" , " (GB/s)" , " " , " (us)" , " (GB/s)" , " (GB/s)" , " " );
1512+ } else if (enable_out_of_place) {
1513+ PRINT (" # %10s %12s %8s %6s %6s out-of-place \n " , " " , " " , " " , " " , " " );
1514+ PRINT (" # %10s %12s %8s %6s %6s %7s %6s %6s %6s\n " , " size" , " count" , " type" , " redop" , " root" ,
1515+ timeStr, " algbw" , " busbw" , " #wrong" );
1516+ PRINT (" # %10s %12s %8s %6s %6s %7s %6s %6s %5s\n " , " (B)" , " (elements)" , " " , " " , " " ,
1517+ " (us)" , " (GB/s)" , " (GB/s)" , " " );
15091518 } else {
1510- PRINT (" # %10s %12s %8s %6s %6s in-place \n " , " " , " " , " " , " " , " " );
1519+ PRINT (" # %10s %12s %8s %6s %6s in-place \n " , " " , " " , " " , " " , " " );
15111520 PRINT (" # %10s %12s %8s %6s %6s %7s %6s %6s %6s\n " , " size" , " count" , " type" , " redop" , " root" ,
15121521 timeStr, " algbw" , " busbw" , " #wrong" );
15131522 PRINT (" # %10s %12s %8s %6s %6s %7s %6s %6s %5s\n " , " (B)" , " (elements)" , " " , " " , " " ,
@@ -1539,6 +1548,7 @@ testResult_t run() {
15391548 threads[t].args .comms =comms+t*nGpus;
15401549 threads[t].args .streams =streams.data ()+t*nGpus;
15411550 threads[t].args .enable_out_of_place =enable_out_of_place;
1551+ threads[t].args .enable_in_place =enable_in_place;
15421552 threads[t].args .enable_cache_flush = enable_cache_flush;
15431553 threads[t].args .enable_rotating_tensor = enable_rotating_tensor;
15441554 threads[t].args .errors =errors.data ()+t;
0 commit comments