2222#include < vector>
2323#include < utility>
2424#include < errno.h> /* program_invocation_short_name */
25-
25+ # include < dlfcn.h >
2626// #define DEBUG_PRINT
2727
2828#include " verifiable.h"
@@ -35,6 +35,24 @@ int test_ncclVersion = 0; // init'd with ncclGetVersion()
3535int32_t gpu_block3;
3636size_t cache_bytes = 192 * 1024 * 1024 ; // Use 192MB
3737
38+ rcclTestsGetAlgoInfo_t rcclTestsGetAlgoInfo = NULL ;
39+ rcclTestsGetProtocolName_t rcclTestsGetProtocolName = NULL ;
40+ rcclTestsGetAlgoName_t rcclTestsGetAlgoName= NULL ;
41+ static void loadRcclSyms () {
42+ static void * handle = NULL ;
43+ const char * libname = " librccl.so" ;
44+ if (!handle) {
45+ handle = dlopen (libname, RTLD_LAZY | RTLD_LOCAL);
46+ if (!handle) {
47+ fprintf (stderr, " dlopen failed: %s\n " , dlerror ());
48+ return ;
49+ }
50+ }
51+ rcclTestsGetAlgoInfo = (rcclTestsGetAlgoInfo_t) dlsym (handle, " rcclGetAlgoInfo" );
52+ rcclTestsGetAlgoName = (rcclTestsGetAlgoName_t) dlsym (handle, " rcclGetAlgoName" );
53+ rcclTestsGetProtocolName = (rcclTestsGetProtocolName_t) dlsym (handle, " rcclGetProtocolName" );
54+ }
55+
3856// RCCL_FLOAT8 support
3957bool rccl_float8_useFnuz = false ;
4058bool IsArchMatch (char const * arch, char const * target) {
@@ -109,6 +127,7 @@ static int nccltype = ncclFloat;
109127static int ncclroot = 0 ;
110128static int parallel_init = 0 ;
111129static int blocking_coll = 0 ;
130+ static int output_algo_proto_channels = 0 ;
112131static int memorytype = 0 ;
113132static uint32_t cumask[4 ];
114133static int streamnull = 0 ;
@@ -944,8 +963,21 @@ testResult_t TimeTest(struct threadArgs* args, ncclDataType_t type, const char*
944963 TESTCHECK (BenchTime (args, type, op, root, 0 ));
945964 usleep (delay_inout_place);
946965 }
947- if (enable_in_place)
966+ if (enable_in_place)
948967 TESTCHECK (BenchTime (args, type, op, root, 1 ));
968+ if (output_algo_proto_channels) {
969+ if (args->collTest ->getAlgoProtoChannels ) {
970+ int algo, proto, nchannels;
971+ const char * algoName = NULL ;
972+ const char * protoName = NULL ;
973+ TESTCHECK (args->collTest ->getAlgoProtoChannels (args->comms [0 ], args->nbytes / wordSize (type), type, &algo, &proto, &nchannels));
974+ NCCLCHECK (rcclTestsGetAlgoName (algo, &algoName));
975+ NCCLCHECK (rcclTestsGetProtocolName (proto, &protoName));
976+ PRINT (" %8s %8s %10d" , algoName, protoName, nchannels);
977+ } else {
978+ PRINT (" %8s %8s %10s" ," N/A" , " N/A" , " N/A" );
979+ }
980+ }
949981 PRINT (" \n " );
950982 }
951983 --repeat;
@@ -1108,7 +1140,7 @@ int main(int argc, char* argv[]) {
11081140 }
11091141 #endif
11101142 #endif
1111-
1143+ loadRcclSyms ();
11121144 // Parse args
11131145 double parsed;
11141146 int longindex;
@@ -1135,22 +1167,23 @@ int main(int argc, char* argv[]) {
11351167 {" report_cputime" , required_argument, 0 , ' C' },
11361168 {" average" , required_argument, 0 , ' a' },
11371169 {" local_register" , required_argument, 0 , ' R' },
1138- {" memory_type" , required_argument, 0 , ' y' }, // RCCL
1139- {" cumask" , required_argument, 0 , ' u' }, // RCCL
1140- {" out_of_place" , required_argument, 0 , ' O' }, // RCCL
1141- {" delay_inout_place" , required_argument, 0 , ' q' }, // RCCL
1142- {" cache_flush" , required_argument, 0 , ' F' }, // RCCL
1143- {" rotating_tensor" , required_argument, 0 , ' E' }, // RCCL
1144- {" output_file" , required_argument, 0 , ' x' }, // RCCL
1145- {" output_format" , required_argument, 0 , ' Z' }, // RCCL
1170+ {" memory_type" , required_argument, 0 , ' y' }, // RCCL
1171+ {" cumask" , required_argument, 0 , ' u' }, // RCCL
1172+ {" out_of_place" , required_argument, 0 , ' O' }, // RCCL
1173+ {" delay_inout_place" , required_argument, 0 , ' q' }, // RCCL
1174+ {" cache_flush" , required_argument, 0 , ' F' }, // RCCL
1175+ {" rotating_tensor" , required_argument, 0 , ' E' }, // RCCL
1176+ {" output_file" , required_argument, 0 , ' x' }, // RCCL
1177+ {" output_format" , required_argument, 0 , ' Z' }, // RCCL
1178+ {" output_algo_proto_channels" , required_argument, 0 , ' M' }, // RCCL
11461179 {" help" , no_argument, 0 , ' h' },
11471180 {}
11481181 };
11491182
11501183 while (1 ) {
11511184 int c;
11521185
1153- c = getopt_long (argc, argv, " t:g:b:e:i:f:n:m:w:N:p:c:o:d:r:z:y:T:G:C:a:R:Y:u:O:q:F:E:x:Z:h" , longopts, &longindex);
1186+ c = getopt_long (argc, argv, " t:g:b:e:i:f:n:m:w:N:p:c:o:d:r:z:y:T:G:C:a:R:Y:u:O:q:F:E:x:Z:M: h" , longopts, &longindex);
11541187
11551188 if (c == -1 )
11561189 break ;
@@ -1290,6 +1323,10 @@ int main(int argc, char* argv[]) {
12901323 case ' Z' :
12911324 output_format = optarg;
12921325 break ;
1326+ case ' M' :
1327+ output_algo_proto_channels = strtol (optarg, NULL , 0 );
1328+ if (rcclTestsGetAlgoInfo == NULL || rcclTestsGetAlgoName == NULL || rcclTestsGetProtocolName == NULL ) output_algo_proto_channels = 0 ;
1329+ break ;
12931330 case ' h' :
12941331 default :
12951332 if (c != ' h' ) printf (" invalid option '%c'\n " , c);
@@ -1607,27 +1644,39 @@ testResult_t run() {
16071644 }
16081645
16091646 fflush (stdout);
1610-
1647+ const char * extra_col_str[3 ] = {" " , " " , " " };
1648+ if (output_algo_proto_channels) {
1649+ extra_col_str[0 ] = " algo" ;
1650+ extra_col_str[1 ] = " proto" ;
1651+ extra_col_str[2 ] = " nchannels" ;
1652+ }
1653+ const char * header_col_str[3 ] = {" out-of-place in-place " ,
1654+ " out-of-place " ," in-place " };
1655+ int header_index =(enable_out_of_place && enable_in_place) ? 0 : (enable_out_of_place ? 1 : 2 );
16111656 const char * timeStr = report_cputime ? " cputime" : " time" ;
1657+
16121658 PRINT (" #\n " );
1659+ PRINT (" # %10s %12s %8s %6s %6s%s\n " , " " , " " , " " , " " , " " , header_col_str[header_index]);
16131660 if (enable_out_of_place && enable_in_place) {
1614- PRINT (" # %10s %12s %8s %6s %6s out-of-place in-place \n " , " " , " " , " " , " " , " " );
1615- PRINT (" # %10s %12s %8s %6s %6s %7s %6s %6s %6s %7s %6s %6s %6s\n " , " size" , " count" , " type" , " redop" , " root" ,
1616- timeStr, " algbw" , " busbw" , " #wrong" , timeStr, " algbw" , " busbw" , " #wrong" );
1617- PRINT (" # %10s %12s %8s %6s %6s %7s %6s %6s %5s %7s %6s %6s %5s\n " , " (B)" , " (elements)" , " " , " " , " " ,
1618- " (us)" , " (GB/s)" , " (GB/s)" , " " , " (us)" , " (GB/s)" , " (GB/s)" , " " );
1619- } else if (enable_out_of_place) {
1620- PRINT (" # %10s %12s %8s %6s %6s out-of-place \n " , " " , " " , " " , " " , " " );
1621- PRINT (" # %10s %12s %8s %6s %6s %7s %6s %6s %6s\n " , " size" , " count" , " type" , " redop" , " root" ,
1622- timeStr, " algbw" , " busbw" , " #wrong" );
1623- PRINT (" # %10s %12s %8s %6s %6s %7s %6s %6s %5s\n " , " (B)" , " (elements)" , " " , " " , " " ,
1624- " (us)" , " (GB/s)" , " (GB/s)" , " " );
1661+ PRINT (" # %10s %12s %8s %6s %6s %7s %6s %6s %6s %7s %6s %6s %6s %8s %8s %10s\n " ,
1662+ " size" , " count" , " type" , " redop" , " root" ,
1663+ timeStr, " algbw" , " busbw" , " #wrong" ,
1664+ timeStr, " algbw" , " busbw" , " #wrong" ,
1665+ extra_col_str[0 ], extra_col_str[1 ], extra_col_str[2 ]);
1666+ PRINT (" # %10s %12s %8s %6s %6s %7s %6s %6s %5s %7s %6s %6s %5s %8s %8s %10s\n " ,
1667+ " (B)" , " (elements)" , " " , " " , " " ,
1668+ " (us)" , " (GB/s)" , " (GB/s)" , " " ,
1669+ " (us)" , " (GB/s)" , " (GB/s)" , " " ,
1670+ " " , " " , " " );
16251671 } else {
1626- PRINT (" # %10s %12s %8s %6s %6s in-place \n " , " " , " " , " " , " " , " " );
1627- PRINT (" # %10s %12s %8s %6s %6s %7s %6s %6s %6s\n " , " size" , " count" , " type" , " redop" , " root" ,
1628- timeStr, " algbw" , " busbw" , " #wrong" );
1629- PRINT (" # %10s %12s %8s %6s %6s %7s %6s %6s %5s\n " , " (B)" , " (elements)" , " " , " " , " " ,
1630- " (us)" , " (GB/s)" , " (GB/s)" , " " );
1672+ PRINT (" # %10s %12s %8s %6s %6s %7s %6s %6s %6s %8s %8s %10s\n " ,
1673+ " size" , " count" , " type" , " redop" , " root" ,
1674+ timeStr, " algbw" , " busbw" , " #wrong" ,
1675+ extra_col_str[0 ], extra_col_str[1 ], extra_col_str[2 ]);
1676+ PRINT (" # %10s %12s %8s %6s %6s %7s %6s %6s %5s %8s %8s %10s\n " ,
1677+ " (B)" , " (elements)" , " " , " " , " " ,
1678+ " (us)" , " (GB/s)" , " (GB/s)" , " " ,
1679+ " " , " " , " " );
16311680 }
16321681 Reporter reporter (output_file, output_format);
16331682
0 commit comments