fixed tqdm in finetuning

This commit is contained in:
Falko Victor Habel 2025-01-30 13:51:49 +01:00
parent 0484ae01b1
commit 8c49cc7e01
1 changed files with 3 additions and 3 deletions

View File

@ -217,7 +217,7 @@ class ModelTrainer:
""" """
self.model.to(self.device) self.model.to(self.device)
for epoch in tqdm(num_epochs): for epoch in tqdm(range(num_epochs), desc="Training"):
print(f"Epoch {epoch+1}/{num_epochs}") print(f"Epoch {epoch+1}/{num_epochs}")
# Train phase # Train phase
@ -238,7 +238,7 @@ class ModelTrainer:
self.model.train() self.model.train()
running_loss = 0.0 running_loss = 0.0
for batch in tqdm(self.train_loader): for batch in tqdm(self.train_loader, desc="Training"):
low_ress = batch['low_ress'].to(self.device) low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device)
@ -266,7 +266,7 @@ class ModelTrainer:
val_loss = 0.0 val_loss = 0.0
with torch.no_grad(): with torch.no_grad():
for batch in self.val_loader: for batch in tqdm(self.val_loader, desc="Validation"):
low_ress = batch['low_ress'].to(self.device) low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device)