From daaceaff0b0ebf242013bb4d211d7c237fd89bc9 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Fri, 6 Jun 2025 13:10:48 +0200 Subject: [PATCH 1/2] fixed model laoading for aiiamoe and aiiasmoe with shared weights --- src/aiia/model/Model.py | 60 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 54 insertions(+), 6 deletions(-) diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index 1cdd8dc..670f05f 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -1,8 +1,9 @@ from .config import AIIAConfig from torch import nn from transformers import PreTrainedModel +from safetensors.torch import load_file import torch -import copy +import os class AIIABase(PreTrainedModel): @@ -132,16 +133,33 @@ class AIIAmoe(PreTrainedModel): config_class = AIIAConfig base_model_prefix = "AIIA" - def __init__(self, config: AIIAConfig, base_class=AIIABase): + def __init__(self, config: AIIAConfig, base_class=None): super().__init__(config=config) self.config = config - # Get num_experts directly from config instead of parameter - num_experts = getattr(config, "num_experts", 3) # default to 3 if not in config + # Get num_experts directly from config or set default to 3 if not present + if not hasattr(config, 'num_experts'): + config.num_experts = 3 + + # Get base_class_name as string (JSON serializable) + if not hasattr(config, 'base_class_name'): + config.base_class_name = 'AIIABase' + + num_experts = config.num_experts + + # Simple mapping from string to class + if base_class is not None: + resolved_base_class = base_class + else: + if config.base_class_name == 'AIIABase': + resolved_base_class = AIIABase + elif config.base_class_name == 'AIIABaseShared': # Replace with your second base class + resolved_base_class = AIIABaseShared + else: + raise ValueError(f"Unknown base_class_name: {config.base_class_name}") - # Initialize multiple experts from the chosen base class self.experts = nn.ModuleList([ - AIIAExpert(self.config, base_class=base_class) + AIIAExpert(self.config, base_class=resolved_base_class) for _ in range(num_experts) ]) @@ -152,6 +170,36 @@ class AIIAmoe(PreTrainedModel): nn.Linear(gate_in_features, num_experts), nn.Softmax(dim=1) ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, base_class=None, base_class_name=None, **kwargs): + # Load the config + config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + + # Set base_class_name if provided + if base_class_name is not None: + config.base_class_name = base_class_name + elif not hasattr(config, 'base_class_name'): + config.base_class_name = 'AIIABase' + + # Create the model with the correct architecture + model = cls(config, base_class=base_class) + + # Try to load weights - check for both formats + safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors") + pytorch_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") + + if os.path.exists(safetensors_path): + state_dict = load_file(safetensors_path) + model.load_state_dict(state_dict, strict=False) + elif os.path.exists(pytorch_path): + state_dict = torch.load(pytorch_path, map_location="cpu") + model.load_state_dict(state_dict, strict=False) + else: + print("No weight files found - model initialized with random weights") + return model + + def forward(self, x: torch.Tensor) -> torch.Tensor: """ -- 2.34.1 From 27693c251c5ed8d8b8f5182bf8e8bc41a7244fb0 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Fri, 6 Jun 2025 13:11:16 +0200 Subject: [PATCH 2/2] increased version number --- pyproject.toml | 2 +- setup.cfg | 2 +- src/aiia/__init__.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4b8bd0e..49ac92d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ include = '\.pyi?$' [project] name = "aiia" -version = "0.3.3" +version = "0.3.4" description = "AIIA Deep Learning Model Implementation" readme = "README.md" authors = [ diff --git a/setup.cfg b/setup.cfg index 202e051..bec1fd3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = aiia -version = 0.3.3 +version = 0.3.4 author = Falko Habel author_email = falko.habel@gmx.de description = AIIA deep learning model implementation diff --git a/src/aiia/__init__.py b/src/aiia/__init__.py index 9a9e138..71374df 100644 --- a/src/aiia/__init__.py +++ b/src/aiia/__init__.py @@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader from .pretrain.pretrainer import Pretrainer, ProjectionHead -__version__ = "0.3.3" +__version__ = "0.3.4" -- 2.34.1