Compare commits
2 Commits
8c9853ade3
...
27693c251c
Author | SHA1 | Date |
---|---|---|
|
27693c251c | |
|
daaceaff0b |
|
@ -10,7 +10,7 @@ include = '\.pyi?$'
|
|||
|
||||
[project]
|
||||
name = "aiia"
|
||||
version = "0.3.3"
|
||||
version = "0.3.4"
|
||||
description = "AIIA Deep Learning Model Implementation"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[metadata]
|
||||
name = aiia
|
||||
version = 0.3.3
|
||||
version = 0.3.4
|
||||
author = Falko Habel
|
||||
author_email = falko.habel@gmx.de
|
||||
description = AIIA deep learning model implementation
|
||||
|
|
|
@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader
|
|||
from .pretrain.pretrainer import Pretrainer, ProjectionHead
|
||||
|
||||
|
||||
__version__ = "0.3.3"
|
||||
__version__ = "0.3.4"
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from .config import AIIAConfig
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from safetensors.torch import load_file
|
||||
import torch
|
||||
import copy
|
||||
import os
|
||||
|
||||
|
||||
class AIIABase(PreTrainedModel):
|
||||
|
@ -132,16 +133,33 @@ class AIIAmoe(PreTrainedModel):
|
|||
config_class = AIIAConfig
|
||||
base_model_prefix = "AIIA"
|
||||
|
||||
def __init__(self, config: AIIAConfig, base_class=AIIABase):
|
||||
def __init__(self, config: AIIAConfig, base_class=None):
|
||||
super().__init__(config=config)
|
||||
self.config = config
|
||||
|
||||
# Get num_experts directly from config instead of parameter
|
||||
num_experts = getattr(config, "num_experts", 3) # default to 3 if not in config
|
||||
# Get num_experts directly from config or set default to 3 if not present
|
||||
if not hasattr(config, 'num_experts'):
|
||||
config.num_experts = 3
|
||||
|
||||
# Get base_class_name as string (JSON serializable)
|
||||
if not hasattr(config, 'base_class_name'):
|
||||
config.base_class_name = 'AIIABase'
|
||||
|
||||
num_experts = config.num_experts
|
||||
|
||||
# Simple mapping from string to class
|
||||
if base_class is not None:
|
||||
resolved_base_class = base_class
|
||||
else:
|
||||
if config.base_class_name == 'AIIABase':
|
||||
resolved_base_class = AIIABase
|
||||
elif config.base_class_name == 'AIIABaseShared': # Replace with your second base class
|
||||
resolved_base_class = AIIABaseShared
|
||||
else:
|
||||
raise ValueError(f"Unknown base_class_name: {config.base_class_name}")
|
||||
|
||||
# Initialize multiple experts from the chosen base class
|
||||
self.experts = nn.ModuleList([
|
||||
AIIAExpert(self.config, base_class=base_class)
|
||||
AIIAExpert(self.config, base_class=resolved_base_class)
|
||||
for _ in range(num_experts)
|
||||
])
|
||||
|
||||
|
@ -152,6 +170,36 @@ class AIIAmoe(PreTrainedModel):
|
|||
nn.Linear(gate_in_features, num_experts),
|
||||
nn.Softmax(dim=1)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, base_class=None, base_class_name=None, **kwargs):
|
||||
# Load the config
|
||||
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# Set base_class_name if provided
|
||||
if base_class_name is not None:
|
||||
config.base_class_name = base_class_name
|
||||
elif not hasattr(config, 'base_class_name'):
|
||||
config.base_class_name = 'AIIABase'
|
||||
|
||||
# Create the model with the correct architecture
|
||||
model = cls(config, base_class=base_class)
|
||||
|
||||
# Try to load weights - check for both formats
|
||||
safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")
|
||||
pytorch_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
|
||||
|
||||
if os.path.exists(safetensors_path):
|
||||
state_dict = load_file(safetensors_path)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
elif os.path.exists(pytorch_path):
|
||||
state_dict = torch.load(pytorch_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
else:
|
||||
print("No weight files found - model initialized with random weights")
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue