From 933238a5302ccd8647aea2631f30513a2a9ba66c Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sat, 22 Feb 2025 17:24:02 +0100 Subject: [PATCH] added extra psampling layer [not working] --- src/aiunn/Upsampler.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 src/aiunn/Upsampler.py diff --git a/src/aiunn/Upsampler.py b/src/aiunn/Upsampler.py new file mode 100644 index 0000000..cf95c8e --- /dev/null +++ b/src/aiunn/Upsampler.py @@ -0,0 +1,40 @@ +import torch.nn as nn +from aiia import AIIA + +class Upsampler(AIIA): + def __init__(self, base_model: AIIA): + super().__init__(base_model.config) + self.base_model = base_model + + # Upsample to double the spatial dimensions using bilinear interpolation + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + + # Update the base model's configuration to include the upsample layer details + print(self.base_model.config) + if hasattr(self.base_model, 'config'): + # Check if layers attribute exists, if not create it + if not hasattr(self.base_model.config, 'layers'): + setattr(self.base_model.config, 'layers', []) + + # Add the upsample layer configuration + current_layers = getattr(self.base_model.config, 'layers', []) + current_layers.append({ + 'name': 'Upsample', + 'type': 'nn.Upsample', + 'scale_factor': 2, + 'mode': 'bilinear', + 'align_corners': False + }) + setattr(self.base_model.config, 'layers', current_layers) + self.config = self.base_model.config + else: + self.config = {} + + def forward(self, x): + x = self.base_model(x) + x = self.upsample(x) + return x + +if __name__ == "__main__": + upsampler = Upsampler.load("test2") + print("Updated configuration:", upsampler.config.__dict__)