from .config import AIIAConfig from torch import nn from transformers import PreTrainedModel import torch import copy class AIIABase(PreTrainedModel): config_class = AIIAConfig base_model_prefix = "AIIA" def __init__(self, config: AIIAConfig): super().__init__(config) # Initialize layers based on configuration layers = [] in_channels = config.num_channels for _ in range(config.num_hidden_layers): layers.extend([ nn.Conv2d(in_channels, config.hidden_size, kernel_size=config.kernel_size, padding=1), getattr(nn, config.activation_function)(), nn.MaxPool2d(kernel_size=1, stride=1) ]) in_channels = config.hidden_size self.cnn = nn.Sequential(*layers) def forward(self, x): return self.cnn(x) class AIIABaseShared(PreTrainedModel): config_class = AIIAConfig base_model_prefix = "AIIA" def __init__(self, config: AIIAConfig): super().__init__(config) """ Initialize the AIIABaseShared model. Args: config (AIIAConfig): Configuration object containing model parameters. """ super().__init__(config=config) # Initialize the network components self._initialize_network() self._initialize_activation_andPooling() def _initialize_network(self): """Initialize the shared and unique layers of the network.""" # Create a single shared convolutional layer self.shared_layer = nn.Conv2d( in_channels=self.config.num_channels, out_channels=self.config.hidden_size, kernel_size=self.config.kernel_size, padding=1 # Using same padding as defined in config ) # Initialize the unique layers with separate weights and biases self.unique_layers = nn.ModuleList() current_in_channels = self.config.hidden_size layer = nn.Conv2d( in_channels=current_in_channels, out_channels=self.config.hidden_size, kernel_size=self.config.kernel_size, padding=1 # Using same padding as defined in config ) self.unique_layers.append(layer) def _initialize_activation_andPooling(self): """Initialize activation function and pooling layers.""" # Get activation function from nn module self.activation = getattr(nn, self.config.activation_function)() # Initialize max pooling layer self.max_pool = nn.MaxPool2d( kernel_size=1, stride=1, ) def forward(self, x): """Forward pass of the network.""" # Apply shared layer transformation out = self.shared_layer(x) # Pass through activation function out = self.activation(out) # Apply max pooling out = self.max_pool(out) # Pass through unique layers for unique_layer in self.unique_layers: out = unique_layer(out) out = self.activation(out) out = self.max_pool(out) return out class AIIAExpert(PreTrainedModel): config_class = AIIAConfig base_model_prefix = "AIIA" def __init__(self, config: AIIAConfig, base_class=AIIABase): super().__init__(config=config) # Initialize base CNN with configuration and chosen base class if issubclass(base_class, AIIABase): self.base_cnn = AIIABase(self.config) elif issubclass(base_class, AIIABaseShared): self.base_cnn = AIIABaseShared(self.config) else: raise ValueError("Invalid base class") def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the expert model. Args: x (torch.Tensor): Input tensor Returns: torch.Tensor: Output tensor after processing through base CNN """ # Process input through the base CNN return self.base_cnn(x) class AIIAmoe(PreTrainedModel): config_class = AIIAConfig base_model_prefix = "AIIA" def __init__(self, config: AIIAConfig, base_class=AIIABase): super().__init__(config=config) self.config = config # Get num_experts directly from config instead of parameter num_experts = getattr(config, "num_experts", 3) # default to 3 if not in config # Initialize multiple experts from the chosen base class self.experts = nn.ModuleList([ AIIAExpert(self.config, base_class=base_class) for _ in range(num_experts) ]) gate_in_features = self.config.hidden_size # Create a gating network that maps the aggregated features to num_experts weights self.gate = nn.Sequential( nn.Linear(gate_in_features, num_experts), nn.Softmax(dim=1) ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for the Mixture-of-Experts model. Args: x (torch.Tensor): Input tensor Returns: torch.Tensor: Merged output tensor from all experts """ # Stack the outputs from each expert. # Each expert's output should have shape (B, C, H, W). After stacking, expert_outputs has shape: # (B, num_experts, C, H, W) expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1) # Aggregate spatial features: average across the spatial dimensions (H, W). # This results in a tensor with shape (B, num_experts, C) spatial_avg = torch.mean(expert_outputs, dim=(3, 4)) # To feed the gating network, further average across the expert dimension, # obtaining a tensor of shape (B, C) that represents the global feature summary. gate_input = torch.mean(spatial_avg, dim=1) # Compute gating weights using the gating network. # The output gate_weights has shape (B, num_experts) gate_weights = self.gate(gate_input) # Expand the gate weights to match the expert outputs shape so they can be combined. # After unsqueezing, gate_weights has shape (B, num_experts, 1, 1, 1) gate_weights_expanded = gate_weights.unsqueeze(2).unsqueeze(3).unsqueeze(4) # Multiply each expert's output by its corresponding gating weight and sum over experts. # The merged_output retains the shape (B, C, H, W) merged_output = torch.sum(expert_outputs * gate_weights_expanded, dim=1) return merged_output class AIIASparseMoe(AIIAmoe): config_class = AIIAConfig base_model_prefix = "AIIA" def __init__(self, config: AIIAConfig, base_class=AIIABase): super().__init__(config=config, base_class=base_class) def forward(self, x: torch.Tensor) -> torch.Tensor: # Compute the gate_weights similar to standard moe. expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1) spatial_avg = torch.mean(expert_outputs, dim=(3, 4)) gate_input = torch.mean(spatial_avg, dim=1) gate_weights = self.gate(gate_input) # Select the top-k experts for each input based on gating weights. _, top_k_indices = gate_weights.topk(self.config.top_k, dim=-1) # Initialize a list to store outputs from selected experts. merged_outputs = [] # Iterate over batch dimension to apply top-k selection per instance. for i in range(x.size(0)): # Get the indices of top-k experts for current instance. instance_top_k_indices = top_k_indices[i] # Select outputs from top-k experts. selected_expert_outputs = expert_outputs[i][instance_top_k_indices] # Average over the selected experts to get a single output per instance. averaged_output = torch.mean(selected_expert_outputs, dim=0) merged_outputs.append(averaged_output.unsqueeze(0)) # Stack outputs from all instances back into a batch tensor. return torch.cat(merged_outputs, dim=0) if __name__ =="__main__": config = AIIAConfig() model = AIIAmoe(config, num_experts=5) model.save_pretrained("test")