From 048a8d98616fafa55263e7cf0c5c3889f06a1ae6 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sun, 23 Feb 2025 19:56:17 +0100 Subject: [PATCH] added extra config --- src/aiunn/config.py | 36 +++++++++++++++++++++++++++++++++ src/aiunn/upsampler.py | 45 ++++++++++++++---------------------------- 2 files changed, 51 insertions(+), 30 deletions(-) create mode 100644 src/aiunn/config.py diff --git a/src/aiunn/config.py b/src/aiunn/config.py new file mode 100644 index 0000000..1a7d3a5 --- /dev/null +++ b/src/aiunn/config.py @@ -0,0 +1,36 @@ +from aiia import AIIAConfig + + +class UpsamplerConfig(AIIAConfig): + def __init__( + self, + upsample_scale: int = 2, + upsample_mode: str = 'bilinear', + upsample_align_corners: bool = False, + layers=None, + **kwargs + ): + # Initialize base configuration. + super().__init__(**kwargs) + self.layers = layers if layers is not None else [] + + # Upsampler-specific parameters. + self.upsample_scale = upsample_scale + self.upsample_mode = upsample_mode + self.upsample_align_corners = upsample_align_corners + + # Automatically add the upsample layer details. + self.add_upsample_layer() + + def add_upsample_layer(self): + upsample_layer = { + 'name': 'Upsample', + 'type': 'nn.Upsample', + 'scale_factor': self.upsample_scale, + 'mode': self.upsample_mode, + 'align_corners': self.upsample_align_corners + } + # Add the upsample layer only if not already present. + if not any(layer.get('name') == 'Upsample' for layer in self.layers): + self.layers.append(upsample_layer) + diff --git a/src/aiunn/upsampler.py b/src/aiunn/upsampler.py index d2473c3..fa3b595 100644 --- a/src/aiunn/upsampler.py +++ b/src/aiunn/upsampler.py @@ -3,27 +3,19 @@ import torch.nn as nn from aiia import AIIA, AIIAConfig, AIIABase +# Upsampler model that uses the configuration from the base model. class Upsampler(AIIA): - def init(self, base_model: AIIA): - # base_model must be a fully instantiated model (with a .config attribute) - super().init(base_model.config) + def __init__(self, base_model: AIIABase): + # Assume that base_model.config is an instance of UpsamplerConfig. + super().__init__(base_model.config) self.base_model = base_model - - # Upsample to double the spatial dimensions using bilinear interpolation - self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) - # Update the base model's configuration to include the upsample layer details - if not hasattr(self.base_model.config, 'layers'): - self.base_model.config.layers = [] - - self.base_model.config.layers.append({ - 'name': 'Upsample', - 'type': 'nn.Upsample', - 'scale_factor': 2, - 'mode': 'bilinear', - 'align_corners': False - }) - self.config = self.base_model.config + # Create the upsample layer using values from the configuration. + self.upsample = nn.Upsample( + scale_factor=self.config.upsample_scale, + mode=self.config.upsample_mode, + align_corners=self.config.upsample_align_corners + ) def forward(self, x): x = self.base_model(x) @@ -33,27 +25,20 @@ class Upsampler(AIIA): @classmethod def load(cls, path: str): """ - Override the default load method: - - First, load the base model (which includes its configuration and state_dict) - - Then instantiate the Upsampler with that base model - - Finally, load the Upsampler-specific state dictionary + Load the model: + - First, load the base model (including its configuration and state_dict). + - Then, wrap it with the Upsampler class. + - Finally, load the combined state dictionary. """ - # Load the full base model from the given path. - # (Assuming AIIABase.load is implemented to load the base model correctly.) base_model = AIIABase.load(path) - - # Create a new instance of Upsampler using the loaded base model. instance = cls(base_model) - # Choose your device mapping (cuda if available, otherwise cpu) device = 'cuda' if torch.cuda.is_available() else 'cpu' - - # Load the saved state dictionary that contains weights for both the base model and upsample layer. state_dict = torch.load(f"{path}/model.pth", map_location=device) instance.load_state_dict(state_dict) - return instance + if __name__ == "main": from aiia import AIIABase, AIIAConfig # Create a configuration and build a base model.