Skip to content

Commit 52e0daf

Browse files
updated configurations
1 parent 8eabf11 commit 52e0daf

File tree

7 files changed

+60
-24
lines changed

7 files changed

+60
-24
lines changed

‎config.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@ def from_dict(cls, data: Dict[str, Any]) -> 'DataConfig':
2727
class ModelConfig:
2828
config_path: str
2929
weights_path: str
30+
lora_weigths: str = None
3031

3132
@classmethod
3233
def from_dict(cls, data: Dict[str, Any]) -> 'ModelConfig':
3334
return cls(
3435
config_path=str(data['config_path']),
35-
weights_path=str(data['weights_path'])
36+
weights_path=str(data['weights_path']),
37+
lora_weigths=str(data.get('lora_weights', None)),
3638
)
3739

3840
@dataclass

‎configs/test_config.yaml‎

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
data:
2+
batch_size: 4
3+
num_workers: 8
4+
train_ann: multimodal-data/fashion_dataset_subset/train_annotations.csv
5+
train_dir: multimodal-data/fashion_dataset_subset/images/train
6+
val_ann: multimodal-data/fashion_dataset_subset/val_annotations.csv
7+
val_dir: multimodal-data/fashion_dataset_subset/images/val
8+
model:
9+
config_path: groundingdino/config/GroundingDINO_SwinT_OGC.py
10+
lora_weigths: None
11+
weights_path: weights/groundingdino_swint_ogc.pth
12+
training:
13+
learning_rate: 0.0001
14+
num_epochs: 200
15+
save_dir: weights
16+
save_frequency: 5
17+
use_lora: true
18+
use_lora_layers: true
19+
visualization_frequency: 5
20+
warmup_epochs: 5

‎config.yaml‎ renamed to ‎configs/train_config.yaml‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ model:
1212

1313
training:
1414
num_epochs: 200
15-
learning_rate: 1e-3
15+
learning_rate: 1e-4
1616
save_dir: "weights"
17-
save_frequency: 100
17+
save_frequency: 5
1818
warmup_epochs: 5
1919
use_lora: true
2020
use_lora_layers: true # This applies lora to only bbox pred layer and few transformer decoder layers the number of trainable parameters in this case will be < 1% of total parameters

‎groundingdino/util/inference.py‎

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from groundingdino.util.class_loss import FocalLoss
1919
import os
2020
from groundingdino.util.box_ops import box_cxcywh_to_xyxy
21+
from config import ModelConfig
2122

2223
# ----------------------------------------------------------------------------------------------------------------------
2324
# OLD API
@@ -31,16 +32,27 @@ def preprocess_caption(caption: str) -> str:
3132
return result + "."
3233

3334

34-
def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda",strict: bool =True):
35-
args = SLConfig.fromfile(model_config_path)
36-
args.device = device
37-
model = build_model(args)
38-
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
35+
def load_weights(model:torch.nn.Module,checkpoint:dict):
3936
if "model" in checkpoint.keys():
4037
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
4138
else:
4239
# The state dict is the checkpoint
43-
model.load_state_dict(clean_state_dict(checkpoint), strict=True)
40+
model.load_state_dict(clean_state_dict(checkpoint), strict=False)
41+
42+
43+
def load_model(model_config:ModelConfig, use_lora:bool= False, device: str = "cuda",strict: bool =True):
44+
args = SLConfig.fromfile(model_config.config_path)
45+
args.device = device
46+
model = build_model(args)
47+
# Loading main weights if lora is not used these are the only one required
48+
checkpoint = torch.load(model_config.weights_path, map_location="cpu")
49+
print(f"Loading main model Weights!!")
50+
load_weights(model,checkpoint)
51+
if use_lora:
52+
print(f"Loading Lora Weights!!")
53+
checkpoint = torch.load(model_config.lora_weigths, map_location="cpu")
54+
load_weights(model,checkpoint)
55+
4456
model.eval()
4557
return model
4658

‎groundingdino/util/lora.py‎

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def add_lora_to_model(model, rank=8):
6565
"key",
6666
"value",
6767
"dense",
68-
"bbox_embed",
6968
],
7069
lora_dropout=0.1,
7170
bias="none",

‎test.py‎

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from groundingdino.util.inference import load_model, load_image, predict, annotate
2-
import cv2
32
import torch
43
import torchvision.ops as ops
54
import os
65
from torchvision.ops import box_convert
76
from groundingdino.util.inference import GroundingDINOVisualizer
7+
from config import ConfigurationManager, DataConfig, ModelConfig
88

99
def apply_nms_per_phrase(image_source, boxes, logits, phrases, threshold=0.3):
1010
h, w, _ = image_source.shape
@@ -28,19 +28,17 @@ def apply_nms_per_phrase(image_source, boxes, logits, phrases, threshold=0.3):
2828
return torch.stack(nms_boxes_list), torch.stack(nms_logits_list), nms_phrases_list
2929

3030

31-
def process_image(
32-
model_config="groundingdino/config/GroundingDINO_SwinT_OGC.py",
33-
model_weights="weights/groundingdino_swint_ogc.pth",
34-
image_path="multimodal-data/fashion_dataset_subset/images/val/val_000004.jpg",
35-
text_prompt="shirt .bag .pants",
31+
def process_images(
32+
model,
33+
text_prompt,
34+
data_config,
3635
box_threshold=0.35,
3736
text_threshold=0.25
3837
):
39-
model = load_model(model_config, model_weights)
4038
visualizer = GroundingDINOVisualizer(save_dir="visualizations")
4139

42-
for img in os.listdir('multimodal-data/fashion_dataset_subset/images/val'):
43-
image_path=os.path.join('multimodal-data/fashion_dataset_subset/images/val',img)
40+
for img in os.listdir(data_config.val_dir):
41+
image_path=os.path.join(data_config.val_dir,img)
4442
image_source, image = load_image(image_path)
4543
visualizer.visualize_image(model,image,text_prompt,image_source,img)
4644

@@ -60,5 +58,10 @@ def process_image(
6058

6159

6260
if __name__ == "__main__":
63-
model_weights="weights/groundingdino_swint_ogc.pth"
64-
process_image(model_weights=model_weights)
61+
# Config file of the prediction, the model weights can be complete model weights but if use_lora is true then lora_wights should also be present see example
62+
## config file
63+
config_path="configs/test_config.yaml"
64+
text_prompt="shirt .bag .pants",
65+
data_config, model_config, training_config = ConfigurationManager.load_config(config_path)
66+
model = load_model(model_config,training_config.use_lora)
67+
process_images(model,text_prompt,data_config)

‎train.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,14 +222,14 @@ def save_checkpoint(self, path, epoch, losses, use_lora=False):
222222
if use_lora:
223223
checkpoint = {
224224
'epoch': epoch,
225-
'model_state_dict': get_lora_weights(self.model),
225+
'model': get_lora_weights(self.model),
226226
'optimizer_state_dict': self.optimizer.state_dict(),
227227
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
228228
'losses': losses,}
229229
else:
230230
checkpoint = {
231231
'epoch': epoch,
232-
'model_state_dict': self.model.state_dict(),
232+
'model': self.model.state_dict(),
233233
'ema_state_dict': self.ema_model.state_dict() if self.use_ema else None,
234234
'optimizer_state_dict': self.optimizer.state_dict(),
235235
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
@@ -322,4 +322,4 @@ def train(config_path: str, save_dir: Optional[str] = None) -> None:
322322

323323

324324
if __name__ == "__main__":
325-
train('config.yaml')
325+
train('configs/train_config.yaml')

0 commit comments

Comments
 (0)