added decoder models fits

This commit is contained in:
Falko Victor Habel 2025-01-31 17:34:33 +01:00
parent be74658ceb
commit bb22d0a6da
1 changed files with 5 additions and 4 deletions

View File

@ -5,7 +5,7 @@ import io
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from aiia.model import AIIABase, AIIA
from aiia.model import AIIABase, AIIA, AIIAConfig
from sklearn.model_selection import train_test_split
from typing import Dict, List, Union, Optional
import base64
@ -100,8 +100,8 @@ class ImageDataset(Dataset):
high_res_stream.close()
class SuperResolutionModel(AIIA):
def __init__(self, base_model):
super(SuperResolutionModel, self).__init__()
def __init__(self, base_model: AIIA, config: AIIAConfig):
super(SuperResolutionModel, self).__init__(config=config)
# Use base model as encoder
self.encoder = base_model
for param in self.encoder.parameters():
@ -393,7 +393,8 @@ class FineTuner:
if __name__ == "__main__":
# Load your model first
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"))
config = AIIAConfig.load("/root/vision/AIIA/AIIA-base-512")
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"), config=config)
trainer = FineTuner(
model=model,