diff --git a/src/aiunn/upsampler/aiunn.py b/src/aiunn/upsampler/aiunn.py index a5ba3b4..978a75d 100644 --- a/src/aiunn/upsampler/aiunn.py +++ b/src/aiunn/upsampler/aiunn.py @@ -17,24 +17,24 @@ class aiuNN(PreTrainedModel): # Enhanced approach scale_factor = self.config.upsample_scale - out_channels = self.base_model.config.num_channels * (scale_factor ** 2) + out_channels = self.aiia_model.config.num_channels * (scale_factor ** 2) self.pixel_shuffle_conv = nn.Conv2d( - in_channels=self.base_model.config.hidden_size, + in_channels=self.aiia_model.config.hidden_size, out_channels=out_channels, - kernel_size=self.base_model.config.kernel_size, + kernel_size=self.aiia_model.config.kernel_size, padding=1 ) self.pixel_shuffle = nn.PixelShuffle(scale_factor) def load_base_model(self, base_model: PreTrainedModel): - self.base_model = base_model + self.aiia_model = base_model def forward(self, x): - if self.base_model is None: + if self.aiia_model is None: raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.") # Get base features - we need to extract the last hidden state if it's returned as part of a tuple/dict - base_output = self.base_model(x) + base_output = self.aiia_model(x) if isinstance(base_output, tuple): x = base_output[0] elif isinstance(base_output, dict):