diff --git a/src/aiunn/upsampler/aiunn.py b/src/aiunn/upsampler/aiunn.py index ecb3c21..1dfa93e 100644 --- a/src/aiunn/upsampler/aiunn.py +++ b/src/aiunn/upsampler/aiunn.py @@ -2,13 +2,14 @@ import os import torch import torch.nn as nn import warnings -from aiia.model.Model import AIIA, AIIAConfig, AIIABase +from aiia.model.Model import AIIAConfig, AIIABase +from transformers import PreTrainedModel from .config import aiuNNConfig import warnings -class aiuNN(AIIA): - def __init__(self, base_model: AIIA, config:aiuNNConfig): +class aiuNN(PreTrainedModel): + def __init__(self, base_model: PreTrainedModel, config:aiuNNConfig): super().__init__(base_model.config) self.base_model = base_model @@ -33,112 +34,6 @@ class aiuNN(AIIA): x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions return x - - @classmethod - def load(cls, path, precision: str = None, **kwargs): - """ - Load a aiuNN model from disk with automatic detection of base model type. - - Args: - path (str): Directory containing the stored configuration and model parameters. - precision (str, optional): Desired precision for the model's parameters. - **kwargs: Additional keyword arguments to override configuration parameters. - - Returns: - An instance of aiuNN with loaded weights. - """ - # Load the configuration - config = aiuNNConfig.load(path) - - # Determine the device - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - # Load the state dictionary - state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device) - - # Import all model types - from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive - - # Helper function to detect base class type from key patterns - def detect_base_class_type(keys_prefix): - if any(f"{keys_prefix}.shared_layer" in key for key in state_dict.keys()): - return AIIABaseShared - else: - return AIIABase - - # Detect base model type - base_model = None - - # Check for AIIAmoe with multiple experts - if any("base_model.experts" in key for key in state_dict.keys()): - # Count the number of experts - max_expert_idx = -1 - for key in state_dict.keys(): - if "base_model.experts." in key: - try: - parts = key.split("base_model.experts.")[1].split(".") - expert_idx = int(parts[0]) - max_expert_idx = max(max_expert_idx, expert_idx) - except (ValueError, IndexError): - pass - - if max_expert_idx >= 0: - # Determine the type of base_cnn each expert is using - base_class_for_experts = detect_base_class_type("base_model.experts.0.base_cnn") - - # Create AIIAmoe with the detected expert count and base class - base_model = AIIAmoe(config, num_experts=max_expert_idx+1, base_class=base_class_for_experts, **kwargs) - - # Check for AIIAchunked or AIIArecursive - elif any("base_model.chunked_cnn" in key for key in state_dict.keys()): - if any("recursion_depth" in key for key in state_dict.keys()): - # This is an AIIArecursive model - base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn") - base_model = AIIArecursive(config, base_class=base_class, **kwargs) - else: - # This is an AIIAchunked model - base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn") - base_model = AIIAchunked(config, base_class=base_class, **kwargs) - - # Check for AIIAExpert - elif any("base_model.base_cnn" in key for key in state_dict.keys()): - # Determine which base class the expert is using - base_class = detect_base_class_type("base_model.base_cnn") - base_model = AIIAExpert(config, base_class=base_class, **kwargs) - - # If none of the above, use AIIABase or AIIABaseShared directly - else: - base_class = detect_base_class_type("base_model") - base_model = base_class(config, **kwargs) - - # Create the aiuNN model with the detected base model - model = cls(base_model, config=base_model.config) - - # Handle precision conversion - dtype = None - if precision is not None: - if precision.lower() == 'fp16': - dtype = torch.float16 - elif precision.lower() == 'bf16': - if device == 'cuda' and not torch.cuda.is_bf16_supported(): - warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.") - dtype = torch.float16 - else: - dtype = torch.bfloat16 - else: - raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.") - - if dtype is not None: - for key, param in state_dict.items(): - if torch.is_tensor(param): - state_dict[key] = param.to(dtype) - - # Load the state dict - model.load_state_dict(state_dict) - return model - - - if __name__ == "__main__": from aiia import AIIABase, AIIAConfig # Create a configuration and build a base model. @@ -149,7 +44,7 @@ if __name__ == "__main__": upsampler = aiuNN(base_model, config=ai_config) # Save the model (both configuration and weights). - upsampler.save("aiunn") + upsampler.save_pretrained("aiunn") # Now load using the overridden load method; this will load the complete model. upsampler_loaded = aiuNN.load("aiunn", precision="bf16")