finetune_class #1
|
@ -12,7 +12,7 @@ from torchvision import transforms
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from aiia import AIIABase
|
from aiia import AIIABase
|
||||||
from upsampler import Upsampler
|
from aiunn.upsample import Upsampler
|
||||||
|
|
||||||
# Define a simple EarlyStopping class to monitor the epoch loss.
|
# Define a simple EarlyStopping class to monitor the epoch loss.
|
||||||
class EarlyStopping:
|
class EarlyStopping:
|
||||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from aiia import AIIA, AIIAConfig, AIIABase
|
from aiia import AIIA, AIIAConfig, AIIABase
|
||||||
|
|
||||||
|
|
||||||
class Upsampler(AIIA):
|
class Upsampler(AIIA):
|
||||||
def init(self, base_model: AIIA):
|
def init(self, base_model: AIIA):
|
||||||
# base_model must be a fully instantiated model (with a .config attribute)
|
# base_model must be a fully instantiated model (with a .config attribute)
|
Loading…
Reference in New Issue