1- # -*- coding:utf-8 -*-
2- __author__ = 'Randolph'
3-
41import os
52import glob
63import json
96
107
118class BestCheckpointSaver (object ):
12- """Maintains a directory containing only the best n checkpoints
9+ """Maintains a directory containing only the best n checkpoints.
1310 Inside the directory is a best_checkpoints JSON file containing a dictionary
1411 mapping of the best checkpoint filepaths to the values by which the checkpoints
1512 are compared. Only the best n checkpoints are contained in the directory and JSON file.
@@ -18,11 +15,12 @@ class BestCheckpointSaver(object):
1815 framework.
1916 """
2017 def __init__ (self , save_dir , num_to_keep = 1 , maximize = True , saver = None ):
21- """Creates a `BestCheckpointSaver`
22- `BestCheckpointSaver` acts as a wrapper class around a `tf.train.Saver`
18+ """Creates a `BestCheckpointSaver`.
19+ `BestCheckpointSaver` acts as a wrapper class around a `tf.train.Saver`.
20+
2321 Args:
24- save_dir: The directory in which the checkpoint files will be saved
25- num_to_keep: The number of best checkpoint files to retain
22+ save_dir: The directory in which the checkpoint files will be saved.
23+ num_to_keep: The number of best checkpoint files to retain.
2624 maximize: Define 'best' values to be the highest values. For example,
2725 set this to True if selecting for the checkpoints with the highest
2826 given accuracy. Or set to False to select for checkpoints with the
@@ -45,10 +43,11 @@ def __init__(self, save_dir, num_to_keep=1, maximize=True, saver=None):
4543
4644 def handle (self , value , sess , global_step ):
4745 """Updates the set of best checkpoints based on the given result.
46+
4847 Args:
4948 value: The value by which to rank the checkpoint.
50- sess: A tf.Session to use to save the checkpoint
51- global_step: The global step
49+ sess: A tf.Session to use to save the checkpoint.
50+ global_step: The global step.
5251 """
5352 current_ckpt = 'model-{}' .format (global_step )
5453 value = float (value )
@@ -117,19 +116,19 @@ def _sort(self, best_checkpoints):
117116
118117
119118def get_best_checkpoint (best_checkpoint_dir , select_maximum_value = True ):
120- """
121- Returns filepath to the best checkpoint
119+ """Returns filepath to the best checkpoint.
122120 Reads the best_checkpoints file in the best_checkpoint_dir directory.
123121 Returns the filepath in the best_checkpoints file associated with
124122 the highest value if select_maximum_value is True, or the filepath
125123 associated with the lowest value if select_maximum_value is False.
124+
126125 Args:
127- best_checkpoint_dir: Directory containing best_checkpoints JSON file
126+ best_checkpoint_dir: Directory containing best_checkpoints JSON file.
128127 select_maximum_value: If True, select the filepath associated
129128 with the highest value. Otherwise, select the filepath associated
130129 with the lowest value.
131130 Returns:
132- The full path to the best checkpoint file
131+ The full path to the best checkpoint file.
133132 """
134133 best_checkpoints_file = os .path .join (best_checkpoint_dir , 'best_checkpoints' )
135134 assert os .path .exists (best_checkpoints_file )
0 commit comments