Compare commits
2 Commits
4e7d0d806f
...
89d9efb41f
Author | SHA1 | Date |
---|---|---|
|
89d9efb41f | |
|
899f714554 |
|
@ -164,6 +164,20 @@ class AIIAExpert(AIIA):
|
|||
else:
|
||||
raise ValueError("Invalid base class")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the expert model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after processing through base CNN
|
||||
"""
|
||||
# Process input through the base CNN
|
||||
return self.base_cnn(x)
|
||||
|
||||
|
||||
class AIIAmoe(AIIA):
|
||||
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue