finetune_class #1
|
@ -5,7 +5,7 @@ import io
|
|||
from torch import nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torchvision.transforms as transforms
|
||||
from aiia.model import AIIABase, AIIA
|
||||
from aiia.model import AIIABase, AIIA, AIIAConfig
|
||||
from sklearn.model_selection import train_test_split
|
||||
from typing import Dict, List, Union, Optional
|
||||
import base64
|
||||
|
@ -100,8 +100,8 @@ class ImageDataset(Dataset):
|
|||
high_res_stream.close()
|
||||
|
||||
class SuperResolutionModel(AIIA):
|
||||
def __init__(self, base_model):
|
||||
super(SuperResolutionModel, self).__init__()
|
||||
def __init__(self, base_model: AIIA, config: AIIAConfig):
|
||||
super(SuperResolutionModel, self).__init__(config=config)
|
||||
# Use base model as encoder
|
||||
self.encoder = base_model
|
||||
for param in self.encoder.parameters():
|
||||
|
@ -393,7 +393,8 @@ class FineTuner:
|
|||
|
||||
if __name__ == "__main__":
|
||||
# Load your model first
|
||||
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"))
|
||||
config = AIIAConfig.load("/root/vision/AIIA/AIIA-base-512")
|
||||
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"), config=config)
|
||||
|
||||
trainer = FineTuner(
|
||||
model=model,
|
||||
|
|
Loading…
Reference in New Issue