diff --git a/src/aiunn/upsampler/aiunn.py b/src/aiunn/upsampler/aiunn.py index 71c77f3..a5ba3b4 100644 --- a/src/aiunn/upsampler/aiunn.py +++ b/src/aiunn/upsampler/aiunn.py @@ -32,7 +32,18 @@ class aiuNN(PreTrainedModel): def forward(self, x): if self.base_model is None: raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.") - x = self.base_model(x) # Get base features + + # Get base features - we need to extract the last hidden state if it's returned as part of a tuple/dict + base_output = self.base_model(x) + if isinstance(base_output, tuple): + x = base_output[0] + elif isinstance(base_output, dict): + x = base_output.get('last_hidden_state', base_output.get('hidden_states')) + if x is None: + raise ValueError("Expected 'last_hidden_state' or 'hidden_states' in model output") + else: + x = base_output + x = self.pixel_shuffle_conv(x) # Expand channels for shuffling x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions return x