Skip to content

Commit 645be0e

Browse files
[Common] Use NCCL API to allocate/free memory (ROCm#144)
1 parent a9b1ce0 commit 645be0e

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

‎src/common.cu‎

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,9 +1040,15 @@ testResult_t AllocateBuffs(void **sendbuff, size_t sendBytes, void **recvbuff, s
10401040
#endif
10411041
}
10421042
else {
1043+
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)
1044+
NCCLCHECK(ncclMemAlloc(sendbuff, nbytes));
1045+
NCCLCHECK(ncclMemAlloc(recvbuff, nbytes));
1046+
if (datacheck) NCCLCHECK(ncclMemAlloc(expected, recvBytes));
1047+
#else
10431048
CUDACHECK(cudaMalloc(sendbuff, nbytes));
10441049
CUDACHECK(cudaMalloc(recvbuff, nbytes));
10451050
if (datacheck) CUDACHECK(cudaMalloc(expected, recvBytes));
1051+
#endif
10461052
}
10471053
CUDACHECK(hipMemset(*sendbuff, 1, nbytes));
10481054
if (datacheck) CUDACHECK(hipMemset(*expected, 1, recvBytes));
@@ -1676,9 +1682,15 @@ testResult_t run() {
16761682

16771683
// Free off CUDA allocated memory
16781684
for (int i=0; i<nGpus*nThreads; i++) {
1685+
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)
1686+
if (sendbuffs[i]) NCCLCHECK(ncclMemFree((char*)sendbuffs[i]));
1687+
if (recvbuffs[i]) NCCLCHECK(ncclMemFree((char*)recvbuffs[i]));
1688+
if (datacheck) NCCLCHECK(ncclMemFree(expected[i]));
1689+
#else
16791690
if (sendbuffs[i]) CUDACHECK(cudaFree((char*)sendbuffs[i]));
16801691
if (recvbuffs[i]) CUDACHECK(cudaFree((char*)recvbuffs[i]));
16811692
if (datacheck) CUDACHECK(cudaFree(expected[i]));
1693+
#endif
16821694
}
16831695
CUDACHECK(cudaFreeHost(delta));
16841696
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)

0 commit comments

Comments
 (0)