@@ -518,18 +518,18 @@ def train(
518518
519519 # if we are using lora then it is takien care of while setting up lora
520520 if not use_lora :
521+ print (f"Freezing most of model except few layers!! " )
521522 freeze_model_layers (model )
522523
523524 print_frozen_status (model )
524525
525526 for epoch in range (num_epochs ):
526527 ## Do visualization on val dataset passed as input loop through it
527- ## visualize after every 5 epochs
528528 if epoch % 5 == 0 :
529529 visualizer .visualize_epoch (model , val_loader , epoch , trainer .prepare_batch )
530530
531531 epoch_losses = defaultdict (list )
532-
532+
533533 for batch_idx , batch in enumerate (train_loader ):
534534
535535 losses = trainer .train_step (batch )
@@ -542,24 +542,20 @@ def train(
542542 loss_str = ", " .join (f"{ k } : { v :.4f} " for k , v in losses .items ())
543543 print (f"Epoch { epoch + 1 } /{ num_epochs } , Batch { batch_idx } /{ len (train_loader )} , { loss_str } " )
544544 print (f"Learning rate: { trainer .optimizer .param_groups [0 ]['lr' ]:.6f} " )
545-
545+ break
546+
546547
547- # Compute epoch averages
548548 avg_losses = {k : sum (v )/ len (v ) for k , v in epoch_losses .items ()}
549549 print (f"Epoch { epoch + 1 } complete. Average losses:" , ", " .join (f"{ k } : { v :.4f} " for k , v in avg_losses .items ()))
550550
551551 if (epoch + 1 ) % save_frequency == 0 :
552- #continue
553552 trainer .save_checkpoint (
554553 os .path .join (save_dir , f'checkpoint_epoch_{ epoch + 1 } .pth' ),
555554 epoch ,
556- avg_losses
555+ avg_losses ,
556+ use_lora = use_lora
557557 )
558- #trainer.save_checkpoint(
559- # os.path.join(save_dir, f'checkpoint.pth'),
560- # epoch,
561- # avg_losses
562- #)
558+
563559
564560if __name__ == "__main__" :
565561
0 commit comments