feat/fix_loading_for_sahred_smoe #44

Merged
Fabel merged 2 commits from feat/fix_loading_for_sahred_smoe into main 2025-06-06 11:15:37 +00:00
4 changed files with 57 additions and 9 deletions

View File

@ -10,7 +10,7 @@ include = '\.pyi?$'
[project] [project]
name = "aiia" name = "aiia"
version = "0.3.3" version = "0.3.4"
description = "AIIA Deep Learning Model Implementation" description = "AIIA Deep Learning Model Implementation"
readme = "README.md" readme = "README.md"
authors = [ authors = [

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = aiia name = aiia
version = 0.3.3 version = 0.3.4
author = Falko Habel author = Falko Habel
author_email = falko.habel@gmx.de author_email = falko.habel@gmx.de
description = AIIA deep learning model implementation description = AIIA deep learning model implementation

View File

@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader
from .pretrain.pretrainer import Pretrainer, ProjectionHead from .pretrain.pretrainer import Pretrainer, ProjectionHead
__version__ = "0.3.3" __version__ = "0.3.4"

View File

@ -1,8 +1,9 @@
from .config import AIIAConfig from .config import AIIAConfig
from torch import nn from torch import nn
from transformers import PreTrainedModel from transformers import PreTrainedModel
from safetensors.torch import load_file
import torch import torch
import copy import os
class AIIABase(PreTrainedModel): class AIIABase(PreTrainedModel):
@ -132,16 +133,33 @@ class AIIAmoe(PreTrainedModel):
config_class = AIIAConfig config_class = AIIAConfig
base_model_prefix = "AIIA" base_model_prefix = "AIIA"
def __init__(self, config: AIIAConfig, base_class=AIIABase): def __init__(self, config: AIIAConfig, base_class=None):
super().__init__(config=config) super().__init__(config=config)
self.config = config self.config = config
# Get num_experts directly from config instead of parameter # Get num_experts directly from config or set default to 3 if not present
num_experts = getattr(config, "num_experts", 3) # default to 3 if not in config if not hasattr(config, 'num_experts'):
config.num_experts = 3
# Get base_class_name as string (JSON serializable)
if not hasattr(config, 'base_class_name'):
config.base_class_name = 'AIIABase'
num_experts = config.num_experts
# Simple mapping from string to class
if base_class is not None:
resolved_base_class = base_class
else:
if config.base_class_name == 'AIIABase':
resolved_base_class = AIIABase
elif config.base_class_name == 'AIIABaseShared': # Replace with your second base class
resolved_base_class = AIIABaseShared
else:
raise ValueError(f"Unknown base_class_name: {config.base_class_name}")
# Initialize multiple experts from the chosen base class
self.experts = nn.ModuleList([ self.experts = nn.ModuleList([
AIIAExpert(self.config, base_class=base_class) AIIAExpert(self.config, base_class=resolved_base_class)
for _ in range(num_experts) for _ in range(num_experts)
]) ])
@ -153,6 +171,36 @@ class AIIAmoe(PreTrainedModel):
nn.Softmax(dim=1) nn.Softmax(dim=1)
) )
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, base_class=None, base_class_name=None, **kwargs):
# Load the config
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
# Set base_class_name if provided
if base_class_name is not None:
config.base_class_name = base_class_name
elif not hasattr(config, 'base_class_name'):
config.base_class_name = 'AIIABase'
# Create the model with the correct architecture
model = cls(config, base_class=base_class)
# Try to load weights - check for both formats
safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")
pytorch_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
if os.path.exists(safetensors_path):
state_dict = load_file(safetensors_path)
model.load_state_dict(state_dict, strict=False)
elif os.path.exists(pytorch_path):
state_dict = torch.load(pytorch_path, map_location="cpu")
model.load_state_dict(state_dict, strict=False)
else:
print("No weight files found - model initialized with random weights")
return model
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
""" """
Forward pass for the Mixture-of-Experts model. Forward pass for the Mixture-of-Experts model.