AIIA/src/aiia/model/Model.py

323 lines
12 KiB
Python

from .config import AIIAConfig
from torch import nn
import torch
import os
import copy
import warnings
class AIIA(nn.Module):
def __init__(self, config: AIIAConfig, **kwargs):
super(AIIA, self).__init__()
# Create a deep copy of the configuration to avoid sharing
self.config = copy.deepcopy(config)
# Update the config with any additional keyword arguments
for key, value in kwargs.items():
setattr(self.config, key, value)
def save(self, path: str):
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
torch.save(self.state_dict(), f"{path}/model.pth")
self.config.save(path)
@classmethod
def load(cls, path, precision: str = None, strict: bool = True, **kwargs):
config = AIIAConfig.load(path)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the state dict to analyze structure
model_dict = torch.load(f"{path}/model.pth", map_location=device)
# Special handling for AIIAmoe - detect number of experts from state_dict
if cls.__name__ == "AIIAmoe" and "num_experts" not in kwargs:
# Find maximum expert index
max_expert_idx = -1
for key in model_dict.keys():
if key.startswith("experts."):
parts = key.split(".")
if len(parts) > 1:
try:
expert_idx = int(parts[1])
max_expert_idx = max(max_expert_idx, expert_idx)
except ValueError:
pass
if max_expert_idx >= 0:
# experts.X keys found, use max_expert_idx + 1 as num_experts
kwargs["num_experts"] = max_expert_idx + 1
# Create model with detected structural parameters
model = cls(config, **kwargs)
# Handle precision conversion
dtype = None
if precision is not None:
if precision.lower() == 'fp16':
dtype = torch.float16
elif precision.lower() == 'bf16':
if device == 'cuda' and not torch.cuda.is_bf16_supported():
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
dtype = torch.float16
else:
dtype = torch.bfloat16
else:
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
if dtype is not None:
for key, param in model_dict.items():
if torch.is_tensor(param):
model_dict[key] = param.to(dtype)
# Load state dict with strict parameter for flexibility
model.load_state_dict(model_dict, strict=strict)
return model
class AIIABase(AIIA):
def __init__(self, config: AIIAConfig, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Initialize layers based on configuration
layers = []
in_channels = self.config.num_channels
for _ in range(self.config.num_hidden_layers):
layers.extend([
nn.Conv2d(in_channels, self.config.hidden_size,
kernel_size=self.config.kernel_size, padding=1),
getattr(nn, self.config.activation_function)(),
nn.MaxPool2d(kernel_size=1, stride=1)
])
in_channels = self.config.hidden_size
self.cnn = nn.Sequential(*layers)
def forward(self, x):
return self.cnn(x)
class AIIABaseShared(AIIA):
def __init__(self, config: AIIAConfig, **kwargs):
"""
Initialize the AIIABaseShared model.
Args:
config (AIIAConfig): Configuration object containing model parameters.
**kwargs: Additional keyword arguments to override configuration settings.
"""
super().__init__(config=config, **kwargs)
# Update configuration with new parameters if provided
self. config = copy.deepcopy(config)
for key, value in kwargs.items():
setattr(self.config, key, value)
# Initialize the network components
self._initialize_network()
self._initialize_activation_andPooling()
def _initialize_network(self):
"""Initialize the shared and unique layers of the network."""
# Create a single shared convolutional layer
self.shared_layer = nn.Conv2d(
in_channels=self.config.num_channels,
out_channels=self.config.hidden_size,
kernel_size=self.config.kernel_size,
padding=1 # Using same padding as defined in config
)
# Initialize the unique layers with separate weights and biases
self.unique_layers = nn.ModuleList()
current_in_channels = self.config.hidden_size
layer = nn.Conv2d(
in_channels=current_in_channels,
out_channels=self.config.hidden_size,
kernel_size=self.config.kernel_size,
padding=1 # Using same padding as defined in config
)
self.unique_layers.append(layer)
def _initialize_activation_andPooling(self):
"""Initialize activation function and pooling layers."""
# Get activation function from nn module
self.activation = getattr(nn, self.config.activation_function)()
# Initialize max pooling layer
self.max_pool = nn.MaxPool2d(
kernel_size=1,
stride=1,
)
def forward(self, x):
"""Forward pass of the network."""
# Apply shared layer transformation
out = self.shared_layer(x)
# Pass through activation function
out = self.activation(out)
# Apply max pooling
out = self.max_pool(out)
# Pass through unique layers
for unique_layer in self.unique_layers:
out = unique_layer(out)
out = self.activation(out)
out = self.max_pool(out)
return out
class AIIAExpert(AIIA):
def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Initialize base CNN with configuration and chosen base class
if issubclass(base_class, AIIABase):
self.base_cnn = AIIABase(self.config, **kwargs)
elif issubclass(base_class, AIIABaseShared):
self.base_cnn = AIIABaseShared(self.config, **kwargs)
else:
raise ValueError("Invalid base class")
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the expert model.
Args:
x (torch.Tensor): Input tensor
Returns:
torch.Tensor: Output tensor after processing through base CNN
"""
# Process input through the base CNN
return self.base_cnn(x)
class AIIAmoe(AIIA):
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs)
self.config = config
# Update the config to include the number of experts.
self.config.num_experts = num_experts
# Initialize multiple experts from the chosen base class.
self.experts = nn.ModuleList([
AIIAExpert(self.config, base_class=base_class, **kwargs)
for _ in range(num_experts)
])
# To generate gating weights, we first need to determine the feature dimension.
# Each expert is assumed to return an output of shape (B, C, H, W); after averaging over H and W,
# we obtain a tensor of shape (B, C) where C is the number of channels (here assumed to be 224).
gate_in_features = 512 # Adjust this if your expert output changes.
# Create a gating network that maps the aggregated features to num_experts weights.
self.gate = nn.Sequential(
nn.Linear(gate_in_features, num_experts),
nn.Softmax(dim=1)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the Mixture-of-Experts model.
Args:
x (torch.Tensor): Input tensor
Returns:
torch.Tensor: Merged output tensor from all experts
"""
# Stack the outputs from each expert.
# Each expert's output should have shape (B, C, H, W). After stacking, expert_outputs has shape:
# (B, num_experts, C, H, W)
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
# Aggregate spatial features: average across the spatial dimensions (H, W).
# This results in a tensor with shape (B, num_experts, C)
spatial_avg = torch.mean(expert_outputs, dim=(3, 4))
# To feed the gating network, further average across the expert dimension,
# obtaining a tensor of shape (B, C) that represents the global feature summary.
gate_input = torch.mean(spatial_avg, dim=1)
# Compute gating weights using the gating network.
# The output gate_weights has shape (B, num_experts)
gate_weights = self.gate(gate_input)
# Expand the gate weights to match the expert outputs shape so they can be combined.
# After unsqueezing, gate_weights has shape (B, num_experts, 1, 1, 1)
gate_weights_expanded = gate_weights.unsqueeze(2).unsqueeze(3).unsqueeze(4)
# Multiply each expert's output by its corresponding gating weight and sum over experts.
# The merged_output retains the shape (B, C, H, W)
merged_output = torch.sum(expert_outputs * gate_weights_expanded, dim=1)
return merged_output
class AIIAchunked(AIIA):
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Update config with new parameters if provided
self.config.patch_size = patch_size
# Initialize base CNN for processing each patch using the specified base class
if issubclass(base_class, AIIABase):
self.base_cnn = AIIABase(self.config, **kwargs)
elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared
self.base_cnn = AIIABaseShared(self.config, **kwargs)
else:
raise ValueError("Invalid base class")
def forward(self, x):
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, self.patch_size, self.patch_size)
patch_outputs = []
for p in torch.split(patches, 1, dim=2):
p = p.squeeze(2)
po = self.base_cnn(p)
patch_outputs.append(po)
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
return combined_output
class AIIArecursive(AIIA):
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Pass recursion_depth as a kwarg to the config
self.config.recursion_depth = recursion_depth
# Initialize chunked CNN with updated config
self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs)
def forward(self, x, depth=0):
if depth == self.recursion_depth:
return self.chunked_cnn(x)
else:
patches = x.unfold(2, 16, 16).unfold(3, 16, 16)
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, 16, 16)
processed_patches = []
for p in torch.split(patches, 1, dim=2):
p = p.squeeze(2)
pp = self.forward(p, depth + 1)
processed_patches.append(pp)
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
return combined_output
if __name__ =="__main__":
config = AIIAConfig()
model = AIIAmoe(config, num_experts=5)
model.save("test")