1 parent 51f1c93 commit 528d7ccCopy full SHA for 528d7cc
1 file changed
tutorial-contents/401_CNN.py
@@ -122,7 +122,7 @@ def plot_with_labels(lowDWeights, labels):
122
plt.ioff()
123
124
# print 10 predictions from test data
125
-test_output = cnn(test_x[:10])
+test_output, _ = cnn(test_x[:10])
126
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
127
print(pred_y, 'prediction number')
128
print(test_y[:10].numpy(), 'real number')
0 commit comments