From a7bab86d8539416397213fb0696ebc59bc0a4eb8 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Mon, 2 Jun 2025 18:33:45 +0200 Subject: [PATCH] corrected the base model_name --- src/aiunn/upsampler/aiunn.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/aiunn/upsampler/aiunn.py b/src/aiunn/upsampler/aiunn.py index a5ba3b4..978a75d 100644 --- a/src/aiunn/upsampler/aiunn.py +++ b/src/aiunn/upsampler/aiunn.py @@ -17,24 +17,24 @@ class aiuNN(PreTrainedModel): # Enhanced approach scale_factor = self.config.upsample_scale - out_channels = self.base_model.config.num_channels * (scale_factor ** 2) + out_channels = self.aiia_model.config.num_channels * (scale_factor ** 2) self.pixel_shuffle_conv = nn.Conv2d( - in_channels=self.base_model.config.hidden_size, + in_channels=self.aiia_model.config.hidden_size, out_channels=out_channels, - kernel_size=self.base_model.config.kernel_size, + kernel_size=self.aiia_model.config.kernel_size, padding=1 ) self.pixel_shuffle = nn.PixelShuffle(scale_factor) def load_base_model(self, base_model: PreTrainedModel): - self.base_model = base_model + self.aiia_model = base_model def forward(self, x): - if self.base_model is None: + if self.aiia_model is None: raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.") # Get base features - we need to extract the last hidden state if it's returned as part of a tuple/dict - base_output = self.base_model(x) + base_output = self.aiia_model(x) if isinstance(base_output, tuple): x = base_output[0] elif isinstance(base_output, dict):