From 899f7145540a1240f09822bff91a242434410641 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Mon, 3 Mar 2025 17:05:35 +0100 Subject: [PATCH] fixed model output --- src/aiia/model/Model.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index 39d5684..3287eab 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -163,6 +163,20 @@ class AIIAExpert(AIIA): self.base_cnn = AIIABaseShared(self.config, **kwargs) else: raise ValueError("Invalid base class") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the expert model. + + Args: + x (torch.Tensor): Input tensor + + Returns: + torch.Tensor: Output tensor after processing through base CNN + """ + # Process input through the base CNN + return self.base_cnn(x) + class AIIAmoe(AIIA): def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):