From 2aba93caea15eaf006f5db5cace652518273181c Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sat, 22 Feb 2025 17:52:37 +0100 Subject: [PATCH] updated upsampler --- src/aiunn/Upsampler.py | 79 +++++++++++++++++++++++++++++------------- 1 file changed, 54 insertions(+), 25 deletions(-) diff --git a/src/aiunn/Upsampler.py b/src/aiunn/Upsampler.py index cf95c8e..af13c82 100644 --- a/src/aiunn/Upsampler.py +++ b/src/aiunn/Upsampler.py @@ -1,40 +1,69 @@ +import torch import torch.nn as nn -from aiia import AIIA +from aiia import AIIA, AIIAConfig, AIIABase class Upsampler(AIIA): - def __init__(self, base_model: AIIA): - super().__init__(base_model.config) + def init(self, base_model: AIIA): + # base_model must be a fully instantiated model (with a .config attribute) + super().init(base_model.config) self.base_model = base_model # Upsample to double the spatial dimensions using bilinear interpolation self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) # Update the base model's configuration to include the upsample layer details - print(self.base_model.config) - if hasattr(self.base_model, 'config'): - # Check if layers attribute exists, if not create it - if not hasattr(self.base_model.config, 'layers'): - setattr(self.base_model.config, 'layers', []) - - # Add the upsample layer configuration - current_layers = getattr(self.base_model.config, 'layers', []) - current_layers.append({ - 'name': 'Upsample', - 'type': 'nn.Upsample', - 'scale_factor': 2, - 'mode': 'bilinear', - 'align_corners': False - }) - setattr(self.base_model.config, 'layers', current_layers) - self.config = self.base_model.config - else: - self.config = {} + if not hasattr(self.base_model.config, 'layers'): + self.base_model.config.layers = [] + + self.base_model.config.layers.append({ + 'name': 'Upsample', + 'type': 'nn.Upsample', + 'scale_factor': 2, + 'mode': 'bilinear', + 'align_corners': False + }) + self.config = self.base_model.config def forward(self, x): x = self.base_model(x) x = self.upsample(x) return x -if __name__ == "__main__": - upsampler = Upsampler.load("test2") - print("Updated configuration:", upsampler.config.__dict__) + @classmethod + def load(cls, path: str): + """ + Override the default load method: + - First, load the base model (which includes its configuration and state_dict) + - Then instantiate the Upsampler with that base model + - Finally, load the Upsampler-specific state dictionary + """ + # Load the full base model from the given path. + # (Assuming AIIABase.load is implemented to load the base model correctly.) + base_model = AIIABase.load(path) + + # Create a new instance of Upsampler using the loaded base model. + instance = cls(base_model) + + # Choose your device mapping (cuda if available, otherwise cpu) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # Load the saved state dictionary that contains weights for both the base model and upsample layer. + state_dict = torch.load(f"{path}/model.pth", map_location=device) + instance.load_state_dict(state_dict) + + return instance + +if __name__ == "main": + from aiia import AIIABase, AIIAConfig + # Create a configuration and build a base model. + config = AIIAConfig() + base_model = AIIABase("test2") + # Instantiate Upsampler from the base model (works correctly). + upsampler = Upsampler(base_model) + + # Save the model (both configuration and weights). + upsampler.save("test2") + + # Now load using the overridden load method; this will load the complete model. + upsampler_loaded = Upsampler.load("test2") + print("Updated configuration:", upsampler_loaded.config.__dict__)