|
104 | 104 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.") |
105 | 105 |
|
106 | 106 | flags.DEFINE_integer("num_accumulation_steps", 1, |
107 | | - "Number of accumulation steps before gradient update." |
| 107 | + "Number of accumulation steps before gradient update." |
108 | 108 | "Global batch size = num_accumulation_steps * train_batch_size") |
109 | 109 |
|
110 | 110 | flags.DEFINE_bool("allreduce_post_accumulation", False, "Whether to all reduce after accumulation of N steps or after each step") |
@@ -146,6 +146,7 @@ def after_create_session(self, session, coord): |
146 | 146 | self.step_time = 0.0 # time taken per step |
147 | 147 | self.init_global_step = session.run(tf.train.get_global_step()) # training starts at init_global_step |
148 | 148 | self.skipped = 0 |
| 149 | + self.final_loss = 0 |
149 | 150 |
|
150 | 151 | def before_run(self, run_context): |
151 | 152 | self.t0 = time.time() |
@@ -246,6 +247,7 @@ def after_run(self, run_context, run_values): |
246 | 247 | self.count = 0 |
247 | 248 | self.loss = 0.0 |
248 | 249 | self.all_count = 0 |
| 250 | + self.final_loss = avg_loss_step |
249 | 251 |
|
250 | 252 | def model_fn_builder(bert_config, init_checkpoint, learning_rate, |
251 | 253 | num_train_steps, num_warmup_steps, |
@@ -280,8 +282,8 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument |
280 | 282 |
|
281 | 283 | (masked_lm_loss, |
282 | 284 | masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( |
283 | | - bert_config, model.get_sequence_output(), model.get_embedding_table(), |
284 | | - masked_lm_positions, masked_lm_ids, |
| 285 | + bert_config, model.get_sequence_output(), model.get_embedding_table(), |
| 286 | + masked_lm_positions, masked_lm_ids, |
285 | 287 | masked_lm_weights) |
286 | 288 |
|
287 | 289 | (next_sentence_loss, next_sentence_example_loss, |
@@ -582,7 +584,7 @@ def main(_): |
582 | 584 | tf.compat.v1.logging.info("**************************") |
583 | 585 |
|
584 | 586 | # config.gpu_options.per_process_gpu_memory_fraction = 0.7 |
585 | | - if FLAGS.use_xla: |
| 587 | + if FLAGS.use_xla: |
586 | 588 | config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1 |
587 | 589 | config.graph_options.rewrite_options.memory_optimization = rewriter_config_pb2.RewriterConfig.NO_MEM_OPT |
588 | 590 | if FLAGS.amp: |
@@ -620,7 +622,8 @@ def main(_): |
620 | 622 | training_hooks.append(hvd.BroadcastGlobalVariablesHook(0)) |
621 | 623 | if (not FLAGS.horovod or hvd.rank() == 0): |
622 | 624 | global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps if not FLAGS.horovod else FLAGS.train_batch_size * FLAGS.num_accumulation_steps * hvd.size() |
623 | | - training_hooks.append(_LogSessionRunHook(global_batch_size, FLAGS.num_accumulation_steps, dllogging, FLAGS.display_loss_steps, FLAGS.save_checkpoints_steps, FLAGS.report_loss)) |
| 625 | + log_hook = _LogSessionRunHook(global_batch_size, FLAGS.num_accumulation_steps, dllogging, FLAGS.display_loss_steps, FLAGS.save_checkpoints_steps, FLAGS.report_loss) |
| 626 | + training_hooks.append(log_hook) |
624 | 627 |
|
625 | 628 | tf.compat.v1.logging.info("***** Running training *****") |
626 | 629 | tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size) |
@@ -649,6 +652,8 @@ def main(_): |
649 | 652 | tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second) |
650 | 653 | tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second) |
651 | 654 | dllogging.logger.log(step=(), data={"throughput_train": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT) |
| 655 | + if log_hook.final_loss != 0: |
| 656 | + dllogging.logger.log(step=(), data={"total_loss": log_hook.final_loss}, verbosity=Verbosity.DEFAULT) |
652 | 657 | tf.compat.v1.logging.info("-----------------------------") |
653 | 658 |
|
654 | 659 | if FLAGS.do_eval and (not FLAGS.horovod or hvd.rank() == 0): |
|
0 commit comments