develop #4
|
@ -5,7 +5,7 @@ import io
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from aiia.model import AIIABase, AIIA
|
from aiia.model import AIIABase, AIIA, AIIAConfig
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from typing import Dict, List, Union, Optional
|
from typing import Dict, List, Union, Optional
|
||||||
import base64
|
import base64
|
||||||
|
@ -100,8 +100,8 @@ class ImageDataset(Dataset):
|
||||||
high_res_stream.close()
|
high_res_stream.close()
|
||||||
|
|
||||||
class SuperResolutionModel(AIIA):
|
class SuperResolutionModel(AIIA):
|
||||||
def __init__(self, base_model):
|
def __init__(self, base_model: AIIA, config: AIIAConfig):
|
||||||
super(SuperResolutionModel, self).__init__()
|
super(SuperResolutionModel, self).__init__(config=config)
|
||||||
# Use base model as encoder
|
# Use base model as encoder
|
||||||
self.encoder = base_model
|
self.encoder = base_model
|
||||||
for param in self.encoder.parameters():
|
for param in self.encoder.parameters():
|
||||||
|
@ -393,7 +393,8 @@ class FineTuner:
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Load your model first
|
# Load your model first
|
||||||
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"))
|
config = AIIAConfig.load("/root/vision/AIIA/AIIA-base-512")
|
||||||
|
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"), config=config)
|
||||||
|
|
||||||
trainer = FineTuner(
|
trainer = FineTuner(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
Loading…
Reference in New Issue