From 8935fa5e13a746a20051e67b88fb3d281fc1d8f1 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Mon, 24 Feb 2025 14:57:55 +0100 Subject: [PATCH] doubled config --- src/aiunn/config.py | 3 ++- src/aiunn/upsampler.py | 60 +++++++++++++++++++++++++++++------------- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/src/aiunn/config.py b/src/aiunn/config.py index 1a7d3a5..2d5926e 100644 --- a/src/aiunn/config.py +++ b/src/aiunn/config.py @@ -1,4 +1,6 @@ from aiia import AIIAConfig +import os +import json class UpsamplerConfig(AIIAConfig): @@ -33,4 +35,3 @@ class UpsamplerConfig(AIIAConfig): # Add the upsample layer only if not already present. if not any(layer.get('name') == 'Upsample' for layer in self.layers): self.layers.append(upsample_layer) - diff --git a/src/aiunn/upsampler.py b/src/aiunn/upsampler.py index d2eeb82..d2ae0f0 100644 --- a/src/aiunn/upsampler.py +++ b/src/aiunn/upsampler.py @@ -1,9 +1,12 @@ +import os import torch import torch.nn as nn +import warnings from aiia import AIIA, AIIAConfig, AIIABase from config import UpsamplerConfig +import warnings + -# Upsampler model that uses the configuration from the base model. class Upsampler(AIIA): def __init__(self, base_model: AIIABase): super().__init__(base_model.config) @@ -14,43 +17,62 @@ class Upsampler(AIIA): mode=self.config.upsample_mode, align_corners=self.config.upsample_align_corners ) - # Conversion layer: change from 512 channels to 3 channels. - self.to_rgb = nn.Conv2d(in_channels=self.base_model.config.hidden_size, out_channels=3, kernel_size=1) + # Conversion layer: change from hidden size channels to 3 channels. + self.to_rgb = nn.Conv2d( + in_channels=self.base_model.config.hidden_size, + out_channels=3, + kernel_size=1 + ) def forward(self, x): x = self.base_model(x) x = self.upsample(x) x = self.to_rgb(x) # Ensures output has 3 channels. return x - + @classmethod - def load(cls, path: str): - """ - Load the model: - - First, load the base model (including its configuration and state_dict). - - Then, wrap it with the Upsampler class. - - Finally, load the combined state dictionary. - """ - base_model = AIIABase.load(path) - instance = cls(base_model) + def load(cls, path, precision: str = None): + # Load the configuration from disk. + config = AIIAConfig.load(path) + # Reconstruct the base model from the loaded configuration. + base_model = AIIABase(config) + # Instantiate the Upsampler using the proper base model. + upsampler = cls(base_model) + # Load state dict and handle precision conversion if needed. device = 'cuda' if torch.cuda.is_available() else 'cpu' state_dict = torch.load(f"{path}/model.pth", map_location=device) - instance.load_state_dict(state_dict) - return instance + if precision is not None: + if precision.lower() == 'fp16': + dtype = torch.float16 + elif precision.lower() == 'bf16': + if device == 'cuda' and not torch.cuda.is_bf16_supported(): + warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.") + dtype = torch.float16 + else: + dtype = torch.bfloat16 + else: + raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.") + + for key, param in state_dict.items(): + if torch.is_tensor(param): + state_dict[key] = param.to(dtype) + upsampler.load_state_dict(state_dict) + return upsampler -if __name__ == "main": + +if __name__ == "__main__": from aiia import AIIABase, AIIAConfig # Create a configuration and build a base model. config = AIIAConfig() - base_model = AIIABase("test2") + base_model = AIIABase(config) # Instantiate Upsampler from the base model (works correctly). upsampler = Upsampler(base_model) # Save the model (both configuration and weights). - upsampler.save("test2") + upsampler.save("hehe") # Now load using the overridden load method; this will load the complete model. - upsampler_loaded = Upsampler.load("test2") + upsampler_loaded = Upsampler.load("hehe", precision="bf16") print("Updated configuration:", upsampler_loaded.config.__dict__)