diff --git a/src/aiunn/upsampler.py b/src/aiunn/upsampler.py index fa3b595..efc5c6f 100644 --- a/src/aiunn/upsampler.py +++ b/src/aiunn/upsampler.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from aiia import AIIA, AIIAConfig, AIIABase - +from config import UpsamplerConfig # Upsampler model that uses the configuration from the base model. class Upsampler(AIIA): @@ -9,7 +9,7 @@ class Upsampler(AIIA): # Assume that base_model.config is an instance of UpsamplerConfig. super().__init__(base_model.config) self.base_model = base_model - + self.config = UpsamplerConfig(self.base_model.config) # Create the upsample layer using values from the configuration. self.upsample = nn.Upsample( scale_factor=self.config.upsample_scale,