256 lines
9.3 KiB
Python
256 lines
9.3 KiB
Python
from .config import AIIAConfig
|
|
from torch import nn
|
|
import torch
|
|
import os
|
|
import copy
|
|
import warnings
|
|
|
|
|
|
class AIIA(nn.Module):
|
|
def __init__(self, config: AIIAConfig, **kwargs):
|
|
super(AIIA, self).__init__()
|
|
# Create a deep copy of the configuration to avoid sharing
|
|
self.config = copy.deepcopy(config)
|
|
|
|
# Update the config with any additional keyword arguments
|
|
for key, value in kwargs.items():
|
|
setattr(self.config, key, value)
|
|
|
|
def save(self, path: str):
|
|
if not os.path.exists(path):
|
|
os.makedirs(path, exist_ok=True)
|
|
torch.save(self.state_dict(), f"{path}/model.pth")
|
|
self.config.save(path)
|
|
|
|
@classmethod
|
|
def load(cls, path, precision: str = None):
|
|
config = AIIAConfig.load(path)
|
|
model = cls(config)
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
dtype = None
|
|
|
|
if precision is not None:
|
|
if precision.lower() == 'fp16':
|
|
dtype = torch.float16
|
|
elif precision.lower() == 'bf16':
|
|
if device == 'cuda' and not torch.cuda.is_bf16_supported():
|
|
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
|
|
dtype = torch.float16
|
|
else:
|
|
dtype = torch.bfloat16
|
|
else:
|
|
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
|
|
|
|
# Load the state dictionary normally (without dtype argument)
|
|
model_dict = torch.load(f"{path}/model.pth", map_location=device)
|
|
|
|
# If a precision conversion is requested, cast each tensor in the state dict to the target dtype.
|
|
if dtype is not None:
|
|
for key, param in model_dict.items():
|
|
if torch.is_tensor(param):
|
|
model_dict[key] = param.to(dtype)
|
|
|
|
model.load_state_dict(model_dict)
|
|
return model
|
|
|
|
|
|
class AIIABase(AIIA):
|
|
def __init__(self, config: AIIAConfig, **kwargs):
|
|
super().__init__(config=config, **kwargs)
|
|
self.config = self.config
|
|
|
|
# Initialize layers based on configuration
|
|
layers = []
|
|
in_channels = self.config.num_channels
|
|
|
|
for _ in range(self.config.num_hidden_layers):
|
|
layers.extend([
|
|
nn.Conv2d(in_channels, self.config.hidden_size,
|
|
kernel_size=self.config.kernel_size, padding=1),
|
|
getattr(nn, self.config.activation_function)(),
|
|
nn.MaxPool2d(kernel_size=1, stride=1)
|
|
])
|
|
in_channels = self.config.hidden_size
|
|
|
|
self.cnn = nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
return self.cnn(x)
|
|
|
|
class AIIABaseShared(AIIA):
|
|
def __init__(self, config: AIIAConfig, **kwargs):
|
|
"""
|
|
Initialize the AIIABaseShared model.
|
|
|
|
Args:
|
|
config (AIIAConfig): Configuration object containing model parameters.
|
|
**kwargs: Additional keyword arguments to override configuration settings.
|
|
"""
|
|
super().__init__(config=config, **kwargs)
|
|
|
|
# Update configuration with new parameters if provided
|
|
self. config = copy.deepcopy(config)
|
|
|
|
for key, value in kwargs.items():
|
|
setattr(self.config, key, value)
|
|
|
|
# Initialize the network components
|
|
self._initialize_network()
|
|
self._initialize_activation_andPooling()
|
|
|
|
def _initialize_network(self):
|
|
"""Initialize the shared and unique layers of the network."""
|
|
# Create a single shared convolutional layer
|
|
self.shared_layer = nn.Conv2d(
|
|
in_channels=self.config.num_channels,
|
|
out_channels=self.config.hidden_size,
|
|
kernel_size=self.config.kernel_size,
|
|
padding=1 # Using same padding as defined in config
|
|
)
|
|
|
|
# Initialize the unique layers with separate weights and biases
|
|
self.unique_layers = nn.ModuleList()
|
|
current_in_channels = self.config.hidden_size
|
|
|
|
layer = nn.Conv2d(
|
|
in_channels=current_in_channels,
|
|
out_channels=self.config.hidden_size,
|
|
kernel_size=self.config.kernel_size,
|
|
padding=1 # Using same padding as defined in config
|
|
)
|
|
|
|
self.unique_layers.append(layer)
|
|
|
|
def _initialize_activation_andPooling(self):
|
|
"""Initialize activation function and pooling layers."""
|
|
# Get activation function from nn module
|
|
self.activation = getattr(nn, self.config.activation_function)()
|
|
|
|
# Initialize max pooling layer
|
|
self.max_pool = nn.MaxPool2d(
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=1
|
|
)
|
|
|
|
def forward(self, x):
|
|
"""Forward pass of the network."""
|
|
# Apply shared layer transformation
|
|
out = self.shared_layer(x)
|
|
|
|
# Pass through activation function
|
|
out = self.activation(out)
|
|
|
|
# Apply max pooling
|
|
out = self.max_pool(out)
|
|
|
|
# Pass through unique layers
|
|
for unique_layer in self.unique_layers:
|
|
out = unique_layer(out)
|
|
out = self.activation(out)
|
|
out = self.max_pool(out)
|
|
|
|
return out
|
|
|
|
class AIIAExpert(AIIA):
|
|
def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs):
|
|
super().__init__(config=config, **kwargs)
|
|
self.config = self.config
|
|
|
|
# Initialize base CNN with configuration and chosen base class
|
|
if issubclass(base_class, AIIABase):
|
|
self.base_cnn = AIIABase(self.config, **kwargs)
|
|
elif issubclass(base_class, AIIABaseShared):
|
|
self.base_cnn = AIIABaseShared(self.config, **kwargs)
|
|
else:
|
|
raise ValueError("Invalid base class")
|
|
|
|
class AIIAmoe(AIIA):
|
|
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
|
|
super().__init__(config=config, **kwargs)
|
|
self.config = self.config
|
|
|
|
# Update config with new parameters if provided
|
|
self.config.num_experts = num_experts
|
|
|
|
# Initialize multiple experts using chosen base class
|
|
self.experts = nn.ModuleList([
|
|
AIIAExpert(self.config, base_class=base_class, **kwargs)
|
|
for _ in range(self.config.num_experts)
|
|
])
|
|
|
|
# Create gating network
|
|
self.gate = nn.Sequential(
|
|
nn.Linear(self.config.hidden_size, self.config.num_experts),
|
|
nn.Softmax(dim=1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
|
|
gate_weights = self.gate(torch.mean(expert_outputs, (2, 3)))
|
|
merged_output = torch.sum(
|
|
expert_outputs * gate_weights.unsqueeze(2).unsqueeze(3), dim=1
|
|
)
|
|
return merged_output
|
|
|
|
class AIIAchunked(AIIA):
|
|
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
|
|
super().__init__(config=config, **kwargs)
|
|
self.config = self.config
|
|
|
|
# Update config with new parameters if provided
|
|
self.config.patch_size = patch_size
|
|
|
|
# Initialize base CNN for processing each patch using the specified base class
|
|
if issubclass(base_class, AIIABase):
|
|
self.base_cnn = AIIABase(self.config, **kwargs)
|
|
elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared
|
|
self.base_cnn = AIIABaseShared(self.config, **kwargs)
|
|
else:
|
|
raise ValueError("Invalid base class")
|
|
|
|
def forward(self, x):
|
|
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
|
|
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, self.patch_size, self.patch_size)
|
|
patch_outputs = []
|
|
|
|
for p in torch.split(patches, 1, dim=2):
|
|
p = p.squeeze(2)
|
|
po = self.base_cnn(p)
|
|
patch_outputs.append(po)
|
|
|
|
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
|
|
return combined_output
|
|
|
|
class AIIArecursive(AIIA):
|
|
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs):
|
|
|
|
super().__init__(config=config, **kwargs)
|
|
self.config = self.config
|
|
|
|
# Pass recursion_depth as a kwarg to the config
|
|
self.config.recursion_depth = recursion_depth
|
|
|
|
# Initialize chunked CNN with updated config
|
|
self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs)
|
|
|
|
def forward(self, x, depth=0):
|
|
if depth == self.recursion_depth:
|
|
return self.chunked_cnn(x)
|
|
else:
|
|
patches = x.unfold(2, 16, 16).unfold(3, 16, 16)
|
|
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, 16, 16)
|
|
processed_patches = []
|
|
|
|
for p in torch.split(patches, 1, dim=2):
|
|
p = p.squeeze(2)
|
|
pp = self.forward(p, depth + 1)
|
|
processed_patches.append(pp)
|
|
|
|
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
|
|
return combined_output
|
|
|
|
if __name__ =="__main__":
|
|
config = AIIAConfig()
|
|
model = AIIAmoe(config, num_experts=5)
|
|
model.save("test") |