From ced7e8a214ac5d47092a5cfeb1e34b548c3c5507 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sat, 19 Apr 2025 22:53:00 +0200 Subject: [PATCH] moved to transformer support, currently dropped fp16 load support --- src/aiunn/upsampler/aiunn.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/aiunn/upsampler/aiunn.py b/src/aiunn/upsampler/aiunn.py index 1dfa93e..71c77f3 100644 --- a/src/aiunn/upsampler/aiunn.py +++ b/src/aiunn/upsampler/aiunn.py @@ -9,13 +9,12 @@ import warnings class aiuNN(PreTrainedModel): - def __init__(self, base_model: PreTrainedModel, config:aiuNNConfig): - super().__init__(base_model.config) - self.base_model = base_model - + config_class = aiuNNConfig + def __init__(self, config: aiuNNConfig): + super().__init__(config) # Pass the unified base configuration using the new parameter. self.config = config - + # Enhanced approach scale_factor = self.config.upsample_scale out_channels = self.base_model.config.num_channels * (scale_factor ** 2) @@ -27,13 +26,19 @@ class aiuNN(PreTrainedModel): ) self.pixel_shuffle = nn.PixelShuffle(scale_factor) - + def load_base_model(self, base_model: PreTrainedModel): + self.base_model = base_model + def forward(self, x): + if self.base_model is None: + raise ValueError("Base model is not loaded. Call 'load_base_model' before forwarding.") x = self.base_model(x) # Get base features x = self.pixel_shuffle_conv(x) # Expand channels for shuffling x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions return x + + if __name__ == "__main__": from aiia import AIIABase, AIIAConfig # Create a configuration and build a base model. @@ -41,11 +46,11 @@ if __name__ == "__main__": ai_config = aiuNNConfig() base_model = AIIABase(config) # Instantiate Upsampler from the base model (works correctly). - upsampler = aiuNN(base_model, config=ai_config) - + upsampler = aiuNN(config=ai_config) + upsampler.load_base_model(base_model) # Save the model (both configuration and weights). upsampler.save_pretrained("aiunn") # Now load using the overridden load method; this will load the complete model. - upsampler_loaded = aiuNN.load("aiunn", precision="bf16") + upsampler_loaded = aiuNN.from_pretrained("aiunn") print("Updated configuration:", upsampler_loaded.config.__dict__)