231 lines
8.3 KiB
Python
231 lines
8.3 KiB
Python
from .config import AIIAConfig
|
|
from torch import nn
|
|
from transformers import PreTrainedModel
|
|
import torch
|
|
import copy
|
|
|
|
|
|
class AIIABase(PreTrainedModel):
|
|
config_class = AIIAConfig
|
|
base_model_prefix = "AIIA"
|
|
|
|
def __init__(self, config: AIIAConfig):
|
|
super().__init__(config)
|
|
|
|
# Initialize layers based on configuration
|
|
layers = []
|
|
in_channels = config.num_channels
|
|
|
|
for _ in range(config.num_hidden_layers):
|
|
layers.extend([
|
|
nn.Conv2d(in_channels, config.hidden_size,
|
|
kernel_size=config.kernel_size, padding=1),
|
|
getattr(nn, config.activation_function)(),
|
|
nn.MaxPool2d(kernel_size=1, stride=1)
|
|
])
|
|
in_channels = config.hidden_size
|
|
|
|
self.cnn = nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
return self.cnn(x)
|
|
|
|
class AIIABaseShared(PreTrainedModel):
|
|
config_class = AIIAConfig
|
|
base_model_prefix = "AIIA"
|
|
|
|
def __init__(self, config: AIIAConfig):
|
|
super().__init__(config)
|
|
"""
|
|
Initialize the AIIABaseShared model.
|
|
|
|
Args:
|
|
config (AIIAConfig): Configuration object containing model parameters.
|
|
"""
|
|
super().__init__(config=config)
|
|
|
|
# Initialize the network components
|
|
self._initialize_network()
|
|
self._initialize_activation_andPooling()
|
|
|
|
def _initialize_network(self):
|
|
"""Initialize the shared and unique layers of the network."""
|
|
# Create a single shared convolutional layer
|
|
self.shared_layer = nn.Conv2d(
|
|
in_channels=self.config.num_channels,
|
|
out_channels=self.config.hidden_size,
|
|
kernel_size=self.config.kernel_size,
|
|
padding=1 # Using same padding as defined in config
|
|
)
|
|
|
|
# Initialize the unique layers with separate weights and biases
|
|
self.unique_layers = nn.ModuleList()
|
|
current_in_channels = self.config.hidden_size
|
|
|
|
layer = nn.Conv2d(
|
|
in_channels=current_in_channels,
|
|
out_channels=self.config.hidden_size,
|
|
kernel_size=self.config.kernel_size,
|
|
padding=1 # Using same padding as defined in config
|
|
)
|
|
|
|
self.unique_layers.append(layer)
|
|
|
|
def _initialize_activation_andPooling(self):
|
|
"""Initialize activation function and pooling layers."""
|
|
# Get activation function from nn module
|
|
self.activation = getattr(nn, self.config.activation_function)()
|
|
|
|
# Initialize max pooling layer
|
|
self.max_pool = nn.MaxPool2d(
|
|
kernel_size=1,
|
|
stride=1,
|
|
)
|
|
|
|
def forward(self, x):
|
|
"""Forward pass of the network."""
|
|
# Apply shared layer transformation
|
|
out = self.shared_layer(x)
|
|
|
|
# Pass through activation function
|
|
out = self.activation(out)
|
|
|
|
# Apply max pooling
|
|
out = self.max_pool(out)
|
|
|
|
# Pass through unique layers
|
|
for unique_layer in self.unique_layers:
|
|
out = unique_layer(out)
|
|
out = self.activation(out)
|
|
out = self.max_pool(out)
|
|
|
|
return out
|
|
|
|
class AIIAExpert(PreTrainedModel):
|
|
config_class = AIIAConfig
|
|
base_model_prefix = "AIIA"
|
|
def __init__(self, config: AIIAConfig, base_class=AIIABase):
|
|
super().__init__(config=config)
|
|
|
|
# Initialize base CNN with configuration and chosen base class
|
|
if issubclass(base_class, AIIABase):
|
|
self.base_cnn = AIIABase(self.config)
|
|
elif issubclass(base_class, AIIABaseShared):
|
|
self.base_cnn = AIIABaseShared(self.config)
|
|
else:
|
|
raise ValueError("Invalid base class")
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Forward pass of the expert model.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor
|
|
|
|
Returns:
|
|
torch.Tensor: Output tensor after processing through base CNN
|
|
"""
|
|
# Process input through the base CNN
|
|
return self.base_cnn(x)
|
|
|
|
class AIIAmoe(PreTrainedModel):
|
|
config_class = AIIAConfig
|
|
base_model_prefix = "AIIA"
|
|
|
|
def __init__(self, config: AIIAConfig, base_class=AIIABase):
|
|
super().__init__(config=config)
|
|
self.config = config
|
|
|
|
# Get num_experts directly from config instead of parameter
|
|
num_experts = getattr(config, "num_experts", 3) # default to 3 if not in config
|
|
|
|
# Initialize multiple experts from the chosen base class
|
|
self.experts = nn.ModuleList([
|
|
AIIAExpert(self.config, base_class=base_class)
|
|
for _ in range(num_experts)
|
|
])
|
|
|
|
gate_in_features = self.config.hidden_size
|
|
|
|
# Create a gating network that maps the aggregated features to num_experts weights
|
|
self.gate = nn.Sequential(
|
|
nn.Linear(gate_in_features, num_experts),
|
|
nn.Softmax(dim=1)
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Forward pass for the Mixture-of-Experts model.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor
|
|
|
|
Returns:
|
|
torch.Tensor: Merged output tensor from all experts
|
|
"""
|
|
# Stack the outputs from each expert.
|
|
# Each expert's output should have shape (B, C, H, W). After stacking, expert_outputs has shape:
|
|
# (B, num_experts, C, H, W)
|
|
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
|
|
|
|
# Aggregate spatial features: average across the spatial dimensions (H, W).
|
|
# This results in a tensor with shape (B, num_experts, C)
|
|
spatial_avg = torch.mean(expert_outputs, dim=(3, 4))
|
|
|
|
# To feed the gating network, further average across the expert dimension,
|
|
# obtaining a tensor of shape (B, C) that represents the global feature summary.
|
|
gate_input = torch.mean(spatial_avg, dim=1)
|
|
|
|
# Compute gating weights using the gating network.
|
|
# The output gate_weights has shape (B, num_experts)
|
|
gate_weights = self.gate(gate_input)
|
|
|
|
# Expand the gate weights to match the expert outputs shape so they can be combined.
|
|
# After unsqueezing, gate_weights has shape (B, num_experts, 1, 1, 1)
|
|
gate_weights_expanded = gate_weights.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
|
|
|
# Multiply each expert's output by its corresponding gating weight and sum over experts.
|
|
# The merged_output retains the shape (B, C, H, W)
|
|
merged_output = torch.sum(expert_outputs * gate_weights_expanded, dim=1)
|
|
return merged_output
|
|
|
|
|
|
class AIIASparseMoe(AIIAmoe):
|
|
config_class = AIIAConfig
|
|
base_model_prefix = "AIIA"
|
|
def __init__(self, config: AIIAConfig, base_class=AIIABase):
|
|
super().__init__(config=config, base_class=base_class)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# Compute the gate_weights similar to standard moe.
|
|
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
|
|
spatial_avg = torch.mean(expert_outputs, dim=(3, 4))
|
|
gate_input = torch.mean(spatial_avg, dim=1)
|
|
gate_weights = self.gate(gate_input)
|
|
|
|
# Select the top-k experts for each input based on gating weights.
|
|
_, top_k_indices = gate_weights.topk(self.config.top_k, dim=-1)
|
|
|
|
# Initialize a list to store outputs from selected experts.
|
|
merged_outputs = []
|
|
|
|
# Iterate over batch dimension to apply top-k selection per instance.
|
|
for i in range(x.size(0)):
|
|
# Get the indices of top-k experts for current instance.
|
|
instance_top_k_indices = top_k_indices[i]
|
|
|
|
# Select outputs from top-k experts.
|
|
selected_expert_outputs = expert_outputs[i][instance_top_k_indices]
|
|
|
|
# Average over the selected experts to get a single output per instance.
|
|
averaged_output = torch.mean(selected_expert_outputs, dim=0)
|
|
merged_outputs.append(averaged_output.unsqueeze(0))
|
|
|
|
# Stack outputs from all instances back into a batch tensor.
|
|
return torch.cat(merged_outputs, dim=0)
|
|
|
|
|
|
if __name__ =="__main__":
|
|
config = AIIAConfig()
|
|
model = AIIAmoe(config, num_experts=5)
|
|
model.save_pretrained("test") |