Compare commits
No commits in common. "89d9efb41f38f742bc1611e40b08a9e7c6a86167" and "4e7d0d806f2a692ff293442cae9a233c9109fd86" have entirely different histories.
89d9efb41f
...
4e7d0d806f
|
@ -164,20 +164,6 @@ class AIIAExpert(AIIA):
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid base class")
|
raise ValueError("Invalid base class")
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Forward pass of the expert model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): Input tensor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Output tensor after processing through base CNN
|
|
||||||
"""
|
|
||||||
# Process input through the base CNN
|
|
||||||
return self.base_cnn(x)
|
|
||||||
|
|
||||||
|
|
||||||
class AIIAmoe(AIIA):
|
class AIIAmoe(AIIA):
|
||||||
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
|
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
|
||||||
super().__init__(config=config, **kwargs)
|
super().__init__(config=config, **kwargs)
|
||||||
|
|
Loading…
Reference in New Issue