From daaceaff0b0ebf242013bb4d211d7c237fd89bc9 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Fri, 6 Jun 2025 13:10:48 +0200 Subject: [PATCH] fixed model laoading for aiiamoe and aiiasmoe with shared weights --- src/aiia/model/Model.py | 60 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 54 insertions(+), 6 deletions(-) diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index 1cdd8dc..670f05f 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -1,8 +1,9 @@ from .config import AIIAConfig from torch import nn from transformers import PreTrainedModel +from safetensors.torch import load_file import torch -import copy +import os class AIIABase(PreTrainedModel): @@ -132,16 +133,33 @@ class AIIAmoe(PreTrainedModel): config_class = AIIAConfig base_model_prefix = "AIIA" - def __init__(self, config: AIIAConfig, base_class=AIIABase): + def __init__(self, config: AIIAConfig, base_class=None): super().__init__(config=config) self.config = config - # Get num_experts directly from config instead of parameter - num_experts = getattr(config, "num_experts", 3) # default to 3 if not in config + # Get num_experts directly from config or set default to 3 if not present + if not hasattr(config, 'num_experts'): + config.num_experts = 3 + + # Get base_class_name as string (JSON serializable) + if not hasattr(config, 'base_class_name'): + config.base_class_name = 'AIIABase' + + num_experts = config.num_experts + + # Simple mapping from string to class + if base_class is not None: + resolved_base_class = base_class + else: + if config.base_class_name == 'AIIABase': + resolved_base_class = AIIABase + elif config.base_class_name == 'AIIABaseShared': # Replace with your second base class + resolved_base_class = AIIABaseShared + else: + raise ValueError(f"Unknown base_class_name: {config.base_class_name}") - # Initialize multiple experts from the chosen base class self.experts = nn.ModuleList([ - AIIAExpert(self.config, base_class=base_class) + AIIAExpert(self.config, base_class=resolved_base_class) for _ in range(num_experts) ]) @@ -152,6 +170,36 @@ class AIIAmoe(PreTrainedModel): nn.Linear(gate_in_features, num_experts), nn.Softmax(dim=1) ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, base_class=None, base_class_name=None, **kwargs): + # Load the config + config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + + # Set base_class_name if provided + if base_class_name is not None: + config.base_class_name = base_class_name + elif not hasattr(config, 'base_class_name'): + config.base_class_name = 'AIIABase' + + # Create the model with the correct architecture + model = cls(config, base_class=base_class) + + # Try to load weights - check for both formats + safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors") + pytorch_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") + + if os.path.exists(safetensors_path): + state_dict = load_file(safetensors_path) + model.load_state_dict(state_dict, strict=False) + elif os.path.exists(pytorch_path): + state_dict = torch.load(pytorch_path, map_location="cpu") + model.load_state_dict(state_dict, strict=False) + else: + print("No weight files found - model initialized with random weights") + return model + + def forward(self, x: torch.Tensor) -> torch.Tensor: """