Merge pull request 'fixed model output' () from bug_fixes into main

Reviewed-on: 
This commit is contained in:
Falko Victor Habel 2025-03-03 16:06:57 +00:00
commit 89d9efb41f
1 changed files with 14 additions and 0 deletions
src/aiia/model

View File

@ -163,6 +163,20 @@ class AIIAExpert(AIIA):
self.base_cnn = AIIABaseShared(self.config, **kwargs)
else:
raise ValueError("Invalid base class")
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the expert model.
Args:
x (torch.Tensor): Input tensor
Returns:
torch.Tensor: Output tensor after processing through base CNN
"""
# Process input through the base CNN
return self.base_cnn(x)
class AIIAmoe(AIIA):
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):