Skip to content

Commit 645625f

Browse files
committed
Update code to TensorFlow 1.14 version
1 parent 71ff6c3 commit 645625f

16 files changed

Lines changed: 168 additions & 169 deletions

‎ANN/test_ann.py‎

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test_ann():
8686

8787
test_pre_tk = [0.0] * args.topK
8888
test_rec_tk = [0.0] * args.topK
89-
test_F_tk = [0.0] * args.topK
89+
test_F1_tk = [0.0] * args.topK
9090

9191
# Collect the predictions here
9292
true_labels = []
@@ -148,8 +148,8 @@ def test_ann():
148148
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
149149
test_rec_ts = recall_score(y_true=np.array(true_onehot_labels),
150150
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
151-
test_F_ts = f1_score(y_true=np.array(true_onehot_labels),
152-
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
151+
test_F1_ts = f1_score(y_true=np.array(true_onehot_labels),
152+
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
153153

154154
for top_num in range(args.topK):
155155
test_pre_tk[top_num] = precision_score(y_true=np.array(true_onehot_labels),
@@ -158,9 +158,9 @@ def test_ann():
158158
test_rec_tk[top_num] = recall_score(y_true=np.array(true_onehot_labels),
159159
y_pred=np.array(predicted_onehot_labels_tk[top_num]),
160160
average='micro')
161-
test_F_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels),
162-
y_pred=np.array(predicted_onehot_labels_tk[top_num]),
163-
average='micro')
161+
test_F1_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels),
162+
y_pred=np.array(predicted_onehot_labels_tk[top_num]),
163+
average='micro')
164164

165165
# Calculate the average AUC
166166
test_auc = roc_auc_score(y_true=np.array(true_onehot_labels),
@@ -176,13 +176,13 @@ def test_ann():
176176

177177
# Predict by threshold
178178
logger.info("Predict by threshold: Precision {0:g}, Recall {1:g}, F1 {2:g}"
179-
.format(test_pre_ts, test_rec_ts, test_F_ts))
179+
.format(test_pre_ts, test_rec_ts, test_F1_ts))
180180

181181
# Predict by topK
182182
logger.info("Predict by topK:")
183183
for top_num in range(args.topK):
184184
logger.info("Top{0}: Precision {1:g}, Recall {2:g}, F {3:g}"
185-
.format(top_num + 1, test_pre_tk[top_num], test_rec_tk[top_num], test_F_tk[top_num]))
185+
.format(top_num + 1, test_pre_tk[top_num], test_rec_tk[top_num], test_F1_tk[top_num]))
186186

187187
# Save the prediction result
188188
if not os.path.exists(SAVE_DIR):

‎ANN/train_ann.py‎

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def validation_step(x_val, y_val, writer=None):
151151

152152
eval_pre_tk = [0.0] * args.topK
153153
eval_rec_tk = [0.0] * args.topK
154-
eval_F_tk = [0.0] * args.topK
154+
eval_F1_tk = [0.0] * args.topK
155155

156156
true_onehot_labels = []
157157
predicted_onehot_scores = []
@@ -202,8 +202,8 @@ def validation_step(x_val, y_val, writer=None):
202202
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
203203
eval_rec_ts = recall_score(y_true=np.array(true_onehot_labels),
204204
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
205-
eval_F_ts = f1_score(y_true=np.array(true_onehot_labels),
206-
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
205+
eval_F1_ts = f1_score(y_true=np.array(true_onehot_labels),
206+
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
207207

208208
for top_num in range(args.topK):
209209
eval_pre_tk[top_num] = precision_score(y_true=np.array(true_onehot_labels),
@@ -212,9 +212,9 @@ def validation_step(x_val, y_val, writer=None):
212212
eval_rec_tk[top_num] = recall_score(y_true=np.array(true_onehot_labels),
213213
y_pred=np.array(predicted_onehot_labels_tk[top_num]),
214214
average='micro')
215-
eval_F_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels),
216-
y_pred=np.array(predicted_onehot_labels_tk[top_num]),
217-
average='micro')
215+
eval_F1_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels),
216+
y_pred=np.array(predicted_onehot_labels_tk[top_num]),
217+
average='micro')
218218

219219
# Calculate the average AUC
220220
eval_auc = roc_auc_score(y_true=np.array(true_onehot_labels),
@@ -223,8 +223,8 @@ def validation_step(x_val, y_val, writer=None):
223223
eval_prc = average_precision_score(y_true=np.array(true_onehot_labels),
224224
y_score=np.array(predicted_onehot_scores), average='micro')
225225

226-
return eval_loss, eval_auc, eval_prc, eval_rec_ts, eval_pre_ts, eval_F_ts, \
227-
eval_rec_tk, eval_pre_tk, eval_F_tk
226+
return eval_loss, eval_auc, eval_prc, eval_pre_ts, eval_rec_ts, eval_F1_ts, \
227+
eval_pre_tk, eval_rec_tk, eval_F1_tk
228228

229229
# Generate batches
230230
batches_train = dh.batch_iter(
@@ -241,21 +241,21 @@ def validation_step(x_val, y_val, writer=None):
241241
if current_step % args.evaluate_steps == 0:
242242
logger.info("\nEvaluation:")
243243
eval_loss, eval_auc, eval_prc, \
244-
eval_rec_ts, eval_pre_ts, eval_F_ts, eval_rec_tk, eval_pre_tk, eval_F_tk = \
244+
eval_pre_ts, eval_rec_ts, eval_F1_ts, eval_pre_tk, eval_rec_tk, eval_F1_tk = \
245245
validation_step(x_val, y_val, writer=validation_summary_writer)
246246

247247
logger.info("All Validation set: Loss {0:g} | AUC {1:g} | AUPRC {2:g}"
248248
.format(eval_loss, eval_auc, eval_prc))
249249

250250
# Predict by threshold
251-
logger.info("Predict by threshold: Precision {0:g}, Recall {1:g}, F {2:g}"
252-
.format(eval_pre_ts, eval_rec_ts, eval_F_ts))
251+
logger.info("Predict by threshold: Precision {0:g}, Recall {1:g}, F1 {2:g}"
252+
.format(eval_pre_ts, eval_rec_ts, eval_F1_ts))
253253

254254
# Predict by topK
255255
logger.info("Predict by topK:")
256256
for top_num in range(args.topK):
257-
logger.info("Top{0}: Precision {1:g}, Recall {2:g}, F {3:g}"
258-
.format(top_num+1, eval_pre_tk[top_num], eval_rec_tk[top_num], eval_F_tk[top_num]))
257+
logger.info("Top{0}: Precision {1:g}, Recall {2:g}, F1 {3:g}"
258+
.format(top_num+1, eval_pre_tk[top_num], eval_rec_tk[top_num], eval_F1_tk[top_num]))
259259
best_saver.handle(eval_prc, sess, current_step)
260260
if current_step % args.checkpoint_steps == 0:
261261
checkpoint_prefix = os.path.join(checkpoint_dir, "model")

‎CNN/test_cnn.py‎

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_cnn():
8787

8888
test_pre_tk = [0.0] * args.topK
8989
test_rec_tk = [0.0] * args.topK
90-
test_F_tk = [0.0] * args.topK
90+
test_F1_tk = [0.0] * args.topK
9191

9292
# Collect the predictions here
9393
true_labels = []
@@ -149,8 +149,8 @@ def test_cnn():
149149
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
150150
test_rec_ts = recall_score(y_true=np.array(true_onehot_labels),
151151
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
152-
test_F_ts = f1_score(y_true=np.array(true_onehot_labels),
153-
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
152+
test_F1_ts = f1_score(y_true=np.array(true_onehot_labels),
153+
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
154154

155155
for top_num in range(args.topK):
156156
test_pre_tk[top_num] = precision_score(y_true=np.array(true_onehot_labels),
@@ -159,9 +159,9 @@ def test_cnn():
159159
test_rec_tk[top_num] = recall_score(y_true=np.array(true_onehot_labels),
160160
y_pred=np.array(predicted_onehot_labels_tk[top_num]),
161161
average='micro')
162-
test_F_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels),
163-
y_pred=np.array(predicted_onehot_labels_tk[top_num]),
164-
average='micro')
162+
test_F1_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels),
163+
y_pred=np.array(predicted_onehot_labels_tk[top_num]),
164+
average='micro')
165165

166166
# Calculate the average AUC
167167
test_auc = roc_auc_score(y_true=np.array(true_onehot_labels),
@@ -177,13 +177,13 @@ def test_cnn():
177177

178178
# Predict by threshold
179179
logger.info("Predict by threshold: Precision {0:g}, Recall {1:g}, F1 {2:g}"
180-
.format(test_pre_ts, test_rec_ts, test_F_ts))
180+
.format(test_pre_ts, test_rec_ts, test_F1_ts))
181181

182182
# Predict by topK
183183
logger.info("Predict by topK:")
184184
for top_num in range(args.topK):
185185
logger.info("Top{0}: Precision {1:g}, Recall {2:g}, F {3:g}"
186-
.format(top_num + 1, test_pre_tk[top_num], test_rec_tk[top_num], test_F_tk[top_num]))
186+
.format(top_num + 1, test_pre_tk[top_num], test_rec_tk[top_num], test_F1_tk[top_num]))
187187

188188
# Save the prediction result
189189
if not os.path.exists(SAVE_DIR):

‎CNN/train_cnn.py‎

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def validation_step(x_val, y_val, writer=None):
154154

155155
eval_pre_tk = [0.0] * args.topK
156156
eval_rec_tk = [0.0] * args.topK
157-
eval_F_tk = [0.0] * args.topK
157+
eval_F1_tk = [0.0] * args.topK
158158

159159
true_onehot_labels = []
160160
predicted_onehot_scores = []
@@ -205,8 +205,8 @@ def validation_step(x_val, y_val, writer=None):
205205
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
206206
eval_rec_ts = recall_score(y_true=np.array(true_onehot_labels),
207207
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
208-
eval_F_ts = f1_score(y_true=np.array(true_onehot_labels),
209-
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
208+
eval_F1_ts = f1_score(y_true=np.array(true_onehot_labels),
209+
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
210210

211211
for top_num in range(args.topK):
212212
eval_pre_tk[top_num] = precision_score(y_true=np.array(true_onehot_labels),
@@ -215,9 +215,9 @@ def validation_step(x_val, y_val, writer=None):
215215
eval_rec_tk[top_num] = recall_score(y_true=np.array(true_onehot_labels),
216216
y_pred=np.array(predicted_onehot_labels_tk[top_num]),
217217
average='micro')
218-
eval_F_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels),
219-
y_pred=np.array(predicted_onehot_labels_tk[top_num]),
220-
average='micro')
218+
eval_F1_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels),
219+
y_pred=np.array(predicted_onehot_labels_tk[top_num]),
220+
average='micro')
221221

222222
# Calculate the average AUC
223223
eval_auc = roc_auc_score(y_true=np.array(true_onehot_labels),
@@ -226,8 +226,8 @@ def validation_step(x_val, y_val, writer=None):
226226
eval_prc = average_precision_score(y_true=np.array(true_onehot_labels),
227227
y_score=np.array(predicted_onehot_scores), average='micro')
228228

229-
return eval_loss, eval_auc, eval_prc, eval_rec_ts, eval_pre_ts, eval_F_ts, \
230-
eval_rec_tk, eval_pre_tk, eval_F_tk
229+
return eval_loss, eval_auc, eval_prc, eval_pre_ts, eval_rec_ts, eval_F1_ts, \
230+
eval_pre_tk, eval_rec_tk, eval_F1_tk
231231

232232
# Generate batches
233233
batches_train = dh.batch_iter(
@@ -244,21 +244,21 @@ def validation_step(x_val, y_val, writer=None):
244244
if current_step % args.evaluate_steps == 0:
245245
logger.info("\nEvaluation:")
246246
eval_loss, eval_auc, eval_prc, \
247-
eval_rec_ts, eval_pre_ts, eval_F_ts, eval_rec_tk, eval_pre_tk, eval_F_tk = \
247+
eval_pre_ts, eval_rec_ts, eval_F1_ts, eval_pre_tk, eval_rec_tk, eval_F1_tk = \
248248
validation_step(x_val, y_val, writer=validation_summary_writer)
249249

250250
logger.info("All Validation set: Loss {0:g} | AUC {1:g} | AUPRC {2:g}"
251251
.format(eval_loss, eval_auc, eval_prc))
252252

253253
# Predict by threshold
254-
logger.info("Predict by threshold: Precision {0:g}, Recall {1:g}, F {2:g}"
255-
.format(eval_pre_ts, eval_rec_ts, eval_F_ts))
254+
logger.info("Predict by threshold: Precision {0:g}, Recall {1:g}, F1 {2:g}"
255+
.format(eval_pre_ts, eval_rec_ts, eval_F1_ts))
256256

257257
# Predict by topK
258258
logger.info("Predict by topK:")
259259
for top_num in range(args.topK):
260-
logger.info("Top{0}: Precision {1:g}, Recall {2:g}, F {3:g}"
261-
.format(top_num+1, eval_pre_tk[top_num], eval_rec_tk[top_num], eval_F_tk[top_num]))
260+
logger.info("Top{0}: Precision {1:g}, Recall {2:g}, F1 {3:g}"
261+
.format(top_num+1, eval_pre_tk[top_num], eval_rec_tk[top_num], eval_F1_tk[top_num]))
262262
best_saver.handle(eval_prc, sess, current_step)
263263
if current_step % args.checkpoint_steps == 0:
264264
checkpoint_prefix = os.path.join(checkpoint_dir, "model")

‎CRNN/test_crnn.py‎

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_crnn():
8787

8888
test_pre_tk = [0.0] * args.topK
8989
test_rec_tk = [0.0] * args.topK
90-
test_F_tk = [0.0] * args.topK
90+
test_F1_tk = [0.0] * args.topK
9191

9292
# Collect the predictions here
9393
true_labels = []
@@ -149,8 +149,8 @@ def test_crnn():
149149
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
150150
test_rec_ts = recall_score(y_true=np.array(true_onehot_labels),
151151
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
152-
test_F_ts = f1_score(y_true=np.array(true_onehot_labels),
153-
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
152+
test_F1_ts = f1_score(y_true=np.array(true_onehot_labels),
153+
y_pred=np.array(predicted_onehot_labels_ts), average='micro')
154154

155155
for top_num in range(args.topK):
156156
test_pre_tk[top_num] = precision_score(y_true=np.array(true_onehot_labels),
@@ -159,9 +159,9 @@ def test_crnn():
159159
test_rec_tk[top_num] = recall_score(y_true=np.array(true_onehot_labels),
160160
y_pred=np.array(predicted_onehot_labels_tk[top_num]),
161161
average='micro')
162-
test_F_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels),
163-
y_pred=np.array(predicted_onehot_labels_tk[top_num]),
164-
average='micro')
162+
test_F1_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels),
163+
y_pred=np.array(predicted_onehot_labels_tk[top_num]),
164+
average='micro')
165165

166166
# Calculate the average AUC
167167
test_auc = roc_auc_score(y_true=np.array(true_onehot_labels),
@@ -177,13 +177,13 @@ def test_crnn():
177177

178178
# Predict by threshold
179179
logger.info("Predict by threshold: Precision {0:g}, Recall {1:g}, F1 {2:g}"
180-
.format(test_pre_ts, test_rec_ts, test_F_ts))
180+
.format(test_pre_ts, test_rec_ts, test_F1_ts))
181181

182182
# Predict by topK
183183
logger.info("Predict by topK:")
184184
for top_num in range(args.topK):
185185
logger.info("Top{0}: Precision {1:g}, Recall {2:g}, F {3:g}"
186-
.format(top_num + 1, test_pre_tk[top_num], test_rec_tk[top_num], test_F_tk[top_num]))
186+
.format(top_num + 1, test_pre_tk[top_num], test_rec_tk[top_num], test_F1_tk[top_num]))
187187

188188
# Save the prediction result
189189
if not os.path.exists(SAVE_DIR):

0 commit comments

Comments
 (0)