From 1a5d98d8bb0d2acb3bc333ff934c41be5796a1bd Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Wed, 26 Mar 2025 13:28:38 +0100 Subject: [PATCH] corrected loading / saving model --- src/aiunn/upsampler/aiunn.py | 101 ++++++++++++++++++++++++++++++----- 1 file changed, 87 insertions(+), 14 deletions(-) diff --git a/src/aiunn/upsampler/aiunn.py b/src/aiunn/upsampler/aiunn.py index ca7ba2b..970dad1 100644 --- a/src/aiunn/upsampler/aiunn.py +++ b/src/aiunn/upsampler/aiunn.py @@ -20,10 +20,10 @@ class aiuNN(AIIA): mode=self.config.upsample_mode, align_corners=self.config.upsample_align_corners ) - # Conversion layer: change from hidden size channels to 3 channels. + # Conversion layer: change from hidden size channels to number of channels from the config. self.to_rgb = nn.Conv2d( in_channels=self.base_model.config.hidden_size, - out_channels=3, + out_channels=self.base_model.config.num_channels, kernel_size=1 ) @@ -35,17 +35,87 @@ class aiuNN(AIIA): return x @classmethod - def load(cls, path, precision: str = None): - # Load the configuration from disk. - config = AIIAConfig.load(path) - # Reconstruct the base model from the loaded configuration. - base_model = AIIABase(config) - # Instantiate the Upsampler using the proper base model. - upsampler = cls(base_model) + def load(cls, path, precision: str = None, **kwargs): + """ + Load a aiuNN model from disk with automatic detection of base model type. - # Load state dict and handle precision conversion if needed. + Args: + path (str): Directory containing the stored configuration and model parameters. + precision (str, optional): Desired precision for the model's parameters. + **kwargs: Additional keyword arguments to override configuration parameters. + + Returns: + An instance of aiuNN with loaded weights. + """ + # Load the configuration + config = aiuNNConfig.load(path) + + # Determine the device device = 'cuda' if torch.cuda.is_available() else 'cpu' - state_dict = torch.load(f"{path}/model.pth", map_location=device) + + # Load the state dictionary + state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device) + + # Import all model types + from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive + + # Helper function to detect base class type from key patterns + def detect_base_class_type(keys_prefix): + if any(f"{keys_prefix}.shared_layer" in key for key in state_dict.keys()): + return AIIABaseShared + else: + return AIIABase + + # Detect base model type + base_model = None + + # Check for AIIAmoe with multiple experts + if any("base_model.experts" in key for key in state_dict.keys()): + # Count the number of experts + max_expert_idx = -1 + for key in state_dict.keys(): + if "base_model.experts." in key: + try: + parts = key.split("base_model.experts.")[1].split(".") + expert_idx = int(parts[0]) + max_expert_idx = max(max_expert_idx, expert_idx) + except (ValueError, IndexError): + pass + + if max_expert_idx >= 0: + # Determine the type of base_cnn each expert is using + base_class_for_experts = detect_base_class_type("base_model.experts.0.base_cnn") + + # Create AIIAmoe with the detected expert count and base class + base_model = AIIAmoe(config, num_experts=max_expert_idx+1, base_class=base_class_for_experts, **kwargs) + + # Check for AIIAchunked or AIIArecursive + elif any("base_model.chunked_cnn" in key for key in state_dict.keys()): + if any("recursion_depth" in key for key in state_dict.keys()): + # This is an AIIArecursive model + base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn") + base_model = AIIArecursive(config, base_class=base_class, **kwargs) + else: + # This is an AIIAchunked model + base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn") + base_model = AIIAchunked(config, base_class=base_class, **kwargs) + + # Check for AIIAExpert + elif any("base_model.base_cnn" in key for key in state_dict.keys()): + # Determine which base class the expert is using + base_class = detect_base_class_type("base_model.base_cnn") + base_model = AIIAExpert(config, base_class=base_class, **kwargs) + + # If none of the above, use AIIABase or AIIABaseShared directly + else: + base_class = detect_base_class_type("base_model") + base_model = base_class(config, **kwargs) + + # Create the aiuNN model with the detected base model + model = cls(base_model) + + # Handle precision conversion + dtype = None if precision is not None: if precision.lower() == 'fp16': dtype = torch.float16 @@ -57,12 +127,15 @@ class aiuNN(AIIA): dtype = torch.bfloat16 else: raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.") - + + if dtype is not None: for key, param in state_dict.items(): if torch.is_tensor(param): state_dict[key] = param.to(dtype) - upsampler.load_state_dict(state_dict) - return upsampler + + # Load the state dict + model.load_state_dict(state_dict) + return model