import os import torch import torch.nn as nn import warnings from aiia.model.Model import AIIA, AIIAConfig, AIIABase from .config import aiuNNConfig import warnings class aiuNN(AIIA): def __init__(self, base_model: AIIA, config:aiuNNConfig): super().__init__(base_model.config) self.base_model = base_model # Pass the unified base configuration using the new parameter. self.config = config # Enhanced approach scale_factor = self.config.upsample_scale out_channels = self.base_model.config.num_channels * (scale_factor ** 2) self.pixel_shuffle_conv = nn.Conv2d( in_channels=self.base_model.config.hidden_size, out_channels=out_channels, kernel_size=self.base_model.config.kernel_size, padding=1 ) self.pixel_shuffle = nn.PixelShuffle(scale_factor) def forward(self, x): x = self.base_model(x) # Get base features x = self.pixel_shuffle_conv(x) # Expand channels for shuffling x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions return x @classmethod def load(cls, path, precision: str = None, **kwargs): """ Load a aiuNN model from disk with automatic detection of base model type. Args: path (str): Directory containing the stored configuration and model parameters. precision (str, optional): Desired precision for the model's parameters. **kwargs: Additional keyword arguments to override configuration parameters. Returns: An instance of aiuNN with loaded weights. """ # Load the configuration config = aiuNNConfig.load(path) # Determine the device device = 'cuda' if torch.cuda.is_available() else 'cpu' # Load the state dictionary state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device) # Import all model types from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive # Helper function to detect base class type from key patterns def detect_base_class_type(keys_prefix): if any(f"{keys_prefix}.shared_layer" in key for key in state_dict.keys()): return AIIABaseShared else: return AIIABase # Detect base model type base_model = None # Check for AIIAmoe with multiple experts if any("base_model.experts" in key for key in state_dict.keys()): # Count the number of experts max_expert_idx = -1 for key in state_dict.keys(): if "base_model.experts." in key: try: parts = key.split("base_model.experts.")[1].split(".") expert_idx = int(parts[0]) max_expert_idx = max(max_expert_idx, expert_idx) except (ValueError, IndexError): pass if max_expert_idx >= 0: # Determine the type of base_cnn each expert is using base_class_for_experts = detect_base_class_type("base_model.experts.0.base_cnn") # Create AIIAmoe with the detected expert count and base class base_model = AIIAmoe(config, num_experts=max_expert_idx+1, base_class=base_class_for_experts, **kwargs) # Check for AIIAchunked or AIIArecursive elif any("base_model.chunked_cnn" in key for key in state_dict.keys()): if any("recursion_depth" in key for key in state_dict.keys()): # This is an AIIArecursive model base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn") base_model = AIIArecursive(config, base_class=base_class, **kwargs) else: # This is an AIIAchunked model base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn") base_model = AIIAchunked(config, base_class=base_class, **kwargs) # Check for AIIAExpert elif any("base_model.base_cnn" in key for key in state_dict.keys()): # Determine which base class the expert is using base_class = detect_base_class_type("base_model.base_cnn") base_model = AIIAExpert(config, base_class=base_class, **kwargs) # If none of the above, use AIIABase or AIIABaseShared directly else: base_class = detect_base_class_type("base_model") base_model = base_class(config, **kwargs) # Create the aiuNN model with the detected base model model = cls(base_model, config=base_model.config) # Handle precision conversion dtype = None if precision is not None: if precision.lower() == 'fp16': dtype = torch.float16 elif precision.lower() == 'bf16': if device == 'cuda' and not torch.cuda.is_bf16_supported(): warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.") dtype = torch.float16 else: dtype = torch.bfloat16 else: raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.") if dtype is not None: for key, param in state_dict.items(): if torch.is_tensor(param): state_dict[key] = param.to(dtype) # Load the state dict model.load_state_dict(state_dict) return model if __name__ == "__main__": from aiia import AIIABase, AIIAConfig # Create a configuration and build a base model. config = AIIAConfig() ai_config = aiuNNConfig() base_model = AIIABase(config) # Instantiate Upsampler from the base model (works correctly). upsampler = aiuNN(base_model, config=ai_config) # Save the model (both configuration and weights). upsampler.save("aiunn") # Now load using the overridden load method; this will load the complete model. upsampler_loaded = aiuNN.load("aiunn", precision="bf16") print("Updated configuration:", upsampler_loaded.config.__dict__)