fixed model output
This commit is contained in:
parent
c06217db1e
commit
899f714554
|
@ -163,6 +163,20 @@ class AIIAExpert(AIIA):
|
|||
self.base_cnn = AIIABaseShared(self.config, **kwargs)
|
||||
else:
|
||||
raise ValueError("Invalid base class")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the expert model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after processing through base CNN
|
||||
"""
|
||||
# Process input through the base CNN
|
||||
return self.base_cnn(x)
|
||||
|
||||
|
||||
class AIIAmoe(AIIA):
|
||||
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
|
||||
|
|
Loading…
Reference in New Issue