diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index f26ef93..545de8c 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -217,7 +217,7 @@ class ModelTrainer: """ self.model.to(self.device) - for epoch in tqdm(num_epochs): + for epoch in tqdm(range(num_epochs), desc="Training"): print(f"Epoch {epoch+1}/{num_epochs}") # Train phase @@ -238,7 +238,7 @@ class ModelTrainer: self.model.train() running_loss = 0.0 - for batch in tqdm(self.train_loader): + for batch in tqdm(self.train_loader, desc="Training"): low_ress = batch['low_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device) @@ -266,7 +266,7 @@ class ModelTrainer: val_loss = 0.0 with torch.no_grad(): - for batch in self.val_loader: + for batch in tqdm(self.val_loader, desc="Validation"): low_ress = batch['low_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device)