updated models and config to improve parameter handling and adding a copy function to use the same base config for mutliple models

This commit is contained in:
Falko Victor Habel 2025-01-22 14:23:03 +01:00
parent ab58d352c4
commit 6e6f4c4a21
2 changed files with 33 additions and 40 deletions

View File

@ -2,18 +2,19 @@ from config import AIIAConfig
from torch import nn
import torch
import os
import copy # Add this for deep copying
class AIIA(nn.Module):
def __init__(self, config: AIIAConfig, **kwargs):
super(AIIA, self).__init__()
self.config = config
# Create a deep copy of the configuration to avoid sharing
self.config = copy.deepcopy(config)
# Update the config with any additional keyword arguments
for key, value in kwargs.items():
setattr(self.config, key, value)
def save(self, path: str):
# Create the directory if it doesn't exist
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
torch.save(self.state_dict(), f"{path}/model.pth")
@ -28,42 +29,42 @@ class AIIA(nn.Module):
class AIIABase(AIIA):
def __init__(self, config: AIIAConfig, **kwargs):
self.config = config
super(AIIABase, self).__init__(config=config)
super().__init__(config=config, **kwargs)
self.config = self.config
# Initialize layers based on configuration
layers = []
in_channels = config.num_channels
in_channels = self.config.num_channels
for _ in range(config.num_hidden_layers):
for _ in range(self.config.num_hidden_layers):
layers.extend([
nn.Conv2d(in_channels, config.hidden_size,
kernel_size=config.kernel_size, padding=1),
getattr(nn, config.activation_function)(),
nn. MaxPool2d(kernel_size=2)
])
in_channels = config.hidden_size
nn.Conv2d(in_channels, self.config.hidden_size,
kernel_size=self.config.kernel_size, padding=1),
getattr(nn, self.config.activation_function)(),
nn.MaxPool2d(kernel_size=2)
])
in_channels = self.config.hidden_size
self.cnn = nn.Sequential(*layers)
def forward(self, x):
return self.cnn(x)
class AIIAExpert(AIIA):
def __init__(self, config: AIIAConfig, **kwargs):
self.config = config
super(AIIAExpert, self).__init__(config=config)
super().__init__(config=config, **kwargs)
self.config = self.config
# Initialize base CNN with configuration
self.base_cnn = AIIABase(self.config, **kwargs)
def forward(self, x):
return self. base_cnn(x)
return self.base_cnn(x)
class AIIAmoe(AIIA):
def __init__(self, config: AIIAConfig, num_experts: int = 3, **kwargs):
self.config = config
super(AIIAmoe, self).__init__(config=config)
super().__init__(config=config, **kwargs)
self.config = self.config
# Update config with new parameters if provided
self.config.num_experts = num_experts
@ -71,26 +72,26 @@ class AIIAmoe(AIIA):
# Initialize multiple experts
self.experts = nn.ModuleList([
AIIAExpert(self.config, **kwargs) for _ in range(num_experts)
])
])
# Create gating network
self. gate = nn.Sequential(
self.gate = nn.Sequential(
nn.Linear(self.config.hidden_size, num_experts),
nn.Softmax(dim=1)
)
)
def forward(self, x):
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
gate_weights = self.gate(torch.mean(expert_outputs, (2, 3)))
merged_output = torch.sum(
expert_outputs * gate_weights.unsqueeze(2).unsqueeze(3), dim=1
)
)
return merged_output
class AIIAchunked(AIIA):
def __init__(self, config: AIIAConfig, patch_size: int = 16, **kwargs):
self.config = config
super(AIIAchunked, self).__init__(config=config)
super().__init__(config=config, **kwargs)
self.config = self.config
# Update config with new parameters if provided
self.config.patch_size = patch_size
@ -113,12 +114,13 @@ class AIIAchunked(AIIA):
class AIIAresursive(AIIA):
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Pass recursion_depth as a kwarg to the config
self.config = config
super().__init__(config, recursion_depth=recursion_depth, **kwargs)
# Get recursion_depth from updated config
self.recursion_depth = getattr(self.config, 'recursion_depth', 2)
self.config.recursion_depth = recursion_depth
# Initialize chunked CNN with updated config
self.chunked_cnn = AIIAchunked(self.config, **kwargs)
@ -136,13 +138,4 @@ class AIIAresursive(AIIA):
processed_patches.append(pp)
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
return combined_output
config = AIIAConfig()
model = AIIABase(config)
model = AIIAmoe(config=config, num_experts=5)
model = AIIAresursive(config=config)
model.save("moe")
return combined_output

View File

@ -10,7 +10,7 @@ class AIIAConfig:
model_name: str = "AIIA",
kernel_size: int = 5,
activation_function: str = "GELU",
hidden_size: int = 256,
hidden_size: int = 512,
num_hidden_layers: int = 12,
num_channels: int = 3,
learning_rate: float = 5e-5,