Skip to content

Commit 9eb05cb

Browse files
add configs
1 parent c96d29a commit 9eb05cb

File tree

2 files changed

+71
-2
lines changed

2 files changed

+71
-2
lines changed

‎config.py‎

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Dict, Optional, Any
2+
from typing import Dict, Optional, Any, List
33
import yaml
44

55

@@ -37,6 +37,33 @@ def from_dict(cls, data: Dict[str, Any]) -> 'ModelConfig':
3737
lora_weigths=str(data.get('lora_weights', None)),
3838
)
3939

40+
@dataclass
41+
class EvaluationConfig:
42+
iou_thresholds: List[float] = None
43+
score_threshold: float = 0.25
44+
max_detections: int = 100
45+
metrics_output_dir: str = 'evaluation_results'
46+
generate_visualizations: bool = True
47+
num_visualization_samples: int = 10
48+
49+
@classmethod
50+
def from_dict(cls, data: Dict[str, Any]) -> 'EvaluationConfig':
51+
if data is None:
52+
return cls()
53+
54+
iou_thresholds = data.get('iou_thresholds', [0.5, 0.75])
55+
if isinstance(iou_thresholds, str):
56+
iou_thresholds = [float(x.strip()) for x in iou_thresholds.split(',')]
57+
58+
return cls(
59+
iou_thresholds=iou_thresholds,
60+
score_threshold=float(data.get('score_threshold', 0.25)),
61+
max_detections=int(data.get('max_detections', 100)),
62+
metrics_output_dir=str(data.get('metrics_output_dir', 'evaluation_results')),
63+
generate_visualizations=bool(data.get('generate_visualizations', True)),
64+
num_visualization_samples=int(data.get('num_visualization_samples', 10))
65+
)
66+
4067
@dataclass
4168
class TrainingConfig:
4269
num_epochs: int = 1000
@@ -46,17 +73,25 @@ class TrainingConfig:
4673
warmup_epochs: int = 5
4774
use_lora: bool = False
4875
visualization_frequency: int = 5
76+
evaluate_during_training: bool = True
77+
evaluation_frequency: int = 10
78+
evaluation: EvaluationConfig = None
4979

5080
@classmethod
5181
def from_dict(cls, data: Dict[str, Any]) -> 'TrainingConfig':
82+
evaluation_config = EvaluationConfig.from_dict(data.get('evaluation', None))
83+
5284
return cls(
5385
num_epochs=int(data.get('num_epochs', 1000)),
5486
learning_rate=float(data.get('learning_rate', 1e-3)),
5587
save_dir=str(data.get('save_dir', 'weights')),
5688
save_frequency=int(data.get('save_frequency', 100)),
5789
warmup_epochs=int(data.get('warmup_epochs', 5)),
5890
use_lora=bool(data.get('use_lora', False)),
59-
visualization_frequency=int(data.get('visualization_frequency', 5))
91+
visualization_frequency=int(data.get('visualization_frequency', 5)),
92+
evaluate_during_training=bool(data.get('evaluate_during_training', True)),
93+
evaluation_frequency=int(data.get('evaluation_frequency', 10)),
94+
evaluation=evaluation_config
6095
)
6196

6297
class ConfigurationManager:

‎configs/evaluation_config.yaml‎

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
data:
2+
train_dir: "path/to/train/images"
3+
train_ann: "path/to/train/annotations.json"
4+
val_dir: "path/to/val/images"
5+
val_ann: "path/to/val/annotations.json"
6+
num_workers: 4
7+
batch_size: 2
8+
9+
model:
10+
config_path: "groundingdino/config/GroundingDINO_SwinT_OGC.py"
11+
weights_path: "weights/groundingdino_swint_ogc.pth"
12+
lora_weights: null # Optional LoRA weights path
13+
14+
training:
15+
num_epochs: 100
16+
learning_rate: 2e-4
17+
save_dir: "results/model_run_1"
18+
save_frequency: 10
19+
warmup_epochs: 5
20+
use_lora: false
21+
visualization_frequency: 10
22+
23+
# Evaluation settings
24+
evaluate_during_training: true
25+
evaluation_frequency: 5 # Evaluate every 5 epochs
26+
27+
# Detailed evaluation configuration
28+
evaluation:
29+
iou_thresholds: [0.5, 0.75] # IoU thresholds for evaluation
30+
score_threshold: 0.3 # Confidence score threshold
31+
max_detections: 100 # Maximum detections per image
32+
metrics_output_dir: "evaluation_results" # Directory to save evaluation results
33+
generate_visualizations: true # Generate visualization images
34+
num_visualization_samples: 10 # Number of samples to visualize

0 commit comments

Comments
 (0)