converted to cnn models
This commit is contained in:
parent
cbacd5e03c
commit
4c19838dab
|
@ -1,144 +1,97 @@
|
|||
from config import AIIAConfig
|
||||
from torch import nn
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import json
|
||||
import os
|
||||
|
||||
from aiia.model.config import AIIAConfig
|
||||
|
||||
|
||||
class AIIA(nn.Module):
|
||||
def __init__(self, config: AIIAConfig):
|
||||
super(AIIA, self).__init__()
|
||||
self.patch_size = 2 * config.radius + 1
|
||||
input_dim = self.patch_size * self.patch_size * config.num_channels
|
||||
|
||||
# Define layers based on the number of hidden layers in the config
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
# First layer: input_dim to hidden_size
|
||||
self.layers.append(nn.Linear(input_dim, config.hidden_size))
|
||||
|
||||
# Intermediate hidden layers: hidden_size to hidden_size
|
||||
for _ in range(config.num_hidden_layers - 1):
|
||||
self.layers.append(nn.Linear(config.hidden_size, config.hidden_size))
|
||||
|
||||
# Last layer: hidden_size back to input_dim
|
||||
self.layers.append(nn.Linear(config.hidden_size, input_dim))
|
||||
|
||||
# Store the configuration
|
||||
self.config = config
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
if i == len(self.layers) - 1:
|
||||
# No activation function on the last layer
|
||||
x = layer(x)
|
||||
else:
|
||||
# Apply the specified activation function to all but the last layer
|
||||
x = self.activation_function(x)
|
||||
return x
|
||||
def save(self, model_path, config_path):
|
||||
torch.save(self.state_dict(), model_path)
|
||||
self.config.save(config_path)
|
||||
|
||||
def activation_function(self, x):
|
||||
if self.config.activation_function == "relu":
|
||||
return torch.relu(x)
|
||||
elif self.config.activation_function == "gelu":
|
||||
return nn.functional.gelu(x)
|
||||
elif self.config.activation_function == "sigmoid":
|
||||
return torch.sigmoid(x)
|
||||
elif self.config.activation_function == "tanh":
|
||||
return torch.tanh(x)
|
||||
@classmethod
|
||||
def load(cls, config_path, model_path):
|
||||
config = AIIAConfig.load(config_path)
|
||||
model = cls(config)
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
return model
|
||||
|
||||
class AIIABase(AIIA):
|
||||
def __init__(self, config: AIIAConfig):
|
||||
super(AIIABase, self).__init__(config)
|
||||
layers = []
|
||||
in_channels = config.num_channels
|
||||
for _ in range(config.num_hidden_layers):
|
||||
layers.extend([
|
||||
nn.Conv2d(in_channels, config.hidden_size, kernel_size=config.kernel_size, padding=1),
|
||||
getattr(nn, config.activation_function)(),
|
||||
nn.MaxPool2d(kernel_size=2)
|
||||
])
|
||||
in_channels = config.hidden_size
|
||||
self.cnn = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.cnn(x)
|
||||
|
||||
class AIIAExpert(AIIA):
|
||||
def __init__(self, config: AIIAConfig):
|
||||
super(AIIAExpert, self).__init__(config)
|
||||
self.base_cnn = AIIABase(config)
|
||||
|
||||
def forward(self, x):
|
||||
return self.base_cnn(x)
|
||||
|
||||
class AIIAmoe(AIIA):
|
||||
def __init__(self, config: AIIAConfig, num_experts: int = 3):
|
||||
super(AIIAmoe, self).__init__(config)
|
||||
self.experts = nn.ModuleList([AIIAExpert(config) for _ in range(num_experts)])
|
||||
self.gate = nn.Sequential(
|
||||
nn.Linear(config.hidden_size, num_experts),
|
||||
nn.Softmax(dim=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
|
||||
gate_weights = self.gate(torch.mean(expert_outputs, (2, 3)))
|
||||
merged_output = torch.sum(expert_outputs * gate_weights.unsqueeze(2).unsqueeze(3), dim=1)
|
||||
return merged_output
|
||||
|
||||
class AIIAchunked(AIIA):
|
||||
def __init__(self, config: AIIAConfig, patch_size: int = 16):
|
||||
super(AIIAchunked, self).__init__(config)
|
||||
self.patch_size = patch_size
|
||||
self.base_cnn = AIIABase(config)
|
||||
|
||||
def forward(self, x):
|
||||
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
|
||||
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, self.patch_size, self.patch_size)
|
||||
patch_outputs = []
|
||||
for p in torch.split(patches, 1, dim=2):
|
||||
p = p.squeeze(2)
|
||||
po = self.base_cnn(p)
|
||||
patch_outputs.append(po)
|
||||
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
|
||||
return combined_output
|
||||
|
||||
class AIIAresursive(AIIA):
|
||||
def __init__(self, config: AIIAConfig, recursion_depth: int = 2):
|
||||
super(AIIAresursive, self).__init__(config)
|
||||
self.recursion_depth = recursion_depth
|
||||
self.chunked_cnn = AIIAchunked(config)
|
||||
|
||||
def forward(self, x, depth=0):
|
||||
if depth == self.recursion_depth:
|
||||
return self.chunked_cnn(x)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {self.config.activation_function}")
|
||||
|
||||
def predict(self, input_image, patch_size=None, stride=None):
|
||||
if patch_size is None:
|
||||
patch_size = 2 * self.config.radius + 1
|
||||
if stride is None:
|
||||
stride = patch_size // 2 # Overlap by half the patch size
|
||||
|
||||
# Extract patches from the input image
|
||||
patches = self.extract_patches(input_image, patch_size, stride)
|
||||
|
||||
# Process each patch through the model
|
||||
with torch.no_grad():
|
||||
predictions = []
|
||||
for patch in patches:
|
||||
patch = patch.view(1, -1).to(self.device)
|
||||
pred = self(patch)
|
||||
predictions.append(pred.view(patch_size, patch_size, self.config.num_channels).cpu())
|
||||
|
||||
# Reconstruct the image from the predicted patches
|
||||
output_image = torch.zeros_like(input_image)
|
||||
count_map = torch.zeros_like(input_image)
|
||||
|
||||
patch_idx = 0
|
||||
for y in range(0, input_image.shape[1] - patch_size + 1, stride):
|
||||
for x in range(0, input_image.shape[2] - patch_size + 1, stride):
|
||||
output_image[:, y:y+patch_size, x:x+patch_size] += predictions[patch_idx]
|
||||
count_map[:, y:y+patch_size, x:x+patch_size] += 1
|
||||
patch_idx += 1
|
||||
|
||||
# Average the overlapping predictions
|
||||
output_image /= count_map
|
||||
|
||||
return output_image
|
||||
|
||||
@staticmethod
|
||||
def extract_patches(image, patch_size, stride):
|
||||
patches = []
|
||||
for y in range(0, image.shape[1] - patch_size + 1, stride):
|
||||
for x in range(0, image.shape[2] - patch_size + 1, stride):
|
||||
patch = image[:, y:y+patch_size, x:x+patch_size]
|
||||
patches.append(patch)
|
||||
return patches
|
||||
|
||||
@staticmethod
|
||||
def extract_patches(image, patch_size, stride=None):
|
||||
if stride is None:
|
||||
stride = patch_size
|
||||
|
||||
C, H, W = image.shape
|
||||
patches = []
|
||||
|
||||
for y in range(0, H - patch_size + 1, stride):
|
||||
for x in range(0, W - patch_size + 1, stride):
|
||||
patch = image[:, y:y+patch_size, x:x+patch_size]
|
||||
patches.append(patch)
|
||||
|
||||
return torch.stack(patches)
|
||||
|
||||
def save(self, folderpath: str):
|
||||
# Ensure the folder exists
|
||||
os.makedirs(folderpath, exist_ok=True)
|
||||
|
||||
# Save the model state dictionary
|
||||
model_state_dict = self.state_dict()
|
||||
|
||||
# Serialize and save the configuration as JSON
|
||||
with open(os.path.join(folderpath, 'config.json'), 'w') as f:
|
||||
json.dump(self.config.__dict__, f)
|
||||
|
||||
# Save the model state dictionary
|
||||
torch.save(model_state_dict, os.path.join(folderpath, 'model.pth'))
|
||||
|
||||
def load(self, folderpath: str):
|
||||
with open(os.path.join(folderpath, 'config.json'), 'r') as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Assuming Config has a constructor that takes a dictionary
|
||||
self.config = AIIAConfig(**config_dict)
|
||||
|
||||
# Load the model state dictionary into the current instance
|
||||
model_state_dict = torch.load(os.path.join(folderpath, 'model.pth'))
|
||||
self.load_state_dict(model_state_dict)
|
||||
|
||||
return config_dict, model_state_dict
|
||||
|
||||
|
||||
class AIIAEncoder(AIIA):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.encoder = torch.nn.Sequential(*list(self.layers.children())[:config.encoder_layers])
|
||||
|
||||
def forward(self, x):
|
||||
return self.encoder(x)
|
||||
patches = x.unfold(2, 16, 16).unfold(3, 16, 16)
|
||||
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, 16, 16)
|
||||
processed_patches = []
|
||||
for p in torch.split(patches, 1, dim=2):
|
||||
p = p.squeeze(2)
|
||||
pp = self.forward(p, depth + 1)
|
||||
processed_patches.append(pp)
|
||||
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
|
||||
return combined_output
|
||||
|
|
|
@ -1,30 +1,44 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import json
|
||||
|
||||
class AIIAConfig:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "AIIA",
|
||||
radius: int = 3,
|
||||
kernel_size: int = 3,
|
||||
activation_function: str = "gelu",
|
||||
hidden_size: int = 128,
|
||||
num_hidden_layers: int = 2,
|
||||
num_channels: int = 3,
|
||||
learning_rate: float = 5e5
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.radius = radius
|
||||
self.kernel_size = kernel_size
|
||||
self.activation_function = activation_function
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_channels = num_channels
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
@property
|
||||
def activation_function(self):
|
||||
return self._activation_function
|
||||
|
||||
@activation_function.setter
|
||||
def activation_function(self, value):
|
||||
attr = getattr(nn, value, None)
|
||||
if attr is None or (not callable(attr) and not isinstance(attr, type(nn.Module))):
|
||||
valid_funcs = [func for func in dir(nn) if callable(getattr(nn, func)) or isinstance(getattr(nn, func), type(nn.Module))]
|
||||
raise ValueError(f"Invalid activation function: {value}. Choose from: {', '.join(valid_funcs)}")
|
||||
self._activation_function = value
|
||||
|
||||
def save(self, file_path):
|
||||
# Save config to JSON
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(self.__dict__, f)
|
||||
|
||||
@classmethod
|
||||
def load(cls, file_path):
|
||||
# Load config from JSON
|
||||
with open(file_path, 'r') as f:
|
||||
config_dict = json.load(f)
|
||||
return cls(**config_dict)
|
||||
return cls(**config_dict)
|
||||
|
|
Loading…
Reference in New Issue