Skip to content

Commit 71aab7b

Browse files
meatybobbynv-kkudrynski
authored andcommitted
[BERT/TF] Add final loss metrics
1 parent a5feffa commit 71aab7b

1 file changed

Lines changed: 10 additions & 5 deletions

File tree

‎TensorFlow/LanguageModeling/BERT/run_pretraining.py‎

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@
104104
flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
105105

106106
flags.DEFINE_integer("num_accumulation_steps", 1,
107-
"Number of accumulation steps before gradient update."
107+
"Number of accumulation steps before gradient update."
108108
"Global batch size = num_accumulation_steps * train_batch_size")
109109

110110
flags.DEFINE_bool("allreduce_post_accumulation", False, "Whether to all reduce after accumulation of N steps or after each step")
@@ -146,6 +146,7 @@ def after_create_session(self, session, coord):
146146
self.step_time = 0.0 # time taken per step
147147
self.init_global_step = session.run(tf.train.get_global_step()) # training starts at init_global_step
148148
self.skipped = 0
149+
self.final_loss = 0
149150

150151
def before_run(self, run_context):
151152
self.t0 = time.time()
@@ -246,6 +247,7 @@ def after_run(self, run_context, run_values):
246247
self.count = 0
247248
self.loss = 0.0
248249
self.all_count = 0
250+
self.final_loss = avg_loss_step
249251

250252
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
251253
num_train_steps, num_warmup_steps,
@@ -280,8 +282,8 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
280282

281283
(masked_lm_loss,
282284
masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
283-
bert_config, model.get_sequence_output(), model.get_embedding_table(),
284-
masked_lm_positions, masked_lm_ids,
285+
bert_config, model.get_sequence_output(), model.get_embedding_table(),
286+
masked_lm_positions, masked_lm_ids,
285287
masked_lm_weights)
286288

287289
(next_sentence_loss, next_sentence_example_loss,
@@ -582,7 +584,7 @@ def main(_):
582584
tf.compat.v1.logging.info("**************************")
583585

584586
# config.gpu_options.per_process_gpu_memory_fraction = 0.7
585-
if FLAGS.use_xla:
587+
if FLAGS.use_xla:
586588
config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
587589
config.graph_options.rewrite_options.memory_optimization = rewriter_config_pb2.RewriterConfig.NO_MEM_OPT
588590
if FLAGS.amp:
@@ -620,7 +622,8 @@ def main(_):
620622
training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
621623
if (not FLAGS.horovod or hvd.rank() == 0):
622624
global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps if not FLAGS.horovod else FLAGS.train_batch_size * FLAGS.num_accumulation_steps * hvd.size()
623-
training_hooks.append(_LogSessionRunHook(global_batch_size, FLAGS.num_accumulation_steps, dllogging, FLAGS.display_loss_steps, FLAGS.save_checkpoints_steps, FLAGS.report_loss))
625+
log_hook = _LogSessionRunHook(global_batch_size, FLAGS.num_accumulation_steps, dllogging, FLAGS.display_loss_steps, FLAGS.save_checkpoints_steps, FLAGS.report_loss)
626+
training_hooks.append(log_hook)
624627

625628
tf.compat.v1.logging.info("***** Running training *****")
626629
tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size)
@@ -649,6 +652,8 @@ def main(_):
649652
tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
650653
tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
651654
dllogging.logger.log(step=(), data={"throughput_train": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
655+
if log_hook.final_loss != 0:
656+
dllogging.logger.log(step=(), data={"total_loss": log_hook.final_loss}, verbosity=Verbosity.DEFAULT)
652657
tf.compat.v1.logging.info("-----------------------------")
653658

654659
if FLAGS.do_eval and (not FLAGS.horovod or hvd.rank() == 0):

0 commit comments

Comments
 (0)