Compare commits
2 Commits
4e7d0d806f
...
89d9efb41f
Author | SHA1 | Date |
---|---|---|
|
89d9efb41f | |
|
899f714554 |
|
@ -164,6 +164,20 @@ class AIIAExpert(AIIA):
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid base class")
|
raise ValueError("Invalid base class")
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass of the expert model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor after processing through base CNN
|
||||||
|
"""
|
||||||
|
# Process input through the base CNN
|
||||||
|
return self.base_cnn(x)
|
||||||
|
|
||||||
|
|
||||||
class AIIAmoe(AIIA):
|
class AIIAmoe(AIIA):
|
||||||
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
|
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
|
||||||
super().__init__(config=config, **kwargs)
|
super().__init__(config=config, **kwargs)
|
||||||
|
|
Loading…
Reference in New Issue