diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index 06e2924..f21a0db 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -260,6 +260,40 @@ class AIIAmoe(AIIA): return merged_output +class AIIASparseMoe(AIIAmoe): + def __init__(self, config: AIIAConfig, num_experts: int = 3, top_k: int = 2, base_class=AIIABase, **kwargs): + super().__init__(config=config, num_experts=num_experts, base_class=base_class, **kwargs) + self.top_k = top_k + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Compute the gate_weights similar to standard moe. + expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1) + spatial_avg = torch.mean(expert_outputs, dim=(3, 4)) + gate_input = torch.mean(spatial_avg, dim=1) + gate_weights = self.gate(gate_input) + + # Select the top-k experts for each input based on gating weights. + _, top_k_indices = gate_weights.topk(self.top_k, dim=-1) + + # Initialize a list to store outputs from selected experts. + merged_outputs = [] + + # Iterate over batch dimension to apply top-k selection per instance. + for i in range(x.size(0)): + # Get the indices of top-k experts for current instance. + instance_top_k_indices = top_k_indices[i] + + # Select outputs from top-k experts. + selected_expert_outputs = expert_outputs[i][instance_top_k_indices] + + # Average over the selected experts to get a single output per instance. + averaged_output = torch.mean(selected_expert_outputs, dim=0) + merged_outputs.append(averaged_output.unsqueeze(0)) + + # Stack outputs from all instances back into a batch tensor. + return torch.cat(merged_outputs, dim=0) + + class AIIAchunked(AIIA): def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs): super().__init__(config=config, **kwargs)