fixed model output #17

Merged
Fabel merged 1 commits from bug_fixes into main 2025-03-03 16:06:58 +00:00
1 changed files with 14 additions and 0 deletions

View File

@ -163,6 +163,20 @@ class AIIAExpert(AIIA):
self.base_cnn = AIIABaseShared(self.config, **kwargs)
else:
raise ValueError("Invalid base class")
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the expert model.
Args:
x (torch.Tensor): Input tensor
Returns:
torch.Tensor: Output tensor after processing through base CNN
"""
# Process input through the base CNN
return self.base_cnn(x)
class AIIAmoe(AIIA):
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):