diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index d785030..9729208 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -12,7 +12,7 @@ from torchvision import transforms from tqdm import tqdm from aiia import AIIABase -from upsampler import Upsampler +from aiunn.upsample import Upsampler # Define a simple EarlyStopping class to monitor the epoch loss. class EarlyStopping: diff --git a/src/aiunn/Upsampler.py b/src/aiunn/upsample.py similarity index 99% rename from src/aiunn/Upsampler.py rename to src/aiunn/upsample.py index af13c82..d2473c3 100644 --- a/src/aiunn/Upsampler.py +++ b/src/aiunn/upsample.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn from aiia import AIIA, AIIAConfig, AIIABase + class Upsampler(AIIA): def init(self, base_model: AIIA): # base_model must be a fully instantiated model (with a .config attribute)