diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index b4584d8..f53ee73 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -5,7 +5,7 @@ from torch.utils.data import Dataset from torchvision import transforms from aiia import AIIA import csv - +from tqdm import tqdm class UpscaleDataset(Dataset): def __init__(self, parquet_file, transform=None): @@ -68,7 +68,10 @@ with open(csv_file, mode='a', newline='') as file: for epoch in range(num_epochs): epoch_loss = 0.0 - for low_res, high_res in data_loader: + # Wrap the data_loader with tqdm for progress tracking + data_loader_with_progress = tqdm(data_loader, desc=f"Epoch {epoch + 1}") + print(f"Epoche: {epoch}") + for low_res, high_res in data_loader_with_progress: low_res = low_res.to(device) high_res = high_res.to(device)