fixed model laoading for aiiamoe and aiiasmoe with shared weights

This commit is contained in:
Falko Victor Habel 2025-06-06 13:10:48 +02:00
parent 8c9853ade3
commit daaceaff0b
1 changed files with 54 additions and 6 deletions

View File

@ -1,8 +1,9 @@
from .config import AIIAConfig
from torch import nn
from transformers import PreTrainedModel
from safetensors.torch import load_file
import torch
import copy
import os
class AIIABase(PreTrainedModel):
@ -132,16 +133,33 @@ class AIIAmoe(PreTrainedModel):
config_class = AIIAConfig
base_model_prefix = "AIIA"
def __init__(self, config: AIIAConfig, base_class=AIIABase):
def __init__(self, config: AIIAConfig, base_class=None):
super().__init__(config=config)
self.config = config
# Get num_experts directly from config instead of parameter
num_experts = getattr(config, "num_experts", 3) # default to 3 if not in config
# Get num_experts directly from config or set default to 3 if not present
if not hasattr(config, 'num_experts'):
config.num_experts = 3
# Get base_class_name as string (JSON serializable)
if not hasattr(config, 'base_class_name'):
config.base_class_name = 'AIIABase'
num_experts = config.num_experts
# Simple mapping from string to class
if base_class is not None:
resolved_base_class = base_class
else:
if config.base_class_name == 'AIIABase':
resolved_base_class = AIIABase
elif config.base_class_name == 'AIIABaseShared': # Replace with your second base class
resolved_base_class = AIIABaseShared
else:
raise ValueError(f"Unknown base_class_name: {config.base_class_name}")
# Initialize multiple experts from the chosen base class
self.experts = nn.ModuleList([
AIIAExpert(self.config, base_class=base_class)
AIIAExpert(self.config, base_class=resolved_base_class)
for _ in range(num_experts)
])
@ -152,6 +170,36 @@ class AIIAmoe(PreTrainedModel):
nn.Linear(gate_in_features, num_experts),
nn.Softmax(dim=1)
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, base_class=None, base_class_name=None, **kwargs):
# Load the config
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
# Set base_class_name if provided
if base_class_name is not None:
config.base_class_name = base_class_name
elif not hasattr(config, 'base_class_name'):
config.base_class_name = 'AIIABase'
# Create the model with the correct architecture
model = cls(config, base_class=base_class)
# Try to load weights - check for both formats
safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")
pytorch_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
if os.path.exists(safetensors_path):
state_dict = load_file(safetensors_path)
model.load_state_dict(state_dict, strict=False)
elif os.path.exists(pytorch_path):
state_dict = torch.load(pytorch_path, map_location="cpu")
model.load_state_dict(state_dict, strict=False)
else:
print("No weight files found - model initialized with random weights")
return model
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""