@@ -74,7 +74,7 @@ def forward(self, x):
7474 x = self .conv2 (x )
7575 x = x .view (x .size (0 ), - 1 ) # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
7676 output = self .out (x )
77- return output
77+ return output , x # return x for visualization
7878
7979
8080cnn = CNN ()
@@ -83,24 +83,53 @@ def forward(self, x):
8383optimizer = torch .optim .Adam (cnn .parameters (), lr = LR ) # optimize all cnn parameters
8484loss_func = nn .CrossEntropyLoss () # the target label is not one-hotted
8585
86+ # following function (plot_with_labels) is for visualization, can be ignored if not interested
87+ from matplotlib import cm
88+ try :
89+ from sklearn .manifold import TSNE
90+ HAS_SK = True
91+ except :
92+ HAS_SK = False
93+ print ('Please install sklearn for layer visualization' )
94+ def plot_with_labels (lowDWeights , labels ):
95+ plt .cla ()
96+ X , Y = lowDWeights [:, 0 ], lowDWeights [:, 1 ]
97+ for x , y , s in zip (X , Y , labels ):
98+ c = cm .rainbow (int (255 * s / 9 ))
99+ plt .text (x , y , s , backgroundcolor = c , fontsize = 9 )
100+ plt .xlim (X .min (), X .max ())
101+ plt .ylim (Y .min (), Y .max ())
102+ plt .title ('Visualize last layer' )
103+ plt .show ()
104+ plt .pause (0.01 )
105+
106+ plt .ion ()
107+
86108# training and testing
87109for epoch in range (EPOCH ):
88110 for step , (x , y ) in enumerate (train_loader ): # gives batch data, normalize x when iterate train_loader
89111 b_x = Variable (x ) # batch x
90112 b_y = Variable (y ) # batch y
91113
92- output = cnn (b_x ) # cnn output
114+ output = cnn (b_x )[ 0 ] # cnn output
93115 loss = loss_func (output , b_y ) # cross entropy loss
94116 optimizer .zero_grad () # clear gradients for this training step
95117 loss .backward () # backpropagation, compute gradients
96118 optimizer .step () # apply gradients
97119
98120 if step % 50 == 0 :
99- test_output = cnn (test_x )
121+ test_output , last_layer = cnn (test_x )
100122 pred_y = torch .max (test_output , 1 )[1 ].data .squeeze ()
101123 accuracy = sum (pred_y == test_y ) / float (test_y .size (0 ))
102124 print ('Epoch: ' , epoch , '| train loss: %.4f' % loss .data [0 ], '| test accuracy: %.2f' % accuracy )
103-
125+ if HAS_SK :
126+ # Visualization of trained flatten layer (T-SNE)
127+ tsne = TSNE (perplexity = 30 , n_components = 2 , init = 'pca' , n_iter = 5000 )
128+ plot_only = 500
129+ low_dim_embs = tsne .fit_transform (last_layer .data .numpy ()[:plot_only , :])
130+ labels = test_y .numpy ()[:plot_only ]
131+ plot_with_labels (low_dim_embs , labels )
132+ plt .ioff ()
104133
105134# print 10 predictions from test data
106135test_output = cnn (test_x [:10 ])
0 commit comments