develop #18
|
@ -32,7 +32,18 @@ class aiuNN(PreTrainedModel):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.base_model is None:
|
if self.base_model is None:
|
||||||
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
|
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
|
||||||
x = self.base_model(x) # Get base features
|
|
||||||
|
# Get base features - we need to extract the last hidden state if it's returned as part of a tuple/dict
|
||||||
|
base_output = self.base_model(x)
|
||||||
|
if isinstance(base_output, tuple):
|
||||||
|
x = base_output[0]
|
||||||
|
elif isinstance(base_output, dict):
|
||||||
|
x = base_output.get('last_hidden_state', base_output.get('hidden_states'))
|
||||||
|
if x is None:
|
||||||
|
raise ValueError("Expected 'last_hidden_state' or 'hidden_states' in model output")
|
||||||
|
else:
|
||||||
|
x = base_output
|
||||||
|
|
||||||
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
|
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
|
||||||
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
|
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
|
||||||
return x
|
return x
|
||||||
|
|
Loading…
Reference in New Issue