updated models for improved config

This commit is contained in:
Falko Victor Habel 2025-01-22 11:19:55 +01:00
parent b87ce68c82
commit 74973a325b
1 changed files with 50 additions and 22 deletions

View File

@ -4,9 +4,13 @@ import torch
class AIIA(nn.Module):
def __init__(self, config: AIIAConfig):
def __init__(self, config: AIIAConfig, **kwargs):
super(AIIA, self).__init__()
self.config = config
# Update the config with any additional keyword arguments
for key, value in kwargs.items():
setattr(self.config, key, value)
def save(self, model_path, config_path):
torch.save(self.state_dict(), model_path)
@ -20,67 +24,89 @@ class AIIA(nn.Module):
return model
class AIIABase(AIIA):
def __init__(self, config: AIIAConfig):
super(AIIABase, self).__init__(config)
def __init__(self, config: AIIAConfig, **kwargs):
super(AIIABase, self).__init__(config, **kwargs)
# Initialize layers based on updated config
layers = []
in_channels = config.num_channels
for _ in range(config.num_hidden_layers):
in_channels = self.config.num_channels
for _ in range(self.config.num_hidden_layers):
layers.extend([
nn.Conv2d(in_channels, config.hidden_size, kernel_size=config.kernel_size, padding=1),
getattr(nn, config.activation_function)(),
nn.Conv2d(in_channels, self.config.hidden_size,
kernel_size=self.config.kernel_size, padding=1),
getattr(nn, self.config.activation_function)(),
nn.MaxPool2d(kernel_size=2)
])
in_channels = config.hidden_size
in_channels = self.config.hidden_size
self.cnn = nn.Sequential(*layers)
def forward(self, x):
return self.cnn(x)
class AIIAExpert(AIIA):
def __init__(self, config: AIIAConfig):
super(AIIAExpert, self).__init__(config)
self.base_cnn = AIIABase(config)
def __init__(self, config: AIIAConfig, **kwargs):
super(AIIAExpert, self).__init__(config, **kwargs)
self.base_cnn = AIIABase(config, **kwargs)
def forward(self, x):
return self.base_cnn(x)
class AIIAmoe(AIIA):
def __init__(self, config: AIIAConfig, num_experts: int = 3):
super(AIIAmoe, self).__init__(config)
self.experts = nn.ModuleList([AIIAExpert(config) for _ in range(num_experts)])
def __init__(self, config: AIIAConfig, **kwargs):
super(AIIAmoe, self).__init__(config, **kwargs)
# Get num_experts from updated config
num_Experts = getattr(self.config, 'num_Experts', 3)
self.experts = nn.ModuleList([AIIAExpert(config, **kwargs) for _ in range(num_Experts)])
# Update gate based on latest config values
self.gate = nn.Sequential(
nn.Linear(config.hidden_size, num_experts),
nn.Linear(self.config.hidden_size, num_Experts),
nn.Softmax(dim=1)
)
def forward(self, x):
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
gate_weights = self.gate(torch.mean(expert_outputs, (2, 3)))
merged_output = torch.sum(expert_outputs * gate_weights.unsqueeze(2).unsqueeze(3), dim=1)
merged_output = torch.sum(
expert_outputs * gate_weights.unsqueeze(2).unsqueeze(3), dim=1
)
return merged_output
class AIIAchunked(AIIA):
def __init__(self, config: AIIAConfig, patch_size: int = 16):
super(AIIAchunked, self).__init__(config)
def __init__(self, config: AIIAConfig, **kwargs):
super(AIIAchunked, self).__init__(config, **kwargs)
# Get patch_size from updated config
patch_size = getattr(self.config, 'patch_size', 16)
self.patch_size = patch_size
self.base_cnn = AIIABase(config)
# Initialize base CNN with updated config
self.base_cnn = AIIABase(config, **kwargs)
def forward(self, x):
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, self.patch_size, self.patch_size)
patch_outputs = []
for p in torch.split(patches, 1, dim=2):
p = p.squeeze(2)
po = self.base_cnn(p)
patch_outputs.append(po)
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
return combined_output
class AIIAresursive(AIIA):
def __init__(self, config: AIIAConfig, recursion_depth: int = 2):
super(AIIAresursive, self).__init__(config)
def __init__(self, config: AIIAConfig, **kwargs):
super(AIIAresursive, self).__init__(config, **kwargs)
# Get recursion_depth from updated config
recursion_depth = getattr(self.config, 'recursion_depth', 2)
self.recursion_depth = recursion_depth
self.chunked_cnn = AIIAchunked(config)
# Initialize chunked CNN with updated config
self.chunked_cnn = AIIAchunked(config, **kwargs)
def forward(self, x, depth=0):
if depth == self.recursion_depth:
@ -89,9 +115,11 @@ class AIIAresursive(AIIA):
patches = x.unfold(2, 16, 16).unfold(3, 16, 16)
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, 16, 16)
processed_patches = []
for p in torch.split(patches, 1, dim=2):
p = p.squeeze(2)
pp = self.forward(p, depth + 1)
processed_patches.append(pp)
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
return combined_output