fixed model outputs to not run in an infinte loop
This commit is contained in:
parent
3b140d559c
commit
a725dd4539
|
@ -32,7 +32,18 @@ class aiuNN(PreTrainedModel):
|
|||
def forward(self, x):
|
||||
if self.base_model is None:
|
||||
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
|
||||
x = self.base_model(x) # Get base features
|
||||
|
||||
# Get base features - we need to extract the last hidden state if it's returned as part of a tuple/dict
|
||||
base_output = self.base_model(x)
|
||||
if isinstance(base_output, tuple):
|
||||
x = base_output[0]
|
||||
elif isinstance(base_output, dict):
|
||||
x = base_output.get('last_hidden_state', base_output.get('hidden_states'))
|
||||
if x is None:
|
||||
raise ValueError("Expected 'last_hidden_state' or 'hidden_states' in model output")
|
||||
else:
|
||||
x = base_output
|
||||
|
||||
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
|
||||
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
|
||||
return x
|
||||
|
|
Loading…
Reference in New Issue