Skip to content

Commit 7494be7

Browse files
committed
fix cls train bug
1 parent 8b780d8 commit 7494be7

File tree

3 files changed

+8
-9
lines changed

3 files changed

+8
-9
lines changed

‎torchocr/data/simple_dataset.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,6 @@ def __getitem__(self, properties):
247247
outs = None
248248
if outs is None:
249249
# during evaluation, we should fix the idx to get same results for many times of evaluation.
250-
rnd_idx = (idx + 1) % self.__len__()
250+
rnd_idx = np.random.randint(self.__len__()) if self.mode == "train" else (idx + 1) % self.__len__()
251251
return self.__getitem__([img_width, img_height, rnd_idx, wh_ratio])
252252
return outs

‎torchocr/engine/trainer.py‎

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,6 @@ def train(self):
227227
train_reader_cost = 0.0
228228
train_batch_cost = 0.0
229229
reader_start = time.time()
230-
if self.local_rank == 0:
231-
save_ckpt(self.model, self.cfg, self.optimizer, self.lr_scheduler, epoch, global_step, best_metric,
232-
is_best=False)
233230
# eval
234231
if self.local_rank == 0 and epoch > start_eval_epoch and (epoch - start_eval_epoch) % eval_epoch_step == 0:
235232
cur_metric = self.eval()
@@ -245,14 +242,17 @@ def train(self):
245242
if cur_metric[self.eval_class.main_indicator] >= best_metric[self.eval_class.main_indicator]:
246243
best_metric.update(cur_metric)
247244
best_metric['best_epoch'] = epoch
248-
save_ckpt(self.model, self.cfg, self.optimizer, self.lr_scheduler, epoch, global_step, best_metric,
249-
is_best=True)
250-
251245
if self.writer is not None:
252246
self.writer.add_scalar(f'EVAL/best_{self.eval_class.main_indicator}',
253247
best_metric[self.eval_class.main_indicator], global_step)
248+
save_ckpt(self.model, self.cfg, self.optimizer, self.lr_scheduler, epoch, global_step, best_metric,
249+
is_best=True)
254250
best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}"
255251
self.logger.info(best_str)
252+
253+
if self.local_rank == 0:
254+
save_ckpt(self.model, self.cfg, self.optimizer, self.lr_scheduler, epoch, global_step, best_metric,
255+
is_best=False)
256256
best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}"
257257
self.logger.info(best_str)
258258
if self.writer is not None:
@@ -274,7 +274,6 @@ def eval(self):
274274
for idx, batch in enumerate(self.valid_dataloader):
275275
batch = [t.to(self.device) for t in batch]
276276
start = time.time()
277-
images = batch[0].to(self.device)
278277
if self.scaler:
279278
with torch.cuda.amp.autocast():
280279
preds = self.model(batch[0], data=batch[1:])

‎torchocr/postprocess/cls_postprocess.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ def __call__(self, preds, batch=None, *args, **kwargs):
2222
decode_out = [(label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_idxs)]
2323
if batch is None:
2424
return decode_out
25-
label = [(label_list[idx], 1.0) for idx in batch[1].numpy()]
25+
label = [(label_list[idx], 1.0) for idx in batch[1].cpu().numpy()]
2626
return decode_out, label

0 commit comments

Comments
 (0)