import os import torch import torch.nn as nn import warnings from aiia.model.Model import AIIAConfig, AIIABase from transformers import PreTrainedModel from .config import aiuNNConfig import warnings class aiuNN(PreTrainedModel): config_class = aiuNNConfig def __init__(self, config: aiuNNConfig): super().__init__(config) # Pass the unified base configuration using the new parameter. self.config = config # Enhanced approach scale_factor = self.config.upsample_scale out_channels = self.base_model.config.num_channels * (scale_factor ** 2) self.pixel_shuffle_conv = nn.Conv2d( in_channels=self.base_model.config.hidden_size, out_channels=out_channels, kernel_size=self.base_model.config.kernel_size, padding=1 ) self.pixel_shuffle = nn.PixelShuffle(scale_factor) def load_base_model(self, base_model: PreTrainedModel): self.base_model = base_model def forward(self, x): if self.base_model is None: raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.") x = self.base_model(x) # Get base features x = self.pixel_shuffle_conv(x) # Expand channels for shuffling x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions return x if __name__ == "__main__": from aiia import AIIABase, AIIAConfig # Create a configuration and build a base model. config = AIIAConfig() ai_config = aiuNNConfig() base_model = AIIABase(config) # Instantiate Upsampler from the base model (works correctly). upsampler = aiuNN(config=ai_config) upsampler.load_base_model(base_model) # Save the model (both configuration and weights). upsampler.save_pretrained("aiunn") # Now load using the overridden load method; this will load the complete model. upsampler_loaded = aiuNN.from_pretrained("aiunn") print("Updated configuration:", upsampler_loaded.config.__dict__)