finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 3 additions and 3 deletions
Showing only changes of commit 8c49cc7e01 - Show all commits

View File

@ -217,7 +217,7 @@ class ModelTrainer:
"""
self.model.to(self.device)
for epoch in tqdm(num_epochs):
for epoch in tqdm(range(num_epochs), desc="Training"):
print(f"Epoch {epoch+1}/{num_epochs}")
# Train phase
@ -238,7 +238,7 @@ class ModelTrainer:
self.model.train()
running_loss = 0.0
for batch in tqdm(self.train_loader):
for batch in tqdm(self.train_loader, desc="Training"):
low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device)
@ -266,7 +266,7 @@ class ModelTrainer:
val_loss = 0.0
with torch.no_grad():
for batch in self.val_loader:
for batch in tqdm(self.val_loader, desc="Validation"):
low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device)