models for training
This commit is contained in:
parent
6757718569
commit
b371d747fd
|
@ -0,0 +1,4 @@
|
|||
# Import submodules
|
||||
from .model import AIIA, AIIAEncoder
|
||||
from .data import AIIADataLoader
|
||||
from .model.config import AIIAConfig
|
|
@ -0,0 +1,144 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import json
|
||||
import os
|
||||
|
||||
from aiia.model.config import AIIAConfig
|
||||
|
||||
|
||||
class AIIA(nn.Module):
|
||||
def __init__(self, config: AIIAConfig):
|
||||
super(AIIA, self).__init__()
|
||||
self.patch_size = 2 * config.radius + 1
|
||||
input_dim = self.patch_size * self.patch_size * config.num_channels
|
||||
|
||||
# Define layers based on the number of hidden layers in the config
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
# First layer: input_dim to hidden_size
|
||||
self.layers.append(nn.Linear(input_dim, config.hidden_size))
|
||||
|
||||
# Intermediate hidden layers: hidden_size to hidden_size
|
||||
for _ in range(config.num_hidden_layers - 1):
|
||||
self.layers.append(nn.Linear(config.hidden_size, config.hidden_size))
|
||||
|
||||
# Last layer: hidden_size back to input_dim
|
||||
self.layers.append(nn.Linear(config.hidden_size, input_dim))
|
||||
|
||||
# Store the configuration
|
||||
self.config = config
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
if i == len(self.layers) - 1:
|
||||
# No activation function on the last layer
|
||||
x = layer(x)
|
||||
else:
|
||||
# Apply the specified activation function to all but the last layer
|
||||
x = self.activation_function(x)
|
||||
return x
|
||||
|
||||
def activation_function(self, x):
|
||||
if self.config.activation_function == "relu":
|
||||
return torch.relu(x)
|
||||
elif self.config.activation_function == "gelu":
|
||||
return nn.functional.gelu(x)
|
||||
elif self.config.activation_function == "sigmoid":
|
||||
return torch.sigmoid(x)
|
||||
elif self.config.activation_function == "tanh":
|
||||
return torch.tanh(x)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {self.config.activation_function}")
|
||||
|
||||
def predict(self, input_image, patch_size=None, stride=None):
|
||||
if patch_size is None:
|
||||
patch_size = 2 * self.config.radius + 1
|
||||
if stride is None:
|
||||
stride = patch_size // 2 # Overlap by half the patch size
|
||||
|
||||
# Extract patches from the input image
|
||||
patches = self.extract_patches(input_image, patch_size, stride)
|
||||
|
||||
# Process each patch through the model
|
||||
with torch.no_grad():
|
||||
predictions = []
|
||||
for patch in patches:
|
||||
patch = patch.view(1, -1).to(self.device)
|
||||
pred = self(patch)
|
||||
predictions.append(pred.view(patch_size, patch_size, self.config.num_channels).cpu())
|
||||
|
||||
# Reconstruct the image from the predicted patches
|
||||
output_image = torch.zeros_like(input_image)
|
||||
count_map = torch.zeros_like(input_image)
|
||||
|
||||
patch_idx = 0
|
||||
for y in range(0, input_image.shape[1] - patch_size + 1, stride):
|
||||
for x in range(0, input_image.shape[2] - patch_size + 1, stride):
|
||||
output_image[:, y:y+patch_size, x:x+patch_size] += predictions[patch_idx]
|
||||
count_map[:, y:y+patch_size, x:x+patch_size] += 1
|
||||
patch_idx += 1
|
||||
|
||||
# Average the overlapping predictions
|
||||
output_image /= count_map
|
||||
|
||||
return output_image
|
||||
|
||||
@staticmethod
|
||||
def extract_patches(image, patch_size, stride):
|
||||
patches = []
|
||||
for y in range(0, image.shape[1] - patch_size + 1, stride):
|
||||
for x in range(0, image.shape[2] - patch_size + 1, stride):
|
||||
patch = image[:, y:y+patch_size, x:x+patch_size]
|
||||
patches.append(patch)
|
||||
return patches
|
||||
|
||||
@staticmethod
|
||||
def extract_patches(image, patch_size, stride=None):
|
||||
if stride is None:
|
||||
stride = patch_size
|
||||
|
||||
C, H, W = image.shape
|
||||
patches = []
|
||||
|
||||
for y in range(0, H - patch_size + 1, stride):
|
||||
for x in range(0, W - patch_size + 1, stride):
|
||||
patch = image[:, y:y+patch_size, x:x+patch_size]
|
||||
patches.append(patch)
|
||||
|
||||
return torch.stack(patches)
|
||||
|
||||
def save(self, folderpath: str):
|
||||
# Ensure the folder exists
|
||||
os.makedirs(folderpath, exist_ok=True)
|
||||
|
||||
# Save the model state dictionary
|
||||
model_state_dict = self.state_dict()
|
||||
|
||||
# Serialize and save the configuration as JSON
|
||||
with open(os.path.join(folderpath, 'config.json'), 'w') as f:
|
||||
json.dump(self.config.__dict__, f)
|
||||
|
||||
# Save the model state dictionary
|
||||
torch.save(model_state_dict, os.path.join(folderpath, 'model.pth'))
|
||||
|
||||
def load(self, folderpath: str):
|
||||
with open(os.path.join(folderpath, 'config.json'), 'r') as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Assuming Config has a constructor that takes a dictionary
|
||||
self.config = AIIAConfig(**config_dict)
|
||||
|
||||
# Load the model state dictionary into the current instance
|
||||
model_state_dict = torch.load(os.path.join(folderpath, 'model.pth'))
|
||||
self.load_state_dict(model_state_dict)
|
||||
|
||||
return config_dict, model_state_dict
|
||||
|
||||
|
||||
class AIIAEncoder(AIIA):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.encoder = torch.nn.Sequential(*list(self.layers.children())[:config.encoder_layers])
|
||||
|
||||
def forward(self, x):
|
||||
return self.encoder(x)
|
|
@ -0,0 +1,2 @@
|
|||
from .config import AIIAConfig
|
||||
from .Model import AIIA, AIIAEncoder
|
Loading…
Reference in New Issue