diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index f21a0db..abcc34a 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -294,63 +294,6 @@ class AIIASparseMoe(AIIAmoe): return torch.cat(merged_outputs, dim=0) -class AIIAchunked(AIIA): - def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs): - super().__init__(config=config, **kwargs) - self.config = self.config - - # Update config with new parameters if provided - self.config.patch_size = patch_size - - # Initialize base CNN for processing each patch using the specified base class - if issubclass(base_class, AIIABase): - self.base_cnn = AIIABase(self.config, **kwargs) - elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared - self.base_cnn = AIIABaseShared(self.config, **kwargs) - else: - raise ValueError("Invalid base class") - - def forward(self, x): - patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size) - patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, self.patch_size, self.patch_size) - patch_outputs = [] - - for p in torch.split(patches, 1, dim=2): - p = p.squeeze(2) - po = self.base_cnn(p) - patch_outputs.append(po) - - combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0) - return combined_output - -class AIIArecursive(AIIA): - def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs): - - super().__init__(config=config, **kwargs) - self.config = self.config - - # Pass recursion_depth as a kwarg to the config - self.config.recursion_depth = recursion_depth - - # Initialize chunked CNN with updated config - self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs) - - def forward(self, x, depth=0): - if depth == self.recursion_depth: - return self.chunked_cnn(x) - else: - patches = x.unfold(2, 16, 16).unfold(3, 16, 16) - patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, 16, 16) - processed_patches = [] - - for p in torch.split(patches, 1, dim=2): - p = p.squeeze(2) - pp = self.forward(p, depth + 1) - processed_patches.append(pp) - - combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0) - return combined_output - if __name__ =="__main__": config = AIIAConfig() model = AIIAmoe(config, num_experts=5)