feat/tf_support #37
|
@ -294,63 +294,6 @@ class AIIASparseMoe(AIIAmoe):
|
|||
return torch.cat(merged_outputs, dim=0)
|
||||
|
||||
|
||||
class AIIAchunked(AIIA):
|
||||
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = self.config
|
||||
|
||||
# Update config with new parameters if provided
|
||||
self.config.patch_size = patch_size
|
||||
|
||||
# Initialize base CNN for processing each patch using the specified base class
|
||||
if issubclass(base_class, AIIABase):
|
||||
self.base_cnn = AIIABase(self.config, **kwargs)
|
||||
elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared
|
||||
self.base_cnn = AIIABaseShared(self.config, **kwargs)
|
||||
else:
|
||||
raise ValueError("Invalid base class")
|
||||
|
||||
def forward(self, x):
|
||||
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
|
||||
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, self.patch_size, self.patch_size)
|
||||
patch_outputs = []
|
||||
|
||||
for p in torch.split(patches, 1, dim=2):
|
||||
p = p.squeeze(2)
|
||||
po = self.base_cnn(p)
|
||||
patch_outputs.append(po)
|
||||
|
||||
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
|
||||
return combined_output
|
||||
|
||||
class AIIArecursive(AIIA):
|
||||
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs):
|
||||
|
||||
super().__init__(config=config, **kwargs)
|
||||
self.config = self.config
|
||||
|
||||
# Pass recursion_depth as a kwarg to the config
|
||||
self.config.recursion_depth = recursion_depth
|
||||
|
||||
# Initialize chunked CNN with updated config
|
||||
self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs)
|
||||
|
||||
def forward(self, x, depth=0):
|
||||
if depth == self.recursion_depth:
|
||||
return self.chunked_cnn(x)
|
||||
else:
|
||||
patches = x.unfold(2, 16, 16).unfold(3, 16, 16)
|
||||
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, 16, 16)
|
||||
processed_patches = []
|
||||
|
||||
for p in torch.split(patches, 1, dim=2):
|
||||
p = p.squeeze(2)
|
||||
pp = self.forward(p, depth + 1)
|
||||
processed_patches.append(pp)
|
||||
|
||||
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
|
||||
return combined_output
|
||||
|
||||
if __name__ =="__main__":
|
||||
config = AIIAConfig()
|
||||
model = AIIAmoe(config, num_experts=5)
|
||||
|
|
Loading…
Reference in New Issue