use correct model loading

This commit is contained in:
Falko Victor Habel 2025-02-21 20:58:35 +01:00
parent eacef6af65
commit 6a9b4afd91
1 changed files with 2 additions and 2 deletions

View File

@ -3,7 +3,7 @@ import io
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from aiia import AIIA
from aiia import AIIABase
import csv
from tqdm import tqdm
@ -39,7 +39,7 @@ import torch
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
# Load the model using the AIIA.load class method (the implementation copied in your query)
model = AIIA.load(pretrained_model_path)
model = AIIABase.load(pretrained_model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
from torch import nn, optim