fixed tqdm in finetuning
This commit is contained in:
parent
0484ae01b1
commit
8c49cc7e01
|
@ -217,7 +217,7 @@ class ModelTrainer:
|
||||||
"""
|
"""
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
|
||||||
for epoch in tqdm(num_epochs):
|
for epoch in tqdm(range(num_epochs), desc="Training"):
|
||||||
print(f"Epoch {epoch+1}/{num_epochs}")
|
print(f"Epoch {epoch+1}/{num_epochs}")
|
||||||
|
|
||||||
# Train phase
|
# Train phase
|
||||||
|
@ -238,7 +238,7 @@ class ModelTrainer:
|
||||||
self.model.train()
|
self.model.train()
|
||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
|
|
||||||
for batch in tqdm(self.train_loader):
|
for batch in tqdm(self.train_loader, desc="Training"):
|
||||||
low_ress = batch['low_ress'].to(self.device)
|
low_ress = batch['low_ress'].to(self.device)
|
||||||
high_ress = batch['high_ress'].to(self.device)
|
high_ress = batch['high_ress'].to(self.device)
|
||||||
|
|
||||||
|
@ -266,7 +266,7 @@ class ModelTrainer:
|
||||||
val_loss = 0.0
|
val_loss = 0.0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in self.val_loader:
|
for batch in tqdm(self.val_loader, desc="Validation"):
|
||||||
low_ress = batch['low_ress'].to(self.device)
|
low_ress = batch['low_ress'].to(self.device)
|
||||||
high_ress = batch['high_ress'].to(self.device)
|
high_ress = batch['high_ress'].to(self.device)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue