added decoder models fits

This commit is contained in:
Falko Victor Habel 2025-01-31 17:34:33 +01:00
parent be74658ceb
commit bb22d0a6da
1 changed files with 5 additions and 4 deletions

View File

@ -5,7 +5,7 @@ import io
from torch import nn from torch import nn
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms import torchvision.transforms as transforms
from aiia.model import AIIABase, AIIA from aiia.model import AIIABase, AIIA, AIIAConfig
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from typing import Dict, List, Union, Optional from typing import Dict, List, Union, Optional
import base64 import base64
@ -100,8 +100,8 @@ class ImageDataset(Dataset):
high_res_stream.close() high_res_stream.close()
class SuperResolutionModel(AIIA): class SuperResolutionModel(AIIA):
def __init__(self, base_model): def __init__(self, base_model: AIIA, config: AIIAConfig):
super(SuperResolutionModel, self).__init__() super(SuperResolutionModel, self).__init__(config=config)
# Use base model as encoder # Use base model as encoder
self.encoder = base_model self.encoder = base_model
for param in self.encoder.parameters(): for param in self.encoder.parameters():
@ -393,7 +393,8 @@ class FineTuner:
if __name__ == "__main__": if __name__ == "__main__":
# Load your model first # Load your model first
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512")) config = AIIAConfig.load("/root/vision/AIIA/AIIA-base-512")
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"), config=config)
trainer = FineTuner( trainer = FineTuner(
model=model, model=model,