fixed model laoading for aiiamoe and aiiasmoe with shared weights
This commit is contained in:
parent
8c9853ade3
commit
daaceaff0b
|
@ -1,8 +1,9 @@
|
||||||
from .config import AIIAConfig
|
from .config import AIIAConfig
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
from safetensors.torch import load_file
|
||||||
import torch
|
import torch
|
||||||
import copy
|
import os
|
||||||
|
|
||||||
|
|
||||||
class AIIABase(PreTrainedModel):
|
class AIIABase(PreTrainedModel):
|
||||||
|
@ -132,16 +133,33 @@ class AIIAmoe(PreTrainedModel):
|
||||||
config_class = AIIAConfig
|
config_class = AIIAConfig
|
||||||
base_model_prefix = "AIIA"
|
base_model_prefix = "AIIA"
|
||||||
|
|
||||||
def __init__(self, config: AIIAConfig, base_class=AIIABase):
|
def __init__(self, config: AIIAConfig, base_class=None):
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Get num_experts directly from config instead of parameter
|
# Get num_experts directly from config or set default to 3 if not present
|
||||||
num_experts = getattr(config, "num_experts", 3) # default to 3 if not in config
|
if not hasattr(config, 'num_experts'):
|
||||||
|
config.num_experts = 3
|
||||||
|
|
||||||
|
# Get base_class_name as string (JSON serializable)
|
||||||
|
if not hasattr(config, 'base_class_name'):
|
||||||
|
config.base_class_name = 'AIIABase'
|
||||||
|
|
||||||
|
num_experts = config.num_experts
|
||||||
|
|
||||||
|
# Simple mapping from string to class
|
||||||
|
if base_class is not None:
|
||||||
|
resolved_base_class = base_class
|
||||||
|
else:
|
||||||
|
if config.base_class_name == 'AIIABase':
|
||||||
|
resolved_base_class = AIIABase
|
||||||
|
elif config.base_class_name == 'AIIABaseShared': # Replace with your second base class
|
||||||
|
resolved_base_class = AIIABaseShared
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown base_class_name: {config.base_class_name}")
|
||||||
|
|
||||||
# Initialize multiple experts from the chosen base class
|
|
||||||
self.experts = nn.ModuleList([
|
self.experts = nn.ModuleList([
|
||||||
AIIAExpert(self.config, base_class=base_class)
|
AIIAExpert(self.config, base_class=resolved_base_class)
|
||||||
for _ in range(num_experts)
|
for _ in range(num_experts)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
@ -153,6 +171,36 @@ class AIIAmoe(PreTrainedModel):
|
||||||
nn.Softmax(dim=1)
|
nn.Softmax(dim=1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, base_class=None, base_class_name=None, **kwargs):
|
||||||
|
# Load the config
|
||||||
|
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
|
# Set base_class_name if provided
|
||||||
|
if base_class_name is not None:
|
||||||
|
config.base_class_name = base_class_name
|
||||||
|
elif not hasattr(config, 'base_class_name'):
|
||||||
|
config.base_class_name = 'AIIABase'
|
||||||
|
|
||||||
|
# Create the model with the correct architecture
|
||||||
|
model = cls(config, base_class=base_class)
|
||||||
|
|
||||||
|
# Try to load weights - check for both formats
|
||||||
|
safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")
|
||||||
|
pytorch_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
|
||||||
|
|
||||||
|
if os.path.exists(safetensors_path):
|
||||||
|
state_dict = load_file(safetensors_path)
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
elif os.path.exists(pytorch_path):
|
||||||
|
state_dict = torch.load(pytorch_path, map_location="cpu")
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
else:
|
||||||
|
print("No weight files found - model initialized with random weights")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass for the Mixture-of-Experts model.
|
Forward pass for the Mixture-of-Experts model.
|
||||||
|
|
Loading…
Reference in New Issue