develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 5 additions and 4 deletions
Showing only changes of commit bb22d0a6da - Show all commits

View File

@ -5,7 +5,7 @@ import io
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from aiia.model import AIIABase, AIIA
from aiia.model import AIIABase, AIIA, AIIAConfig
from sklearn.model_selection import train_test_split
from typing import Dict, List, Union, Optional
import base64
@ -100,8 +100,8 @@ class ImageDataset(Dataset):
high_res_stream.close()
class SuperResolutionModel(AIIA):
def __init__(self, base_model):
super(SuperResolutionModel, self).__init__()
def __init__(self, base_model: AIIA, config: AIIAConfig):
super(SuperResolutionModel, self).__init__(config=config)
# Use base model as encoder
self.encoder = base_model
for param in self.encoder.parameters():
@ -393,7 +393,8 @@ class FineTuner:
if __name__ == "__main__":
# Load your model first
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"))
config = AIIAConfig.load("/root/vision/AIIA/AIIA-base-512")
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"), config=config)
trainer = FineTuner(
model=model,