68 lines
2.5 KiB
Python
68 lines
2.5 KiB
Python
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import warnings
|
|
from aiia.model.Model import AIIAConfig, AIIABase
|
|
from transformers import PreTrainedModel
|
|
from .config import aiuNNConfig
|
|
import warnings
|
|
|
|
|
|
class aiuNN(PreTrainedModel):
|
|
config_class = aiuNNConfig
|
|
def __init__(self, config: aiuNNConfig):
|
|
super().__init__(config)
|
|
# Pass the unified base configuration using the new parameter.
|
|
self.config = config
|
|
|
|
# Enhanced approach
|
|
scale_factor = self.config.upsample_scale
|
|
out_channels = self.config.num_channels * (scale_factor ** 2)
|
|
self.pixel_shuffle_conv = nn.Conv2d(
|
|
in_channels=self.config.hidden_size,
|
|
out_channels=out_channels,
|
|
kernel_size=self.config.kernel_size,
|
|
padding=1
|
|
)
|
|
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
|
|
|
|
def load_base_model(self, base_model: PreTrainedModel):
|
|
self.aiia_model = base_model
|
|
|
|
def forward(self, x):
|
|
if self.aiia_model is None:
|
|
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
|
|
|
|
# Get base features - we need to extract the last hidden state if it's returned as part of a tuple/dict
|
|
base_output = self.aiia_model(x)
|
|
if isinstance(base_output, tuple):
|
|
x = base_output[0]
|
|
elif isinstance(base_output, dict):
|
|
x = base_output.get('last_hidden_state', base_output.get('hidden_states'))
|
|
if x is None:
|
|
raise ValueError("Expected 'last_hidden_state' or 'hidden_states' in model output")
|
|
else:
|
|
x = base_output
|
|
|
|
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
|
|
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
|
|
return x
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from aiia import AIIABase, AIIAConfig
|
|
# Create a configuration and build a base model.
|
|
config = AIIAConfig()
|
|
ai_config = aiuNNConfig()
|
|
base_model = AIIABase(config)
|
|
# Instantiate Upsampler from the base model (works correctly).
|
|
upsampler = aiuNN(config=ai_config)
|
|
upsampler.load_base_model(base_model)
|
|
# Save the model (both configuration and weights).
|
|
upsampler.save_pretrained("aiunn")
|
|
|
|
# Now load using the overridden load method; this will load the complete model.
|
|
upsampler_loaded = aiuNN.from_pretrained("aiunn")
|
|
print("Updated configuration:", upsampler_loaded.config.__dict__)
|