develop #4
|
@ -217,7 +217,7 @@ class ModelTrainer:
|
|||
"""
|
||||
self.model.to(self.device)
|
||||
|
||||
for epoch in tqdm(num_epochs):
|
||||
for epoch in tqdm(range(num_epochs), desc="Training"):
|
||||
print(f"Epoch {epoch+1}/{num_epochs}")
|
||||
|
||||
# Train phase
|
||||
|
@ -238,7 +238,7 @@ class ModelTrainer:
|
|||
self.model.train()
|
||||
running_loss = 0.0
|
||||
|
||||
for batch in tqdm(self.train_loader):
|
||||
for batch in tqdm(self.train_loader, desc="Training"):
|
||||
low_ress = batch['low_ress'].to(self.device)
|
||||
high_ress = batch['high_ress'].to(self.device)
|
||||
|
||||
|
@ -266,7 +266,7 @@ class ModelTrainer:
|
|||
val_loss = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in self.val_loader:
|
||||
for batch in tqdm(self.val_loader, desc="Validation"):
|
||||
low_ress = batch['low_ress'].to(self.device)
|
||||
high_ress = batch['high_ress'].to(self.device)
|
||||
|
||||
|
|
Loading…
Reference in New Issue