import torch.nn as nn from aiia import AIIA class Upsampler(AIIA): def __init__(self, base_model: AIIA): super().__init__(base_model.config) self.base_model = base_model # Upsample to double the spatial dimensions using bilinear interpolation self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) # Update the base model's configuration to include the upsample layer details print(self.base_model.config) if hasattr(self.base_model, 'config'): # Check if layers attribute exists, if not create it if not hasattr(self.base_model.config, 'layers'): setattr(self.base_model.config, 'layers', []) # Add the upsample layer configuration current_layers = getattr(self.base_model.config, 'layers', []) current_layers.append({ 'name': 'Upsample', 'type': 'nn.Upsample', 'scale_factor': 2, 'mode': 'bilinear', 'align_corners': False }) setattr(self.base_model.config, 'layers', current_layers) self.config = self.base_model.config else: self.config = {} def forward(self, x): x = self.base_model(x) x = self.upsample(x) return x if __name__ == "__main__": upsampler = Upsampler.load("test2") print("Updated configuration:", upsampler.config.__dict__)