Skip to content

Commit 2fff359

Browse files
hXl3snv-kkudrynski
authored andcommitted
[Convents/MX] Logging and suspend-resume fixes
1 parent e9f7444 commit 2fff359

2 files changed

Lines changed: 25 additions & 8 deletions

File tree

‎MxNet/Classification/RN50v1.5/fit.py‎

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def should_end(self) -> bool:
8383
return bool(self.t[0] > 0)
8484

8585
def _signal_handler(self, signum, frame):
86+
print("Signal reveived")
8687
self.t[0] = 1
8788

8889

@@ -152,6 +153,8 @@ def float_list(x):
152153
help='file where to save the dllogger log from the experiment')
153154
train.add_argument('--workspace', type=str, default='./',
154155
help='path to directory where results will be stored')
156+
train.add_argument('--logdir', type=str, default=None,
157+
help="path to directory where logs will be stored")
155158
train.add_argument('--no-metrics', action='store_true',
156159
help='do not calculate evaluation metrics (for benchmarking)')
157160
train.add_argument('--benchmark-iters', type=int, default=None,
@@ -199,13 +202,18 @@ def load_model(args, model):
199202
file = list(glob.glob(
200203
f"{args.workspace}/{args.model_prefix}_*.params"))
201204
if len(file) == 0:
202-
return 0
205+
return -1
206+
207+
file = [x for x in sorted(file) if "best.params" not in x]
208+
209+
if len(file) == 0:
210+
return -1
203211

204-
file = [x for x in sorted(file) if "best.params" not in x][-1]
212+
file = file[-1]
205213

206214
epoch = re.match(f".*{args.model_prefix}_([0-9]*)\.params", file)
207215
if epoch is None:
208-
return 0
216+
return -1
209217

210218
epoch = int(epoch.group(1))
211219
model.load_parameters(file)
@@ -427,6 +435,8 @@ def transform_data(images, labels):
427435

428436
durations.append(time.time() - tic)
429437
tic = time.time()
438+
else:
439+
break
430440

431441
durations = durations[min(len(durations) // 10, 100):]
432442
dllogger_epoch_data = {
@@ -453,8 +463,8 @@ def transform_data(images, labels):
453463
accuracy = score.get('accuracy', -1)
454464
save_checkpoint(net, epoch, accuracy, best_accuracy,
455465
model_prefix, args.workspace,
456-
args.save_frequency, kvstore,
457-
force_save=should_break)
466+
args.save_frequency if args.mode == "train_val" else -1,
467+
kvstore, force_save=should_break)
458468
best_accuracy = max(best_accuracy, accuracy)
459469
global_metrics.update_dict(dllogger_epoch_data)
460470
dllogger.log(step=(epoch,), data=dllogger_epoch_data)
@@ -473,6 +483,11 @@ def fit(args, model, data_loader):
473483
# select gpu for horovod process
474484
if 'horovod' in args.kv_store:
475485
args.gpus = [args.gpus[hvd.local_rank()]]
486+
ctx = mx.gpu(hvd.local_rank())
487+
488+
tensor1 = mx.nd.zeros(shape=(1,), dtype='float32', ctx=ctx)
489+
tensor2 = mx.nd.zeros(shape=(1,), dtype='float32', ctx=ctx)
490+
tensor1, tensor2 = hvd.grouped_allreduce([tensor1,tensor2])
476491

477492
if args.amp:
478493
amp.init()
@@ -516,11 +531,12 @@ def fit(args, model, data_loader):
516531
tic = time.time()
517532
return
518533

519-
start_epoch = load_model(args, model)
534+
start_epoch = load_model(args, model) + 1
520535
if start_epoch == 0:
521536
# all initializers should be specified in the model definition.
522537
# if not, this will raise an error
523538
model.initialize(mx.init.Initializer())
539+
logging.info(f"starting epoch {start_epoch}")
524540

525541
# devices for training
526542
devs = list(map(mx.gpu, args.gpus))
@@ -598,7 +614,7 @@ def fit(args, model, data_loader):
598614
optimizer=args.optimizer,
599615
optimizer_params=optimizer_params,
600616
lr_scheduler=lr_scheduler,
601-
model_prefix=os.path.join(args.workspace, args.model_prefix),
617+
model_prefix=args.model_prefix,
602618
)
603619
elif args.mode == 'val':
604620
for epoch in range(args.num_epochs): # loop for benchmarking

‎MxNet/Classification/RN50v1.5/log_utils.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ def format_step(step):
2424
def setup_logging(args):
2525
logging.basicConfig(level=logging.DEBUG, format='{asctime}:{levelname}: {message}', style='{')
2626
if hvd.rank() == 0:
27+
logging_dir = args.logdir if args.logdir is not None else args.workspace
2728
dllogger.init(backends=[
2829
dllogger.StdOutBackend(dllogger.Verbosity.DEFAULT, step_format=format_step),
2930
dllogger.JSONStreamBackend(
30-
dllogger.Verbosity.VERBOSE, os.path.join(args.workspace, args.dllogger_log)),
31+
dllogger.Verbosity.VERBOSE, os.path.join(logging_dir, args.dllogger_log)),
3132
])
3233
else:
3334
dllogger.init([])

0 commit comments

Comments
 (0)