finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 5 additions and 2 deletions
Showing only changes of commit eacef6af65 - Show all commits

View File

@ -5,7 +5,7 @@ from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
from aiia import AIIA from aiia import AIIA
import csv import csv
from tqdm import tqdm
class UpscaleDataset(Dataset): class UpscaleDataset(Dataset):
def __init__(self, parquet_file, transform=None): def __init__(self, parquet_file, transform=None):
@ -68,7 +68,10 @@ with open(csv_file, mode='a', newline='') as file:
for epoch in range(num_epochs): for epoch in range(num_epochs):
epoch_loss = 0.0 epoch_loss = 0.0
for low_res, high_res in data_loader: # Wrap the data_loader with tqdm for progress tracking
data_loader_with_progress = tqdm(data_loader, desc=f"Epoch {epoch + 1}")
print(f"Epoche: {epoch}")
for low_res, high_res in data_loader_with_progress:
low_res = low_res.to(device) low_res = low_res.to(device)
high_res = high_res.to(device) high_res = high_res.to(device)