fixed model laoading for aiiamoe and aiiasmoe with shared weights
This commit is contained in:
parent
8c9853ade3
commit
daaceaff0b
|
@ -1,8 +1,9 @@
|
|||
from .config import AIIAConfig
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from safetensors.torch import load_file
|
||||
import torch
|
||||
import copy
|
||||
import os
|
||||
|
||||
|
||||
class AIIABase(PreTrainedModel):
|
||||
|
@ -132,16 +133,33 @@ class AIIAmoe(PreTrainedModel):
|
|||
config_class = AIIAConfig
|
||||
base_model_prefix = "AIIA"
|
||||
|
||||
def __init__(self, config: AIIAConfig, base_class=AIIABase):
|
||||
def __init__(self, config: AIIAConfig, base_class=None):
|
||||
super().__init__(config=config)
|
||||
self.config = config
|
||||
|
||||
# Get num_experts directly from config instead of parameter
|
||||
num_experts = getattr(config, "num_experts", 3) # default to 3 if not in config
|
||||
# Get num_experts directly from config or set default to 3 if not present
|
||||
if not hasattr(config, 'num_experts'):
|
||||
config.num_experts = 3
|
||||
|
||||
# Get base_class_name as string (JSON serializable)
|
||||
if not hasattr(config, 'base_class_name'):
|
||||
config.base_class_name = 'AIIABase'
|
||||
|
||||
num_experts = config.num_experts
|
||||
|
||||
# Simple mapping from string to class
|
||||
if base_class is not None:
|
||||
resolved_base_class = base_class
|
||||
else:
|
||||
if config.base_class_name == 'AIIABase':
|
||||
resolved_base_class = AIIABase
|
||||
elif config.base_class_name == 'AIIABaseShared': # Replace with your second base class
|
||||
resolved_base_class = AIIABaseShared
|
||||
else:
|
||||
raise ValueError(f"Unknown base_class_name: {config.base_class_name}")
|
||||
|
||||
# Initialize multiple experts from the chosen base class
|
||||
self.experts = nn.ModuleList([
|
||||
AIIAExpert(self.config, base_class=base_class)
|
||||
AIIAExpert(self.config, base_class=resolved_base_class)
|
||||
for _ in range(num_experts)
|
||||
])
|
||||
|
||||
|
@ -153,6 +171,36 @@ class AIIAmoe(PreTrainedModel):
|
|||
nn.Softmax(dim=1)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, base_class=None, base_class_name=None, **kwargs):
|
||||
# Load the config
|
||||
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# Set base_class_name if provided
|
||||
if base_class_name is not None:
|
||||
config.base_class_name = base_class_name
|
||||
elif not hasattr(config, 'base_class_name'):
|
||||
config.base_class_name = 'AIIABase'
|
||||
|
||||
# Create the model with the correct architecture
|
||||
model = cls(config, base_class=base_class)
|
||||
|
||||
# Try to load weights - check for both formats
|
||||
safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")
|
||||
pytorch_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
|
||||
|
||||
if os.path.exists(safetensors_path):
|
||||
state_dict = load_file(safetensors_path)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
elif os.path.exists(pytorch_path):
|
||||
state_dict = torch.load(pytorch_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
else:
|
||||
print("No weight files found - model initialized with random weights")
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the Mixture-of-Experts model.
|
||||
|
|
Loading…
Reference in New Issue