feat/tf_support #37

Merged
Fabel merged 13 commits from feat/tf_support into develop 2025-04-16 20:59:48 +00:00
1 changed files with 0 additions and 57 deletions
Showing only changes of commit 0852ddb109 - Show all commits

View File

@ -294,63 +294,6 @@ class AIIASparseMoe(AIIAmoe):
return torch.cat(merged_outputs, dim=0)
class AIIAchunked(AIIA):
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Update config with new parameters if provided
self.config.patch_size = patch_size
# Initialize base CNN for processing each patch using the specified base class
if issubclass(base_class, AIIABase):
self.base_cnn = AIIABase(self.config, **kwargs)
elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared
self.base_cnn = AIIABaseShared(self.config, **kwargs)
else:
raise ValueError("Invalid base class")
def forward(self, x):
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, self.patch_size, self.patch_size)
patch_outputs = []
for p in torch.split(patches, 1, dim=2):
p = p.squeeze(2)
po = self.base_cnn(p)
patch_outputs.append(po)
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
return combined_output
class AIIArecursive(AIIA):
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Pass recursion_depth as a kwarg to the config
self.config.recursion_depth = recursion_depth
# Initialize chunked CNN with updated config
self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs)
def forward(self, x, depth=0):
if depth == self.recursion_depth:
return self.chunked_cnn(x)
else:
patches = x.unfold(2, 16, 16).unfold(3, 16, 16)
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, 16, 16)
processed_patches = []
for p in torch.split(patches, 1, dim=2):
p = p.squeeze(2)
pp = self.forward(p, depth + 1)
processed_patches.append(pp)
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
return combined_output
if __name__ =="__main__":
config = AIIAConfig()
model = AIIAmoe(config, num_experts=5)