use correct model loading

This commit is contained in:
Falko Victor Habel 2025-02-21 20:58:35 +01:00
parent eacef6af65
commit 6a9b4afd91
1 changed files with 2 additions and 2 deletions

View File

@ -3,7 +3,7 @@ import io
from PIL import Image from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
from aiia import AIIA from aiia import AIIABase
import csv import csv
from tqdm import tqdm from tqdm import tqdm
@ -39,7 +39,7 @@ import torch
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512" pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
# Load the model using the AIIA.load class method (the implementation copied in your query) # Load the model using the AIIA.load class method (the implementation copied in your query)
model = AIIA.load(pretrained_model_path) model = AIIABase.load(pretrained_model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device) model = model.to(device)
from torch import nn, optim from torch import nn, optim