diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index f53ee73..80dad0b 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -3,7 +3,7 @@ import io from PIL import Image from torch.utils.data import Dataset from torchvision import transforms -from aiia import AIIA +from aiia import AIIABase import csv from tqdm import tqdm @@ -39,7 +39,7 @@ import torch pretrained_model_path = "/root/vision/AIIA/AIIA-base-512" # Load the model using the AIIA.load class method (the implementation copied in your query) -model = AIIA.load(pretrained_model_path) +model = AIIABase.load(pretrained_model_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) from torch import nn, optim