added tqdm and removed doubles
This commit is contained in:
parent
34c547fb23
commit
0484ae01b1
|
@ -9,14 +9,9 @@ from aiia.model import AIIABase, AIIA
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from typing import Dict, List, Union, Optional
|
from typing import Dict, List, Union, Optional
|
||||||
import base64
|
import base64
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from PIL import Image, ImageFile
|
|
||||||
import io
|
|
||||||
import base64
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
class ImageDataset(Dataset):
|
class ImageDataset(Dataset):
|
||||||
def __init__(self, dataframe, transform=None):
|
def __init__(self, dataframe, transform=None):
|
||||||
|
@ -222,7 +217,7 @@ class ModelTrainer:
|
||||||
"""
|
"""
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in tqdm(num_epochs):
|
||||||
print(f"Epoch {epoch+1}/{num_epochs}")
|
print(f"Epoch {epoch+1}/{num_epochs}")
|
||||||
|
|
||||||
# Train phase
|
# Train phase
|
||||||
|
@ -243,7 +238,7 @@ class ModelTrainer:
|
||||||
self.model.train()
|
self.model.train()
|
||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
|
|
||||||
for batch in self.train_loader:
|
for batch in tqdm(self.train_loader):
|
||||||
low_ress = batch['low_ress'].to(self.device)
|
low_ress = batch['low_ress'].to(self.device)
|
||||||
high_ress = batch['high_ress'].to(self.device)
|
high_ress = batch['high_ress'].to(self.device)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue