3838from yolo .config .config import Config , YOLOLayer
3939from yolo .model .yolo import YOLO
4040from yolo .utils .logger import logger
41- from yolo .utils .model_utils import EMA
41+ from yolo .utils .model_utils import EMA , GradientAccumulation
4242from yolo .utils .solver_utils import make_ap_table
4343
4444
@@ -68,7 +68,6 @@ def _init_progress(self, trainer: "Trainer") -> None:
6868 self ._reset_progress_bar_ids ()
6969 reconfigure (** self ._console_kwargs )
7070 self ._console = Console ()
71- self ._console .clear_live ()
7271 self .progress = YOLOCustomProgress (
7372 * self .configure_columns (trainer ),
7473 auto_refresh = False ,
@@ -105,7 +104,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx:
105104 self ._update (self .train_progress_bar_id , batch_idx + 1 )
106105 self ._update_metrics (trainer , pl_module )
107106 epoch_descript = "[cyan]Train [white]|"
108- batch_descript = "[green]Train [white]|"
107+ batch_descript = "[green]Batch [white]|"
109108 metrics = self .get_metrics (trainer , pl_module )
110109 metrics .pop ("v_num" )
111110 for metrics_name , metrics_val in metrics .items ():
@@ -238,7 +237,7 @@ def on_validation_batch_end(self, trainer: Trainer, pl_module, outputs, batch, b
238237 logger .log_image ("Prediction" , images , step = step , boxes = [log_bbox (pred_boxes )])
239238
240239
241- def setup_logger (logger_name , quite = False ):
240+ def setup_logger (logger_name , quiet = False ):
242241 class EmojiFormatter (logging .Formatter ):
243242 def format (self , record , emoji = ":high_voltage:" ):
244243 return f"{ emoji } { super ().format (record )} "
@@ -249,17 +248,17 @@ def format(self, record, emoji=":high_voltage:"):
249248 if rich_logger :
250249 rich_logger .handlers .clear ()
251250 rich_logger .addHandler (rich_handler )
252- if quite :
251+ if quiet :
253252 rich_logger .setLevel (logging .ERROR )
254253
255254 coco_logger = logging .getLogger ("faster_coco_eval.core.cocoeval" )
256255 coco_logger .setLevel (logging .ERROR )
257256
258257
259258def setup (cfg : Config ):
260- quite = hasattr (cfg , "quite " )
261- setup_logger ("lightning.fabric" , quite = quite )
262- setup_logger ("lightning.pytorch" , quite = quite )
259+ quiet = hasattr (cfg , "quiet " )
260+ setup_logger ("lightning.fabric" , quiet = quiet )
261+ setup_logger ("lightning.pytorch" , quiet = quiet )
263262
264263 def custom_wandb_log (string = "" , level = int , newline = True , repeat = True , prefix = True , silent = False ):
265264 if silent :
@@ -273,9 +272,12 @@ def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=Tru
273272
274273 progress , loggers = [], []
275274
275+ if cfg .task .task == "train" and hasattr (cfg .task .data , "equivalent_batch_size" ):
276+ progress .append (GradientAccumulation (data_cfg = cfg .task .data , scheduler_cfg = cfg .task .scheduler ))
277+
276278 if hasattr (cfg .task , "ema" ) and cfg .task .ema .enable :
277279 progress .append (EMA (cfg .task .ema .decay ))
278- if quite :
280+ if quiet :
279281 logger .setLevel (logging .ERROR )
280282 return progress , loggers , save_path
281283
@@ -336,7 +338,7 @@ def validate_log_directory(cfg: Config, exp_name: str) -> Path:
336338 )
337339
338340 save_path .mkdir (parents = True , exist_ok = True )
339- if not getattr (cfg , "quite " , False ):
341+ if not getattr (cfg , "quiet " , False ):
340342 logger .info (f"π Created log folder: [blue b u]{ save_path } [/]" )
341343 logger .addHandler (FileHandler (save_path / "output.log" ))
342344 return save_path
0 commit comments