models for training

This commit is contained in:
Falko Victor Habel 2025-01-12 20:49:22 +01:00
parent 6757718569
commit b371d747fd
3 changed files with 150 additions and 0 deletions

View File

@ -0,0 +1,4 @@
# Import submodules
from .model import AIIA, AIIAEncoder
from .data import AIIADataLoader
from .model.config import AIIAConfig

144
src/aiia/model/Model.py Normal file
View File

@ -0,0 +1,144 @@
import torch
import torch.nn as nn
import json
import os
from aiia.model.config import AIIAConfig
class AIIA(nn.Module):
def __init__(self, config: AIIAConfig):
super(AIIA, self).__init__()
self.patch_size = 2 * config.radius + 1
input_dim = self.patch_size * self.patch_size * config.num_channels
# Define layers based on the number of hidden layers in the config
self.layers = nn.ModuleList()
# First layer: input_dim to hidden_size
self.layers.append(nn.Linear(input_dim, config.hidden_size))
# Intermediate hidden layers: hidden_size to hidden_size
for _ in range(config.num_hidden_layers - 1):
self.layers.append(nn.Linear(config.hidden_size, config.hidden_size))
# Last layer: hidden_size back to input_dim
self.layers.append(nn.Linear(config.hidden_size, input_dim))
# Store the configuration
self.config = config
def forward(self, x):
for i, layer in enumerate(self.layers):
if i == len(self.layers) - 1:
# No activation function on the last layer
x = layer(x)
else:
# Apply the specified activation function to all but the last layer
x = self.activation_function(x)
return x
def activation_function(self, x):
if self.config.activation_function == "relu":
return torch.relu(x)
elif self.config.activation_function == "gelu":
return nn.functional.gelu(x)
elif self.config.activation_function == "sigmoid":
return torch.sigmoid(x)
elif self.config.activation_function == "tanh":
return torch.tanh(x)
else:
raise ValueError(f"Unsupported activation function: {self.config.activation_function}")
def predict(self, input_image, patch_size=None, stride=None):
if patch_size is None:
patch_size = 2 * self.config.radius + 1
if stride is None:
stride = patch_size // 2 # Overlap by half the patch size
# Extract patches from the input image
patches = self.extract_patches(input_image, patch_size, stride)
# Process each patch through the model
with torch.no_grad():
predictions = []
for patch in patches:
patch = patch.view(1, -1).to(self.device)
pred = self(patch)
predictions.append(pred.view(patch_size, patch_size, self.config.num_channels).cpu())
# Reconstruct the image from the predicted patches
output_image = torch.zeros_like(input_image)
count_map = torch.zeros_like(input_image)
patch_idx = 0
for y in range(0, input_image.shape[1] - patch_size + 1, stride):
for x in range(0, input_image.shape[2] - patch_size + 1, stride):
output_image[:, y:y+patch_size, x:x+patch_size] += predictions[patch_idx]
count_map[:, y:y+patch_size, x:x+patch_size] += 1
patch_idx += 1
# Average the overlapping predictions
output_image /= count_map
return output_image
@staticmethod
def extract_patches(image, patch_size, stride):
patches = []
for y in range(0, image.shape[1] - patch_size + 1, stride):
for x in range(0, image.shape[2] - patch_size + 1, stride):
patch = image[:, y:y+patch_size, x:x+patch_size]
patches.append(patch)
return patches
@staticmethod
def extract_patches(image, patch_size, stride=None):
if stride is None:
stride = patch_size
C, H, W = image.shape
patches = []
for y in range(0, H - patch_size + 1, stride):
for x in range(0, W - patch_size + 1, stride):
patch = image[:, y:y+patch_size, x:x+patch_size]
patches.append(patch)
return torch.stack(patches)
def save(self, folderpath: str):
# Ensure the folder exists
os.makedirs(folderpath, exist_ok=True)
# Save the model state dictionary
model_state_dict = self.state_dict()
# Serialize and save the configuration as JSON
with open(os.path.join(folderpath, 'config.json'), 'w') as f:
json.dump(self.config.__dict__, f)
# Save the model state dictionary
torch.save(model_state_dict, os.path.join(folderpath, 'model.pth'))
def load(self, folderpath: str):
with open(os.path.join(folderpath, 'config.json'), 'r') as f:
config_dict = json.load(f)
# Assuming Config has a constructor that takes a dictionary
self.config = AIIAConfig(**config_dict)
# Load the model state dictionary into the current instance
model_state_dict = torch.load(os.path.join(folderpath, 'model.pth'))
self.load_state_dict(model_state_dict)
return config_dict, model_state_dict
class AIIAEncoder(AIIA):
def __init__(self, config):
super().__init__(config)
self.encoder = torch.nn.Sequential(*list(self.layers.children())[:config.encoder_layers])
def forward(self, x):
return self.encoder(x)

View File

@ -0,0 +1,2 @@
from .config import AIIAConfig
from .Model import AIIA, AIIAEncoder