converted to cnn models

This commit is contained in:
Falko Victor Habel 2025-01-20 13:25:36 +01:00
parent cbacd5e03c
commit 4c19838dab
2 changed files with 106 additions and 139 deletions

View File

@ -1,144 +1,97 @@
from config import AIIAConfig
from torch import nn
import torch
import torch.nn as nn
import json
import os
from aiia.model.config import AIIAConfig
class AIIA(nn.Module):
def __init__(self, config: AIIAConfig):
super(AIIA, self).__init__()
self.patch_size = 2 * config.radius + 1
input_dim = self.patch_size * self.patch_size * config.num_channels
# Define layers based on the number of hidden layers in the config
self.layers = nn.ModuleList()
# First layer: input_dim to hidden_size
self.layers.append(nn.Linear(input_dim, config.hidden_size))
# Intermediate hidden layers: hidden_size to hidden_size
for _ in range(config.num_hidden_layers - 1):
self.layers.append(nn.Linear(config.hidden_size, config.hidden_size))
# Last layer: hidden_size back to input_dim
self.layers.append(nn.Linear(config.hidden_size, input_dim))
# Store the configuration
self.config = config
def forward(self, x):
for i, layer in enumerate(self.layers):
if i == len(self.layers) - 1:
# No activation function on the last layer
x = layer(x)
else:
# Apply the specified activation function to all but the last layer
x = self.activation_function(x)
return x
def save(self, model_path, config_path):
torch.save(self.state_dict(), model_path)
self.config.save(config_path)
def activation_function(self, x):
if self.config.activation_function == "relu":
return torch.relu(x)
elif self.config.activation_function == "gelu":
return nn.functional.gelu(x)
elif self.config.activation_function == "sigmoid":
return torch.sigmoid(x)
elif self.config.activation_function == "tanh":
return torch.tanh(x)
@classmethod
def load(cls, config_path, model_path):
config = AIIAConfig.load(config_path)
model = cls(config)
model.load_state_dict(torch.load(model_path))
return model
class AIIABase(AIIA):
def __init__(self, config: AIIAConfig):
super(AIIABase, self).__init__(config)
layers = []
in_channels = config.num_channels
for _ in range(config.num_hidden_layers):
layers.extend([
nn.Conv2d(in_channels, config.hidden_size, kernel_size=config.kernel_size, padding=1),
getattr(nn, config.activation_function)(),
nn.MaxPool2d(kernel_size=2)
])
in_channels = config.hidden_size
self.cnn = nn.Sequential(*layers)
def forward(self, x):
return self.cnn(x)
class AIIAExpert(AIIA):
def __init__(self, config: AIIAConfig):
super(AIIAExpert, self).__init__(config)
self.base_cnn = AIIABase(config)
def forward(self, x):
return self.base_cnn(x)
class AIIAmoe(AIIA):
def __init__(self, config: AIIAConfig, num_experts: int = 3):
super(AIIAmoe, self).__init__(config)
self.experts = nn.ModuleList([AIIAExpert(config) for _ in range(num_experts)])
self.gate = nn.Sequential(
nn.Linear(config.hidden_size, num_experts),
nn.Softmax(dim=1)
)
def forward(self, x):
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
gate_weights = self.gate(torch.mean(expert_outputs, (2, 3)))
merged_output = torch.sum(expert_outputs * gate_weights.unsqueeze(2).unsqueeze(3), dim=1)
return merged_output
class AIIAchunked(AIIA):
def __init__(self, config: AIIAConfig, patch_size: int = 16):
super(AIIAchunked, self).__init__(config)
self.patch_size = patch_size
self.base_cnn = AIIABase(config)
def forward(self, x):
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, self.patch_size, self.patch_size)
patch_outputs = []
for p in torch.split(patches, 1, dim=2):
p = p.squeeze(2)
po = self.base_cnn(p)
patch_outputs.append(po)
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
return combined_output
class AIIAresursive(AIIA):
def __init__(self, config: AIIAConfig, recursion_depth: int = 2):
super(AIIAresursive, self).__init__(config)
self.recursion_depth = recursion_depth
self.chunked_cnn = AIIAchunked(config)
def forward(self, x, depth=0):
if depth == self.recursion_depth:
return self.chunked_cnn(x)
else:
raise ValueError(f"Unsupported activation function: {self.config.activation_function}")
def predict(self, input_image, patch_size=None, stride=None):
if patch_size is None:
patch_size = 2 * self.config.radius + 1
if stride is None:
stride = patch_size // 2 # Overlap by half the patch size
# Extract patches from the input image
patches = self.extract_patches(input_image, patch_size, stride)
# Process each patch through the model
with torch.no_grad():
predictions = []
for patch in patches:
patch = patch.view(1, -1).to(self.device)
pred = self(patch)
predictions.append(pred.view(patch_size, patch_size, self.config.num_channels).cpu())
# Reconstruct the image from the predicted patches
output_image = torch.zeros_like(input_image)
count_map = torch.zeros_like(input_image)
patch_idx = 0
for y in range(0, input_image.shape[1] - patch_size + 1, stride):
for x in range(0, input_image.shape[2] - patch_size + 1, stride):
output_image[:, y:y+patch_size, x:x+patch_size] += predictions[patch_idx]
count_map[:, y:y+patch_size, x:x+patch_size] += 1
patch_idx += 1
# Average the overlapping predictions
output_image /= count_map
return output_image
@staticmethod
def extract_patches(image, patch_size, stride):
patches = []
for y in range(0, image.shape[1] - patch_size + 1, stride):
for x in range(0, image.shape[2] - patch_size + 1, stride):
patch = image[:, y:y+patch_size, x:x+patch_size]
patches.append(patch)
return patches
@staticmethod
def extract_patches(image, patch_size, stride=None):
if stride is None:
stride = patch_size
C, H, W = image.shape
patches = []
for y in range(0, H - patch_size + 1, stride):
for x in range(0, W - patch_size + 1, stride):
patch = image[:, y:y+patch_size, x:x+patch_size]
patches.append(patch)
return torch.stack(patches)
def save(self, folderpath: str):
# Ensure the folder exists
os.makedirs(folderpath, exist_ok=True)
# Save the model state dictionary
model_state_dict = self.state_dict()
# Serialize and save the configuration as JSON
with open(os.path.join(folderpath, 'config.json'), 'w') as f:
json.dump(self.config.__dict__, f)
# Save the model state dictionary
torch.save(model_state_dict, os.path.join(folderpath, 'model.pth'))
def load(self, folderpath: str):
with open(os.path.join(folderpath, 'config.json'), 'r') as f:
config_dict = json.load(f)
# Assuming Config has a constructor that takes a dictionary
self.config = AIIAConfig(**config_dict)
# Load the model state dictionary into the current instance
model_state_dict = torch.load(os.path.join(folderpath, 'model.pth'))
self.load_state_dict(model_state_dict)
return config_dict, model_state_dict
class AIIAEncoder(AIIA):
def __init__(self, config):
super().__init__(config)
self.encoder = torch.nn.Sequential(*list(self.layers.children())[:config.encoder_layers])
def forward(self, x):
return self.encoder(x)
patches = x.unfold(2, 16, 16).unfold(3, 16, 16)
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, 16, 16)
processed_patches = []
for p in torch.split(patches, 1, dim=2):
p = p.squeeze(2)
pp = self.forward(p, depth + 1)
processed_patches.append(pp)
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
return combined_output

View File

@ -1,30 +1,44 @@
import torch
import torch.nn as nn
import json
class AIIAConfig:
def __init__(
self,
model_name: str = "AIIA",
radius: int = 3,
kernel_size: int = 3,
activation_function: str = "gelu",
hidden_size: int = 128,
num_hidden_layers: int = 2,
num_channels: int = 3,
learning_rate: float = 5e5
):
self.model_name = model_name
self.radius = radius
self.kernel_size = kernel_size
self.activation_function = activation_function
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_channels = num_channels
self.learning_rate = learning_rate
@property
def activation_function(self):
return self._activation_function
@activation_function.setter
def activation_function(self, value):
attr = getattr(nn, value, None)
if attr is None or (not callable(attr) and not isinstance(attr, type(nn.Module))):
valid_funcs = [func for func in dir(nn) if callable(getattr(nn, func)) or isinstance(getattr(nn, func), type(nn.Module))]
raise ValueError(f"Invalid activation function: {value}. Choose from: {', '.join(valid_funcs)}")
self._activation_function = value
def save(self, file_path):
# Save config to JSON
with open(file_path, 'w') as f:
json.dump(self.__dict__, f)
@classmethod
def load(cls, file_path):
# Load config from JSON
with open(file_path, 'r') as f:
config_dict = json.load(f)
return cls(**config_dict)
return cls(**config_dict)