From bb22d0a6da807568e7bd9d3214409455d8b75713 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Fri, 31 Jan 2025 17:34:33 +0100 Subject: [PATCH] added decoder models fits --- src/aiunn/finetune.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 6121e22..be54fe2 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -5,7 +5,7 @@ import io from torch import nn from torch.utils.data import Dataset, DataLoader import torchvision.transforms as transforms -from aiia.model import AIIABase, AIIA +from aiia.model import AIIABase, AIIA, AIIAConfig from sklearn.model_selection import train_test_split from typing import Dict, List, Union, Optional import base64 @@ -100,8 +100,8 @@ class ImageDataset(Dataset): high_res_stream.close() class SuperResolutionModel(AIIA): - def __init__(self, base_model): - super(SuperResolutionModel, self).__init__() + def __init__(self, base_model: AIIA, config: AIIAConfig): + super(SuperResolutionModel, self).__init__(config=config) # Use base model as encoder self.encoder = base_model for param in self.encoder.parameters(): @@ -393,7 +393,8 @@ class FineTuner: if __name__ == "__main__": # Load your model first - model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512")) + config = AIIAConfig.load("/root/vision/AIIA/AIIA-base-512") + model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"), config=config) trainer = FineTuner( model=model,