Skip to content

Commit ae3e635

Browse files
authored
Scaling tests to #ngpus (ROCm#81)
* scaling tests to #ngpus Signed-off-by: AtlantaPepsi <hyj1999110@gmail.com> * switching to rocminfo --------- Signed-off-by: AtlantaPepsi <hyj1999110@gmail.com>
1 parent 52aee69 commit ae3e635

File tree

5 files changed

+60
-10
lines changed

5 files changed

+60
-10
lines changed

‎test/test_AllGather.py‎

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,22 @@
2222
import os
2323
import subprocess
2424
import itertools
25+
import math
2526

2627
import pytest
2728

29+
ngpus = 0
30+
if os.environ.get('ROCR_VISIBLE_DEVICES') is not None:
31+
ngpus = len(os.environ['ROCR_VISIBLE_DEVICES'].split(","))
32+
elif os.environ.get('HIP_VISIBLE_DEVICES') is not None:
33+
ngpus = len(os.environ['HIP_VISIBLE_DEVICES'].split(","))
34+
else:
35+
ngpus = int(subprocess.check_output("rocminfo | grep \"Device Type:.\s*.GPU\" | wc -l",shell=True))
36+
log_ngpus = int(math.log2(ngpus))
37+
2838
nthreads = ["1"]
2939
nprocs = ["2"]
30-
ngpus_single = ["1","2","4"]
40+
ngpus_single = [str(2**x) for x in range(log_ngpus+1)]
3141
ngpus_mpi = ["1","2"]
3242
byte_range = [("4", "128M")]
3343
op = ["sum", "prod", "min", "max"]
@@ -99,4 +109,4 @@ def test_AllGatherMPI(request, nthreads, nprocs, ngpus_mpi, byte_range, op, step
99109
print(rccl_test.stdout)
100110
pytest.fail("AllGather test error(s) detected.")
101111

102-
assert rccl_test.returncode == 0
112+
assert rccl_test.returncode == 0

‎test/test_AllReduce.py‎

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,22 @@
2222
import os
2323
import subprocess
2424
import itertools
25+
import math
2526

2627
import pytest
2728

29+
ngpus = 0
30+
if os.environ.get('ROCR_VISIBLE_DEVICES') is not None:
31+
ngpus = len(os.environ['ROCR_VISIBLE_DEVICES'].split(","))
32+
elif os.environ.get('HIP_VISIBLE_DEVICES') is not None:
33+
ngpus = len(os.environ['HIP_VISIBLE_DEVICES'].split(","))
34+
else:
35+
ngpus = int(subprocess.check_output("rocminfo | grep \"Device Type:.\s*.GPU\" | wc -l",shell=True))
36+
log_ngpus = int(math.log2(ngpus))
37+
2838
nthreads = ["1"]
2939
nprocs = ["2"]
30-
ngpus_single = ["1","2","4"]
40+
ngpus_single = [str(2**x) for x in range(log_ngpus+1)]
3141
ngpus_mpi = ["1","2"]
3242
byte_range = [("4", "128M")]
3343
op = ["sum", "prod", "min", "max"]
@@ -99,4 +109,4 @@ def test_AllReduceMPI(request, nthreads, nprocs, ngpus_mpi, byte_range, op, step
99109
print(rccl_test.stdout)
100110
pytest.fail("AllReduce test error(s) detected.")
101111

102-
assert rccl_test.returncode == 0
112+
assert rccl_test.returncode == 0

‎test/test_Broadcast.py‎

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,22 @@
2222
import os
2323
import subprocess
2424
import itertools
25+
import math
2526

2627
import pytest
2728

29+
ngpus = 0
30+
if os.environ.get('ROCR_VISIBLE_DEVICES') is not None:
31+
ngpus = len(os.environ['ROCR_VISIBLE_DEVICES'].split(","))
32+
elif os.environ.get('HIP_VISIBLE_DEVICES') is not None:
33+
ngpus = len(os.environ['HIP_VISIBLE_DEVICES'].split(","))
34+
else:
35+
ngpus = int(subprocess.check_output("rocminfo | grep \"Device Type:.\s*.GPU\" | wc -l",shell=True))
36+
log_ngpus = int(math.log2(ngpus))
37+
2838
nthreads = ["1"]
2939
nprocs = ["2"]
30-
ngpus_single = ["1","2","4"]
40+
ngpus_single = [str(2**x) for x in range(log_ngpus+1)]
3141
ngpus_mpi = ["1","2"]
3242
byte_range = [("4", "128M")]
3343
op = ["sum", "prod", "min", "max"]
@@ -99,4 +109,4 @@ def test_BroadcastMPI(request, nthreads, nprocs, ngpus_mpi, byte_range, op, step
99109
print(rccl_test.stdout)
100110
pytest.fail("Broadcast test error(s) detected.")
101111

102-
assert rccl_test.returncode == 0
112+
assert rccl_test.returncode == 0

‎test/test_Reduce.py‎

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,22 @@
2222
import os
2323
import subprocess
2424
import itertools
25+
import math
2526

2627
import pytest
2728

29+
ngpus = 0
30+
if os.environ.get('ROCR_VISIBLE_DEVICES') is not None:
31+
ngpus = len(os.environ['ROCR_VISIBLE_DEVICES'].split(","))
32+
elif os.environ.get('HIP_VISIBLE_DEVICES') is not None:
33+
ngpus = len(os.environ['HIP_VISIBLE_DEVICES'].split(","))
34+
else:
35+
ngpus = int(subprocess.check_output("rocminfo | grep \"Device Type:.\s*.GPU\" | wc -l",shell=True))
36+
log_ngpus = int(math.log2(ngpus))
37+
2838
nthreads = ["1"]
2939
nprocs = ["2"]
30-
ngpus_single = ["1","2","4"]
40+
ngpus_single = [str(2**x) for x in range(log_ngpus+1)]
3141
ngpus_mpi = ["1","2"]
3242
byte_range = [("4", "128M")]
3343
op = ["sum", "prod", "min", "max"]
@@ -99,4 +109,4 @@ def test_ReduceMPI(request, nthreads, nprocs, ngpus_mpi, byte_range, op, step_fa
99109
print(rccl_test.stdout)
100110
pytest.fail("Reduce test error(s) detected.")
101111

102-
assert rccl_test.returncode == 0
112+
assert rccl_test.returncode == 0

‎test/test_ReduceScatter.py‎

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,22 @@
2222
import os
2323
import subprocess
2424
import itertools
25+
import math
2526

2627
import pytest
2728

29+
ngpus = 0
30+
if os.environ.get('ROCR_VISIBLE_DEVICES') is not None:
31+
ngpus = len(os.environ['ROCR_VISIBLE_DEVICES'].split(","))
32+
elif os.environ.get('HIP_VISIBLE_DEVICES') is not None:
33+
ngpus = len(os.environ['HIP_VISIBLE_DEVICES'].split(","))
34+
else:
35+
ngpus = int(subprocess.check_output("rocminfo | grep \"Device Type:.\s*.GPU\" | wc -l",shell=True))
36+
log_ngpus = int(math.log2(ngpus))
37+
2838
nthreads = ["1"]
2939
nprocs = ["2"]
30-
ngpus_single = ["1","2","4"]
40+
ngpus_single = [str(2**x) for x in range(log_ngpus+1)]
3141
ngpus_mpi = ["1","2"]
3242
byte_range = [("4", "128M")]
3343
op = ["sum", "prod", "min", "max"]
@@ -99,4 +109,4 @@ def test_ReduceScatterMPI(request, nthreads, nprocs, ngpus_mpi, byte_range, op,
99109
print(rccl_test.stdout)
100110
pytest.fail("ReduceScatter test error(s) detected.")
101111

102-
assert rccl_test.returncode == 0
112+
assert rccl_test.returncode == 0

0 commit comments

Comments
 (0)