added sparse moe
This commit is contained in:
parent
1fcb31b044
commit
10967ea880
|
@ -260,6 +260,40 @@ class AIIAmoe(AIIA):
|
||||||
return merged_output
|
return merged_output
|
||||||
|
|
||||||
|
|
||||||
|
class AIIASparseMoe(AIIAmoe):
|
||||||
|
def __init__(self, config: AIIAConfig, num_experts: int = 3, top_k: int = 2, base_class=AIIABase, **kwargs):
|
||||||
|
super().__init__(config=config, num_experts=num_experts, base_class=base_class, **kwargs)
|
||||||
|
self.top_k = top_k
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Compute the gate_weights similar to standard moe.
|
||||||
|
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
|
||||||
|
spatial_avg = torch.mean(expert_outputs, dim=(3, 4))
|
||||||
|
gate_input = torch.mean(spatial_avg, dim=1)
|
||||||
|
gate_weights = self.gate(gate_input)
|
||||||
|
|
||||||
|
# Select the top-k experts for each input based on gating weights.
|
||||||
|
_, top_k_indices = gate_weights.topk(self.top_k, dim=-1)
|
||||||
|
|
||||||
|
# Initialize a list to store outputs from selected experts.
|
||||||
|
merged_outputs = []
|
||||||
|
|
||||||
|
# Iterate over batch dimension to apply top-k selection per instance.
|
||||||
|
for i in range(x.size(0)):
|
||||||
|
# Get the indices of top-k experts for current instance.
|
||||||
|
instance_top_k_indices = top_k_indices[i]
|
||||||
|
|
||||||
|
# Select outputs from top-k experts.
|
||||||
|
selected_expert_outputs = expert_outputs[i][instance_top_k_indices]
|
||||||
|
|
||||||
|
# Average over the selected experts to get a single output per instance.
|
||||||
|
averaged_output = torch.mean(selected_expert_outputs, dim=0)
|
||||||
|
merged_outputs.append(averaged_output.unsqueeze(0))
|
||||||
|
|
||||||
|
# Stack outputs from all instances back into a batch tensor.
|
||||||
|
return torch.cat(merged_outputs, dim=0)
|
||||||
|
|
||||||
|
|
||||||
class AIIAchunked(AIIA):
|
class AIIAchunked(AIIA):
|
||||||
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
|
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
|
||||||
super().__init__(config=config, **kwargs)
|
super().__init__(config=config, **kwargs)
|
||||||
|
|
Loading…
Reference in New Issue