11from dataclasses import dataclass
2- from typing import Dict , Optional , Any
2+ from typing import Dict , Optional , Any , List
33import yaml
44
55
@@ -37,6 +37,33 @@ def from_dict(cls, data: Dict[str, Any]) -> 'ModelConfig':
3737 lora_weigths = str (data .get ('lora_weights' , None )),
3838 )
3939
40+ @dataclass
41+ class EvaluationConfig :
42+ iou_thresholds : List [float ] = None
43+ score_threshold : float = 0.25
44+ max_detections : int = 100
45+ metrics_output_dir : str = 'evaluation_results'
46+ generate_visualizations : bool = True
47+ num_visualization_samples : int = 10
48+
49+ @classmethod
50+ def from_dict (cls , data : Dict [str , Any ]) -> 'EvaluationConfig' :
51+ if data is None :
52+ return cls ()
53+
54+ iou_thresholds = data .get ('iou_thresholds' , [0.5 , 0.75 ])
55+ if isinstance (iou_thresholds , str ):
56+ iou_thresholds = [float (x .strip ()) for x in iou_thresholds .split (',' )]
57+
58+ return cls (
59+ iou_thresholds = iou_thresholds ,
60+ score_threshold = float (data .get ('score_threshold' , 0.25 )),
61+ max_detections = int (data .get ('max_detections' , 100 )),
62+ metrics_output_dir = str (data .get ('metrics_output_dir' , 'evaluation_results' )),
63+ generate_visualizations = bool (data .get ('generate_visualizations' , True )),
64+ num_visualization_samples = int (data .get ('num_visualization_samples' , 10 ))
65+ )
66+
4067@dataclass
4168class TrainingConfig :
4269 num_epochs : int = 1000
@@ -46,17 +73,25 @@ class TrainingConfig:
4673 warmup_epochs : int = 5
4774 use_lora : bool = False
4875 visualization_frequency : int = 5
76+ evaluate_during_training : bool = True
77+ evaluation_frequency : int = 10
78+ evaluation : EvaluationConfig = None
4979
5080 @classmethod
5181 def from_dict (cls , data : Dict [str , Any ]) -> 'TrainingConfig' :
82+ evaluation_config = EvaluationConfig .from_dict (data .get ('evaluation' , None ))
83+
5284 return cls (
5385 num_epochs = int (data .get ('num_epochs' , 1000 )),
5486 learning_rate = float (data .get ('learning_rate' , 1e-3 )),
5587 save_dir = str (data .get ('save_dir' , 'weights' )),
5688 save_frequency = int (data .get ('save_frequency' , 100 )),
5789 warmup_epochs = int (data .get ('warmup_epochs' , 5 )),
5890 use_lora = bool (data .get ('use_lora' , False )),
59- visualization_frequency = int (data .get ('visualization_frequency' , 5 ))
91+ visualization_frequency = int (data .get ('visualization_frequency' , 5 )),
92+ evaluate_during_training = bool (data .get ('evaluate_during_training' , True )),
93+ evaluation_frequency = int (data .get ('evaluation_frequency' , 10 )),
94+ evaluation = evaluation_config
6095 )
6196
6297class ConfigurationManager :
0 commit comments