Skip to content

Commit a98af79

Browse files
updated lora saving state dict
1 parent 1def1c2 commit a98af79

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

‎groundingdino/util/lora.py‎

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,16 @@
44

55

66
def get_lora_weights(model):
7-
lora_state_dict = get_peft_model_state_dict(model)
7+
## This needs lora config which is not part of grouding dino model state dict tight now hack to get lora weights
8+
#lora_state_dict = get_peft_model_state_dict(model)
9+
# Collect LoRA parameters manually
10+
lora_state_dict = {}
11+
for name, param in model.named_parameters():
12+
if 'lora_' in name:
13+
lora_state_dict[name] = param.data.cpu()
14+
# If no LoRA weights found, print warning
15+
if not lora_state_dict:
16+
print("No LoRA weights found in the model.")
817
return lora_state_dict
918

1019
def load_lora_weights(model, weights_path):

‎train.py‎

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -518,18 +518,18 @@ def train(
518518

519519
# if we are using lora then it is takien care of while setting up lora
520520
if not use_lora:
521+
print(f"Freezing most of model except few layers!! ")
521522
freeze_model_layers(model)
522523

523524
print_frozen_status(model)
524525

525526
for epoch in range(num_epochs):
526527
## Do visualization on val dataset passed as input loop through it
527-
## visualize after every 5 epochs
528528
if epoch % 5 == 0:
529529
visualizer.visualize_epoch(model, val_loader, epoch, trainer.prepare_batch)
530530

531531
epoch_losses = defaultdict(list)
532-
532+
533533
for batch_idx, batch in enumerate(train_loader):
534534

535535
losses = trainer.train_step(batch)
@@ -542,24 +542,20 @@ def train(
542542
loss_str = ", ".join(f"{k}: {v:.4f}" for k, v in losses.items())
543543
print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, {loss_str}")
544544
print(f"Learning rate: {trainer.optimizer.param_groups[0]['lr']:.6f}")
545-
545+
break
546+
546547

547-
# Compute epoch averages
548548
avg_losses = {k: sum(v)/len(v) for k, v in epoch_losses.items()}
549549
print(f"Epoch {epoch+1} complete. Average losses:", ", ".join(f"{k}: {v:.4f}" for k, v in avg_losses.items()))
550550

551551
if (epoch + 1) % save_frequency == 0:
552-
#continue
553552
trainer.save_checkpoint(
554553
os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth'),
555554
epoch,
556-
avg_losses
555+
avg_losses,
556+
use_lora=use_lora
557557
)
558-
#trainer.save_checkpoint(
559-
# os.path.join(save_dir, f'checkpoint.pth'),
560-
# epoch,
561-
# avg_losses
562-
#)
558+
563559

564560
if __name__ == "__main__":
565561

0 commit comments

Comments
 (0)