develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 5 additions and 4 deletions
Showing only changes of commit bb22d0a6da - Show all commits

View File

@ -5,7 +5,7 @@ import io
from torch import nn from torch import nn
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms import torchvision.transforms as transforms
from aiia.model import AIIABase, AIIA from aiia.model import AIIABase, AIIA, AIIAConfig
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from typing import Dict, List, Union, Optional from typing import Dict, List, Union, Optional
import base64 import base64
@ -100,8 +100,8 @@ class ImageDataset(Dataset):
high_res_stream.close() high_res_stream.close()
class SuperResolutionModel(AIIA): class SuperResolutionModel(AIIA):
def __init__(self, base_model): def __init__(self, base_model: AIIA, config: AIIAConfig):
super(SuperResolutionModel, self).__init__() super(SuperResolutionModel, self).__init__(config=config)
# Use base model as encoder # Use base model as encoder
self.encoder = base_model self.encoder = base_model
for param in self.encoder.parameters(): for param in self.encoder.parameters():
@ -393,7 +393,8 @@ class FineTuner:
if __name__ == "__main__": if __name__ == "__main__":
# Load your model first # Load your model first
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512")) config = AIIAConfig.load("/root/vision/AIIA/AIIA-base-512")
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"), config=config)
trainer = FineTuner( trainer = FineTuner(
model=model, model=model,