finetune_class #1
|
@ -5,7 +5,7 @@ from torch.utils.data import Dataset
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from aiia import AIIA
|
from aiia import AIIA
|
||||||
import csv
|
import csv
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
class UpscaleDataset(Dataset):
|
class UpscaleDataset(Dataset):
|
||||||
def __init__(self, parquet_file, transform=None):
|
def __init__(self, parquet_file, transform=None):
|
||||||
|
@ -68,7 +68,10 @@ with open(csv_file, mode='a', newline='') as file:
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
epoch_loss = 0.0
|
epoch_loss = 0.0
|
||||||
for low_res, high_res in data_loader:
|
# Wrap the data_loader with tqdm for progress tracking
|
||||||
|
data_loader_with_progress = tqdm(data_loader, desc=f"Epoch {epoch + 1}")
|
||||||
|
print(f"Epoche: {epoch}")
|
||||||
|
for low_res, high_res in data_loader_with_progress:
|
||||||
low_res = low_res.to(device)
|
low_res = low_res.to(device)
|
||||||
high_res = high_res.to(device)
|
high_res = high_res.to(device)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue