feat/tf_support #37

Merged
Fabel merged 13 commits from feat/tf_support into develop 2025-04-16 20:59:48 +00:00
1 changed files with 48 additions and 65 deletions
Showing only changes of commit 3e78a595c9 - Show all commits

View File

@ -1,67 +1,48 @@
from .config import AIIAConfig from .config import AIIAConfig
from torch import nn from torch import nn
from transformers import PretrainedModel from transformers import PreTrainedModel
import torch import torch
import os
import copy import copy
import warnings
class AIIA(nn.Module): class AIIABase(PreTrainedModel):
def __init__(self, config: AIIAConfig, **kwargs): config_class = AIIAConfig
super(AIIA, self).__init__() base_model_prefix = "AIIA"
# Create a deep copy of the configuration to avoid sharing
self.config = copy.deepcopy(config)
# Update the config with any additional keyword arguments def __init__(self, config: AIIAConfig):
for key, value in kwargs.items(): super().__init__(config)
setattr(self.config, key, value)
def save(self, path: str):
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
torch.save(self.state_dict(), f"{path}/model.pth")
self.config.save(path)
class AIIABase(AIIA):
def __init__(self, config: AIIAConfig, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Initialize layers based on configuration # Initialize layers based on configuration
layers = [] layers = []
in_channels = self.config.num_channels in_channels = config.num_channels
for _ in range(self.config.num_hidden_layers): for _ in range(config.num_hidden_layers):
layers.extend([ layers.extend([
nn.Conv2d(in_channels, self.config.hidden_size, nn.Conv2d(in_channels, config.hidden_size,
kernel_size=self.config.kernel_size, padding=1), kernel_size=config.kernel_size, padding=1),
getattr(nn, self.config.activation_function)(), getattr(nn, config.activation_function)(),
nn.MaxPool2d(kernel_size=1, stride=1) nn.MaxPool2d(kernel_size=1, stride=1)
]) ])
in_channels = self.config.hidden_size in_channels = config.hidden_size
self.cnn = nn.Sequential(*layers) self.cnn = nn.Sequential(*layers)
def forward(self, x): def forward(self, x):
return self.cnn(x) return self.cnn(x)
class AIIABaseShared(AIIA): class AIIABaseShared(PreTrainedModel):
def __init__(self, config: AIIAConfig, **kwargs): config_class = AIIAConfig
base_model_prefix = "AIIA"
def __init__(self, config: AIIAConfig):
super().__init__(config)
""" """
Initialize the AIIABaseShared model. Initialize the AIIABaseShared model.
Args: Args:
config (AIIAConfig): Configuration object containing model parameters. config (AIIAConfig): Configuration object containing model parameters.
**kwargs: Additional keyword arguments to override configuration settings.
""" """
super().__init__(config=config, **kwargs) super().__init__(config=config)
# Update configuration with new parameters if provided
self. config = copy.deepcopy(config)
for key, value in kwargs.items():
setattr(self.config, key, value)
# Initialize the network components # Initialize the network components
self._initialize_network() self._initialize_network()
@ -120,16 +101,17 @@ class AIIABaseShared(AIIA):
return out return out
class AIIAExpert(AIIA): class AIIAExpert(PreTrainedModel):
def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs): config_class = AIIAConfig
super().__init__(config=config, **kwargs) base_model_prefix = "AIIA"
self.config = self.config def __init__(self, config: AIIAConfig, base_class=AIIABase):
super().__init__(config=config)
# Initialize base CNN with configuration and chosen base class # Initialize base CNN with configuration and chosen base class
if issubclass(base_class, AIIABase): if issubclass(base_class, AIIABase):
self.base_cnn = AIIABase(self.config, **kwargs) self.base_cnn = AIIABase(self.config)
elif issubclass(base_class, AIIABaseShared): elif issubclass(base_class, AIIABaseShared):
self.base_cnn = AIIABaseShared(self.config, **kwargs) self.base_cnn = AIIABaseShared(self.config)
else: else:
raise ValueError("Invalid base class") raise ValueError("Invalid base class")
@ -146,26 +128,26 @@ class AIIAExpert(AIIA):
# Process input through the base CNN # Process input through the base CNN
return self.base_cnn(x) return self.base_cnn(x)
class AIIAmoe(AIIA): class AIIAmoe(PreTrainedModel):
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs): config_class = AIIAConfig
super().__init__(config=config, **kwargs) base_model_prefix = "AIIA"
def __init__(self, config: AIIAConfig, base_class=AIIABase):
super().__init__(config=config)
self.config = config self.config = config
# Update the config to include the number of experts. # Get num_experts directly from config instead of parameter
self.config.num_experts = num_experts num_experts = getattr(config, "num_experts", 3) # default to 3 if not in config
# Initialize multiple experts from the chosen base class. # Initialize multiple experts from the chosen base class
self.experts = nn.ModuleList([ self.experts = nn.ModuleList([
AIIAExpert(self.config, base_class=base_class, **kwargs) AIIAExpert(self.config, base_class=base_class)
for _ in range(num_experts) for _ in range(num_experts)
]) ])
# To generate gating weights, we first need to determine the feature dimension. gate_in_features = self.config.hidden_size
# Each expert is assumed to return an output of shape (B, C, H, W); after averaging over H and W,
# we obtain a tensor of shape (B, C) where C is the number of channels (here assumed to be 224).
gate_in_features = 512 # Adjust this if your expert output changes.
# Create a gating network that maps the aggregated features to num_experts weights. # Create a gating network that maps the aggregated features to num_experts weights
self.gate = nn.Sequential( self.gate = nn.Sequential(
nn.Linear(gate_in_features, num_experts), nn.Linear(gate_in_features, num_experts),
nn.Softmax(dim=1) nn.Softmax(dim=1)
@ -209,9 +191,10 @@ class AIIAmoe(AIIA):
class AIIASparseMoe(AIIAmoe): class AIIASparseMoe(AIIAmoe):
def __init__(self, config: AIIAConfig, num_experts: int = 3, top_k: int = 2, base_class=AIIABase, **kwargs): config_class = AIIAConfig
super().__init__(config=config, num_experts=num_experts, base_class=base_class, **kwargs) base_model_prefix = "AIIA"
self.top_k = top_k def __init__(self, config: AIIAConfig, base_class=AIIABase):
super().__init__(config=config, base_class=base_class)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# Compute the gate_weights similar to standard moe. # Compute the gate_weights similar to standard moe.
@ -221,7 +204,7 @@ class AIIASparseMoe(AIIAmoe):
gate_weights = self.gate(gate_input) gate_weights = self.gate(gate_input)
# Select the top-k experts for each input based on gating weights. # Select the top-k experts for each input based on gating weights.
_, top_k_indices = gate_weights.topk(self.top_k, dim=-1) _, top_k_indices = gate_weights.topk(self.config.top_k, dim=-1)
# Initialize a list to store outputs from selected experts. # Initialize a list to store outputs from selected experts.
merged_outputs = [] merged_outputs = []
@ -245,4 +228,4 @@ class AIIASparseMoe(AIIAmoe):
if __name__ =="__main__": if __name__ =="__main__":
config = AIIAConfig() config = AIIAConfig()
model = AIIAmoe(config, num_experts=5) model = AIIAmoe(config, num_experts=5)
model.save("test") model.save_pretrained("test")