feat/tf_support #37
|
@ -1,67 +1,48 @@
|
||||||
from .config import AIIAConfig
|
from .config import AIIAConfig
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedModel
|
from transformers import PreTrainedModel
|
||||||
import torch
|
import torch
|
||||||
import os
|
|
||||||
import copy
|
import copy
|
||||||
import warnings
|
|
||||||
|
|
||||||
|
|
||||||
class AIIA(nn.Module):
|
class AIIABase(PreTrainedModel):
|
||||||
def __init__(self, config: AIIAConfig, **kwargs):
|
config_class = AIIAConfig
|
||||||
super(AIIA, self).__init__()
|
base_model_prefix = "AIIA"
|
||||||
# Create a deep copy of the configuration to avoid sharing
|
|
||||||
self.config = copy.deepcopy(config)
|
|
||||||
|
|
||||||
# Update the config with any additional keyword arguments
|
def __init__(self, config: AIIAConfig):
|
||||||
for key, value in kwargs.items():
|
super().__init__(config)
|
||||||
setattr(self.config, key, value)
|
|
||||||
|
|
||||||
def save(self, path: str):
|
|
||||||
if not os.path.exists(path):
|
|
||||||
os.makedirs(path, exist_ok=True)
|
|
||||||
torch.save(self.state_dict(), f"{path}/model.pth")
|
|
||||||
self.config.save(path)
|
|
||||||
|
|
||||||
class AIIABase(AIIA):
|
|
||||||
def __init__(self, config: AIIAConfig, **kwargs):
|
|
||||||
super().__init__(config=config, **kwargs)
|
|
||||||
self.config = self.config
|
|
||||||
|
|
||||||
# Initialize layers based on configuration
|
# Initialize layers based on configuration
|
||||||
layers = []
|
layers = []
|
||||||
in_channels = self.config.num_channels
|
in_channels = config.num_channels
|
||||||
|
|
||||||
for _ in range(self.config.num_hidden_layers):
|
for _ in range(config.num_hidden_layers):
|
||||||
layers.extend([
|
layers.extend([
|
||||||
nn.Conv2d(in_channels, self.config.hidden_size,
|
nn.Conv2d(in_channels, config.hidden_size,
|
||||||
kernel_size=self.config.kernel_size, padding=1),
|
kernel_size=config.kernel_size, padding=1),
|
||||||
getattr(nn, self.config.activation_function)(),
|
getattr(nn, config.activation_function)(),
|
||||||
nn.MaxPool2d(kernel_size=1, stride=1)
|
nn.MaxPool2d(kernel_size=1, stride=1)
|
||||||
])
|
])
|
||||||
in_channels = self.config.hidden_size
|
in_channels = config.hidden_size
|
||||||
|
|
||||||
self.cnn = nn.Sequential(*layers)
|
self.cnn = nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.cnn(x)
|
return self.cnn(x)
|
||||||
|
|
||||||
class AIIABaseShared(AIIA):
|
class AIIABaseShared(PreTrainedModel):
|
||||||
def __init__(self, config: AIIAConfig, **kwargs):
|
config_class = AIIAConfig
|
||||||
|
base_model_prefix = "AIIA"
|
||||||
|
|
||||||
|
def __init__(self, config: AIIAConfig):
|
||||||
|
super().__init__(config)
|
||||||
"""
|
"""
|
||||||
Initialize the AIIABaseShared model.
|
Initialize the AIIABaseShared model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (AIIAConfig): Configuration object containing model parameters.
|
config (AIIAConfig): Configuration object containing model parameters.
|
||||||
**kwargs: Additional keyword arguments to override configuration settings.
|
|
||||||
"""
|
"""
|
||||||
super().__init__(config=config, **kwargs)
|
super().__init__(config=config)
|
||||||
|
|
||||||
# Update configuration with new parameters if provided
|
|
||||||
self. config = copy.deepcopy(config)
|
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
setattr(self.config, key, value)
|
|
||||||
|
|
||||||
# Initialize the network components
|
# Initialize the network components
|
||||||
self._initialize_network()
|
self._initialize_network()
|
||||||
|
@ -120,16 +101,17 @@ class AIIABaseShared(AIIA):
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class AIIAExpert(AIIA):
|
class AIIAExpert(PreTrainedModel):
|
||||||
def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs):
|
config_class = AIIAConfig
|
||||||
super().__init__(config=config, **kwargs)
|
base_model_prefix = "AIIA"
|
||||||
self.config = self.config
|
def __init__(self, config: AIIAConfig, base_class=AIIABase):
|
||||||
|
super().__init__(config=config)
|
||||||
|
|
||||||
# Initialize base CNN with configuration and chosen base class
|
# Initialize base CNN with configuration and chosen base class
|
||||||
if issubclass(base_class, AIIABase):
|
if issubclass(base_class, AIIABase):
|
||||||
self.base_cnn = AIIABase(self.config, **kwargs)
|
self.base_cnn = AIIABase(self.config)
|
||||||
elif issubclass(base_class, AIIABaseShared):
|
elif issubclass(base_class, AIIABaseShared):
|
||||||
self.base_cnn = AIIABaseShared(self.config, **kwargs)
|
self.base_cnn = AIIABaseShared(self.config)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid base class")
|
raise ValueError("Invalid base class")
|
||||||
|
|
||||||
|
@ -146,26 +128,26 @@ class AIIAExpert(AIIA):
|
||||||
# Process input through the base CNN
|
# Process input through the base CNN
|
||||||
return self.base_cnn(x)
|
return self.base_cnn(x)
|
||||||
|
|
||||||
class AIIAmoe(AIIA):
|
class AIIAmoe(PreTrainedModel):
|
||||||
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
|
config_class = AIIAConfig
|
||||||
super().__init__(config=config, **kwargs)
|
base_model_prefix = "AIIA"
|
||||||
|
|
||||||
|
def __init__(self, config: AIIAConfig, base_class=AIIABase):
|
||||||
|
super().__init__(config=config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Update the config to include the number of experts.
|
# Get num_experts directly from config instead of parameter
|
||||||
self.config.num_experts = num_experts
|
num_experts = getattr(config, "num_experts", 3) # default to 3 if not in config
|
||||||
|
|
||||||
# Initialize multiple experts from the chosen base class.
|
# Initialize multiple experts from the chosen base class
|
||||||
self.experts = nn.ModuleList([
|
self.experts = nn.ModuleList([
|
||||||
AIIAExpert(self.config, base_class=base_class, **kwargs)
|
AIIAExpert(self.config, base_class=base_class)
|
||||||
for _ in range(num_experts)
|
for _ in range(num_experts)
|
||||||
])
|
])
|
||||||
|
|
||||||
# To generate gating weights, we first need to determine the feature dimension.
|
gate_in_features = self.config.hidden_size
|
||||||
# Each expert is assumed to return an output of shape (B, C, H, W); after averaging over H and W,
|
|
||||||
# we obtain a tensor of shape (B, C) where C is the number of channels (here assumed to be 224).
|
|
||||||
gate_in_features = 512 # Adjust this if your expert output changes.
|
|
||||||
|
|
||||||
# Create a gating network that maps the aggregated features to num_experts weights.
|
# Create a gating network that maps the aggregated features to num_experts weights
|
||||||
self.gate = nn.Sequential(
|
self.gate = nn.Sequential(
|
||||||
nn.Linear(gate_in_features, num_experts),
|
nn.Linear(gate_in_features, num_experts),
|
||||||
nn.Softmax(dim=1)
|
nn.Softmax(dim=1)
|
||||||
|
@ -209,9 +191,10 @@ class AIIAmoe(AIIA):
|
||||||
|
|
||||||
|
|
||||||
class AIIASparseMoe(AIIAmoe):
|
class AIIASparseMoe(AIIAmoe):
|
||||||
def __init__(self, config: AIIAConfig, num_experts: int = 3, top_k: int = 2, base_class=AIIABase, **kwargs):
|
config_class = AIIAConfig
|
||||||
super().__init__(config=config, num_experts=num_experts, base_class=base_class, **kwargs)
|
base_model_prefix = "AIIA"
|
||||||
self.top_k = top_k
|
def __init__(self, config: AIIAConfig, base_class=AIIABase):
|
||||||
|
super().__init__(config=config, base_class=base_class)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
# Compute the gate_weights similar to standard moe.
|
# Compute the gate_weights similar to standard moe.
|
||||||
|
@ -221,7 +204,7 @@ class AIIASparseMoe(AIIAmoe):
|
||||||
gate_weights = self.gate(gate_input)
|
gate_weights = self.gate(gate_input)
|
||||||
|
|
||||||
# Select the top-k experts for each input based on gating weights.
|
# Select the top-k experts for each input based on gating weights.
|
||||||
_, top_k_indices = gate_weights.topk(self.top_k, dim=-1)
|
_, top_k_indices = gate_weights.topk(self.config.top_k, dim=-1)
|
||||||
|
|
||||||
# Initialize a list to store outputs from selected experts.
|
# Initialize a list to store outputs from selected experts.
|
||||||
merged_outputs = []
|
merged_outputs = []
|
||||||
|
@ -245,4 +228,4 @@ class AIIASparseMoe(AIIAmoe):
|
||||||
if __name__ =="__main__":
|
if __name__ =="__main__":
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
model = AIIAmoe(config, num_experts=5)
|
model = AIIAmoe(config, num_experts=5)
|
||||||
model.save("test")
|
model.save_pretrained("test")
|
Loading…
Reference in New Issue