aiuNN/src/aiunn/upsampler/aiunn.py

57 lines
2.0 KiB
Python

import os
import torch
import torch.nn as nn
import warnings
from aiia.model.Model import AIIAConfig, AIIABase
from transformers import PreTrainedModel
from .config import aiuNNConfig
import warnings
class aiuNN(PreTrainedModel):
config_class = aiuNNConfig
def __init__(self, config: aiuNNConfig):
super().__init__(config)
# Pass the unified base configuration using the new parameter.
self.config = config
# Enhanced approach
scale_factor = self.config.upsample_scale
out_channels = self.base_model.config.num_channels * (scale_factor ** 2)
self.pixel_shuffle_conv = nn.Conv2d(
in_channels=self.base_model.config.hidden_size,
out_channels=out_channels,
kernel_size=self.base_model.config.kernel_size,
padding=1
)
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
def load_base_model(self, base_model: PreTrainedModel):
self.base_model = base_model
def forward(self, x):
if self.base_model is None:
raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.")
x = self.base_model(x) # Get base features
x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
return x
if __name__ == "__main__":
from aiia import AIIABase, AIIAConfig
# Create a configuration and build a base model.
config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config)
# Instantiate Upsampler from the base model (works correctly).
upsampler = aiuNN(config=ai_config)
upsampler.load_base_model(base_model)
# Save the model (both configuration and weights).
upsampler.save_pretrained("aiunn")
# Now load using the overridden load method; this will load the complete model.
upsampler_loaded = aiuNN.from_pretrained("aiunn")
print("Updated configuration:", upsampler_loaded.config.__dict__)