finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 3 additions and 3 deletions
Showing only changes of commit 81ceddff3b - Show all commits

View File

@ -8,7 +8,7 @@ import csv
from tqdm import tqdm from tqdm import tqdm
import base64 import base64
from torch.amp import autocast, GradScaler from torch.amp import autocast, GradScaler
import torch
class UpscaleDataset(Dataset): class UpscaleDataset(Dataset):
def __init__(self, parquet_files: list, transform=None): def __init__(self, parquet_files: list, transform=None):
@ -87,7 +87,7 @@ pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
# Load the model using the AIIA.load class method (the implementation copied in your query) # Load the model using the AIIA.load class method (the implementation copied in your query)
model = AIIABase.load(pretrained_model_path) model = AIIABase.load(pretrained_model_path)
device = 'cpu' #torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device) model = model.to(device)
from torch import nn, optim from torch import nn, optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -126,7 +126,7 @@ for epoch in range(num_epochs):
optimizer.zero_grad() optimizer.zero_grad()
# Use automatic mixed precision context # Use automatic mixed precision context
with autocast(): with autocast(device_type=device):
outputs = model(low_res) outputs = model(low_res)
loss = criterion(outputs, high_res) loss = criterion(outputs, high_res)