diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index 39d5684..3287eab 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -163,6 +163,20 @@ class AIIAExpert(AIIA): self.base_cnn = AIIABaseShared(self.config, **kwargs) else: raise ValueError("Invalid base class") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the expert model. + + Args: + x (torch.Tensor): Input tensor + + Returns: + torch.Tensor: Output tensor after processing through base CNN + """ + # Process input through the base CNN + return self.base_cnn(x) + class AIIAmoe(AIIA): def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):