Skip to content

Commit 9fb0f5b

Browse files
Update 502_GPU.py
There are several little mistakes in this file which will make you fail to debug successfully.
1 parent d87529e commit 9fb0f5b

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

‎tutorial-contents/502_GPU.py‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
# !!!!!!!! Change in here !!!!!!!!! #
2828
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1)).type(torch.FloatTensor)[:2000].cuda()/255. # Tensor on GPU
29-
test_y = test_data.test_labels[:2000]
29+
test_y = test_data.test_labels[:2000].cuda()
3030

3131

3232
class CNN(nn.Module):
@@ -69,7 +69,7 @@ def forward(self, x):
6969
test_output = cnn(test_x)
7070

7171
# !!!!!!!! Change in here !!!!!!!!! #
72-
pred_y = torch.max(test_output, 1)[1].cup().data.squeeze() # Move to CPU
72+
pred_y = torch.max(test_output, 1)[1].cuda().data.squeeze() # Move to CPU
7373

7474
accuracy = sum(pred_y == test_y) / test_y.size(0)
7575
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data[0], '| test accuracy: %.2f' % accuracy)
@@ -78,7 +78,7 @@ def forward(self, x):
7878
test_output = cnn(test_x[:10])
7979

8080
# !!!!!!!! Change in here !!!!!!!!! #
81-
pred_y = torch.max(test_output, 1)[1].cup().data.numpy().squeeze() # Move to CPU
81+
pred_y = torch.max(test_output, 1)[1].cuda().data.squeeze() # Move to CPU
8282

8383
print(pred_y, 'prediction number')
84-
print(test_y[:10].numpy(), 'real number')
84+
print(test_y[:10], 'real number')

0 commit comments

Comments
 (0)