Skip to content

Commit 987b30c

Browse files
committed
Code refactoring.
1 parent 87550e3 commit 987b30c

2 files changed

Lines changed: 16 additions & 17 deletions

File tree

‎utils/checkmate.py‎

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
# -*- coding:utf-8 -*-
2-
__author__ = 'Randolph'
3-
41
import os
52
import glob
63
import json
@@ -9,7 +6,7 @@
96

107

118
class BestCheckpointSaver(object):
12-
"""Maintains a directory containing only the best n checkpoints
9+
"""Maintains a directory containing only the best n checkpoints.
1310
Inside the directory is a best_checkpoints JSON file containing a dictionary
1411
mapping of the best checkpoint filepaths to the values by which the checkpoints
1512
are compared. Only the best n checkpoints are contained in the directory and JSON file.
@@ -18,11 +15,12 @@ class BestCheckpointSaver(object):
1815
framework.
1916
"""
2017
def __init__(self, save_dir, num_to_keep=1, maximize=True, saver=None):
21-
"""Creates a `BestCheckpointSaver`
22-
`BestCheckpointSaver` acts as a wrapper class around a `tf.train.Saver`
18+
"""Creates a `BestCheckpointSaver`.
19+
`BestCheckpointSaver` acts as a wrapper class around a `tf.train.Saver`.
20+
2321
Args:
24-
save_dir: The directory in which the checkpoint files will be saved
25-
num_to_keep: The number of best checkpoint files to retain
22+
save_dir: The directory in which the checkpoint files will be saved.
23+
num_to_keep: The number of best checkpoint files to retain.
2624
maximize: Define 'best' values to be the highest values. For example,
2725
set this to True if selecting for the checkpoints with the highest
2826
given accuracy. Or set to False to select for checkpoints with the
@@ -45,10 +43,11 @@ def __init__(self, save_dir, num_to_keep=1, maximize=True, saver=None):
4543

4644
def handle(self, value, sess, global_step):
4745
"""Updates the set of best checkpoints based on the given result.
46+
4847
Args:
4948
value: The value by which to rank the checkpoint.
50-
sess: A tf.Session to use to save the checkpoint
51-
global_step: The global step
49+
sess: A tf.Session to use to save the checkpoint.
50+
global_step: The global step.
5251
"""
5352
current_ckpt = 'model-{}'.format(global_step)
5453
value = float(value)
@@ -117,19 +116,19 @@ def _sort(self, best_checkpoints):
117116

118117

119118
def get_best_checkpoint(best_checkpoint_dir, select_maximum_value=True):
120-
"""
121-
Returns filepath to the best checkpoint
119+
"""Returns filepath to the best checkpoint.
122120
Reads the best_checkpoints file in the best_checkpoint_dir directory.
123121
Returns the filepath in the best_checkpoints file associated with
124122
the highest value if select_maximum_value is True, or the filepath
125123
associated with the lowest value if select_maximum_value is False.
124+
126125
Args:
127-
best_checkpoint_dir: Directory containing best_checkpoints JSON file
126+
best_checkpoint_dir: Directory containing best_checkpoints JSON file.
128127
select_maximum_value: If True, select the filepath associated
129128
with the highest value. Otherwise, select the filepath associated
130129
with the lowest value.
131130
Returns:
132-
The full path to the best checkpoint file
131+
The full path to the best checkpoint file.
133132
"""
134133
best_checkpoints_file = os.path.join(best_checkpoint_dir, 'best_checkpoints')
135134
assert os.path.exists(best_checkpoints_file)

‎utils/param_parser.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def parameter_parser():
3131

3232
parser.add_argument("--word2vec-file",
3333
nargs="?",
34-
default="../data/word2vec_100.model",
34+
default="../data/word2vec_100.kv",
3535
help="Word2vec file for embedding characters (the dim need to be the same as embedding dim).")
3636

3737
# Model Hyperparameters
@@ -138,7 +138,7 @@ def parameter_parser():
138138

139139
parser.add_argument("--evaluate-steps",
140140
type=int,
141-
default=50,
141+
default=10,
142142
help="Evaluate model on val set after how many steps. (default: 50)")
143143

144144
parser.add_argument("--norm-ratio",
@@ -153,7 +153,7 @@ def parameter_parser():
153153

154154
parser.add_argument("--checkpoint-steps",
155155
type=int,
156-
default=50,
156+
default=10,
157157
help="Save model after how many steps. (default: 50)")
158158

159159
parser.add_argument("--num-checkpoints",

0 commit comments

Comments
 (0)