Merge pull request 'feat/smoe' () from feat/smoe into main

Reviewed-on: 
This commit is contained in:
Falko Victor Habel 2025-03-28 15:48:40 +00:00
commit 18249a852a
7 changed files with 70 additions and 11 deletions

View File

@ -34,4 +34,4 @@ jobs:
VECTORDB_TOKEN: ${{ secrets.VECTORDB_TOKEN }}
run: |
cd VectorLoader
python -m src.run --full
python -m src.run

View File

@ -10,7 +10,7 @@ include = '\.pyi?$'
[project]
name = "aiia"
version = "0.2.0"
version = "0.2.1"
description = "AIIA Deep Learning Model Implementation"
readme = "README.md"
authors = [

View File

@ -1,6 +1,6 @@
[metadata]
name = aiia
version = 0.2.0
version = 0.2.1
author = Falko Habel
author_email = falko.habel@gmx.de
description = AIIA deep learning model implementation

View File

@ -1,7 +1,7 @@
from .model.Model import AIIABase, AIIABaseShared, AIIAchunked, AIIAExpert, AIIAmoe, AIIA, AIIArecursive
from .model.Model import AIIABase, AIIABaseShared, AIIAchunked, AIIAmoe, AIIASparseMoe, AIIArecursive
from .model.config import AIIAConfig
from .data.DataLoader import DataLoader
from .pretrain.pretrainer import Pretrainer, ProjectionHead
__version__ = "0.2.0"
__version__ = "0.2.1"

View File

@ -260,6 +260,40 @@ class AIIAmoe(AIIA):
return merged_output
class AIIASparseMoe(AIIAmoe):
def __init__(self, config: AIIAConfig, num_experts: int = 3, top_k: int = 2, base_class=AIIABase, **kwargs):
super().__init__(config=config, num_experts=num_experts, base_class=base_class, **kwargs)
self.top_k = top_k
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Compute the gate_weights similar to standard moe.
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
spatial_avg = torch.mean(expert_outputs, dim=(3, 4))
gate_input = torch.mean(spatial_avg, dim=1)
gate_weights = self.gate(gate_input)
# Select the top-k experts for each input based on gating weights.
_, top_k_indices = gate_weights.topk(self.top_k, dim=-1)
# Initialize a list to store outputs from selected experts.
merged_outputs = []
# Iterate over batch dimension to apply top-k selection per instance.
for i in range(x.size(0)):
# Get the indices of top-k experts for current instance.
instance_top_k_indices = top_k_indices[i]
# Select outputs from top-k experts.
selected_expert_outputs = expert_outputs[i][instance_top_k_indices]
# Average over the selected experts to get a single output per instance.
averaged_output = torch.mean(selected_expert_outputs, dim=0)
merged_outputs.append(averaged_output.unsqueeze(0))
# Stack outputs from all instances back into a batch tensor.
return torch.cat(merged_outputs, dim=0)
class AIIAchunked(AIIA):
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs)

View File

@ -1,21 +1,20 @@
from .Model import (
AIIA,
AIIABase,
AIIABaseShared,
AIIAchunked,
AIIAExpert,
AIIAmoe,
AIIASparseMoe,
AIIArecursive
)
from .config import AIIAConfig
__all__ = [
"AIIA",
"AIIABase",
"AIIABaseShared",
"AIIAchunked",
"AIIAExpert",
"AIIAmoe",
"AIIASparseMoe",
"AIIArecursive",
"AIIAConfig"
"AIIAConfig",
]

View File

@ -1,6 +1,6 @@
import os
import torch
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig, AIIASparseMoe
def test_aiiabase_creation():
config = AIIAConfig()
@ -106,6 +106,32 @@ def test_aiiamoe_save_load():
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiasparsemoe_creation():
config = AIIAConfig()
model = AIIASparseMoe(config, num_experts=5, top_k=2)
assert isinstance(model, AIIASparseMoe)
def test_aiiasparsemoe_save_load():
config = AIIAConfig()
model = AIIASparseMoe(config, num_experts=3, top_k=1)
save_path = "test_aiiasparsemoe_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIASparseMoe.load(save_path)
# Check if the loaded model is an instance of AIIASparseMoe
assert isinstance(loaded_model, AIIASparseMoe)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiachunked_creation():
config = AIIAConfig()
model = AIIAchunked(config)