fixed model output

This commit is contained in:
Falko Victor Habel 2025-03-03 17:05:35 +01:00
parent c06217db1e
commit 899f714554
1 changed files with 14 additions and 0 deletions

View File

@ -163,6 +163,20 @@ class AIIAExpert(AIIA):
self.base_cnn = AIIABaseShared(self.config, **kwargs)
else:
raise ValueError("Invalid base class")
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the expert model.
Args:
x (torch.Tensor): Input tensor
Returns:
torch.Tensor: Output tensor after processing through base CNN
"""
# Process input through the base CNN
return self.base_cnn(x)
class AIIAmoe(AIIA):
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):