fixed model output #17
|
@ -164,6 +164,20 @@ class AIIAExpert(AIIA):
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid base class")
|
raise ValueError("Invalid base class")
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass of the expert model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor after processing through base CNN
|
||||||
|
"""
|
||||||
|
# Process input through the base CNN
|
||||||
|
return self.base_cnn(x)
|
||||||
|
|
||||||
|
|
||||||
class AIIAmoe(AIIA):
|
class AIIAmoe(AIIA):
|
||||||
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
|
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
|
||||||
super().__init__(config=config, **kwargs)
|
super().__init__(config=config, **kwargs)
|
||||||
|
|
Loading…
Reference in New Issue