diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index 118acb0..1cdd8dc 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -1,68 +1,49 @@ from .config import AIIAConfig from torch import nn -from transformers import PretrainedModel +from transformers import PreTrainedModel import torch -import os import copy -import warnings -class AIIA(nn.Module): - def __init__(self, config: AIIAConfig, **kwargs): - super(AIIA, self).__init__() - # Create a deep copy of the configuration to avoid sharing - self.config = copy.deepcopy(config) - - # Update the config with any additional keyword arguments - for key, value in kwargs.items(): - setattr(self.config, key, value) - - def save(self, path: str): - if not os.path.exists(path): - os.makedirs(path, exist_ok=True) - torch.save(self.state_dict(), f"{path}/model.pth") - self.config.save(path) - -class AIIABase(AIIA): - def __init__(self, config: AIIAConfig, **kwargs): - super().__init__(config=config, **kwargs) - self.config = self.config +class AIIABase(PreTrainedModel): + config_class = AIIAConfig + base_model_prefix = "AIIA" + + def __init__(self, config: AIIAConfig): + super().__init__(config) # Initialize layers based on configuration layers = [] - in_channels = self.config.num_channels + in_channels = config.num_channels - for _ in range(self.config.num_hidden_layers): + for _ in range(config.num_hidden_layers): layers.extend([ - nn.Conv2d(in_channels, self.config.hidden_size, - kernel_size=self.config.kernel_size, padding=1), - getattr(nn, self.config.activation_function)(), + nn.Conv2d(in_channels, config.hidden_size, + kernel_size=config.kernel_size, padding=1), + getattr(nn, config.activation_function)(), nn.MaxPool2d(kernel_size=1, stride=1) ]) - in_channels = self.config.hidden_size + in_channels = config.hidden_size self.cnn = nn.Sequential(*layers) def forward(self, x): return self.cnn(x) -class AIIABaseShared(AIIA): - def __init__(self, config: AIIAConfig, **kwargs): +class AIIABaseShared(PreTrainedModel): + config_class = AIIAConfig + base_model_prefix = "AIIA" + + def __init__(self, config: AIIAConfig): + super().__init__(config) """ Initialize the AIIABaseShared model. Args: config (AIIAConfig): Configuration object containing model parameters. - **kwargs: Additional keyword arguments to override configuration settings. """ - super().__init__(config=config, **kwargs) - - # Update configuration with new parameters if provided - self. config = copy.deepcopy(config) - - for key, value in kwargs.items(): - setattr(self.config, key, value) - + super().__init__(config=config) + # Initialize the network components self._initialize_network() self._initialize_activation_andPooling() @@ -120,16 +101,17 @@ class AIIABaseShared(AIIA): return out -class AIIAExpert(AIIA): - def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs): - super().__init__(config=config, **kwargs) - self.config = self.config +class AIIAExpert(PreTrainedModel): + config_class = AIIAConfig + base_model_prefix = "AIIA" + def __init__(self, config: AIIAConfig, base_class=AIIABase): + super().__init__(config=config) # Initialize base CNN with configuration and chosen base class if issubclass(base_class, AIIABase): - self.base_cnn = AIIABase(self.config, **kwargs) + self.base_cnn = AIIABase(self.config) elif issubclass(base_class, AIIABaseShared): - self.base_cnn = AIIABaseShared(self.config, **kwargs) + self.base_cnn = AIIABaseShared(self.config) else: raise ValueError("Invalid base class") @@ -146,31 +128,31 @@ class AIIAExpert(AIIA): # Process input through the base CNN return self.base_cnn(x) -class AIIAmoe(AIIA): - def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs): - super().__init__(config=config, **kwargs) +class AIIAmoe(PreTrainedModel): + config_class = AIIAConfig + base_model_prefix = "AIIA" + + def __init__(self, config: AIIAConfig, base_class=AIIABase): + super().__init__(config=config) self.config = config - # Update the config to include the number of experts. - self.config.num_experts = num_experts - - # Initialize multiple experts from the chosen base class. + # Get num_experts directly from config instead of parameter + num_experts = getattr(config, "num_experts", 3) # default to 3 if not in config + + # Initialize multiple experts from the chosen base class self.experts = nn.ModuleList([ - AIIAExpert(self.config, base_class=base_class, **kwargs) + AIIAExpert(self.config, base_class=base_class) for _ in range(num_experts) ]) - # To generate gating weights, we first need to determine the feature dimension. - # Each expert is assumed to return an output of shape (B, C, H, W); after averaging over H and W, - # we obtain a tensor of shape (B, C) where C is the number of channels (here assumed to be 224). - gate_in_features = 512 # Adjust this if your expert output changes. + gate_in_features = self.config.hidden_size - # Create a gating network that maps the aggregated features to num_experts weights. + # Create a gating network that maps the aggregated features to num_experts weights self.gate = nn.Sequential( nn.Linear(gate_in_features, num_experts), nn.Softmax(dim=1) ) - + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for the Mixture-of-Experts model. @@ -209,9 +191,10 @@ class AIIAmoe(AIIA): class AIIASparseMoe(AIIAmoe): - def __init__(self, config: AIIAConfig, num_experts: int = 3, top_k: int = 2, base_class=AIIABase, **kwargs): - super().__init__(config=config, num_experts=num_experts, base_class=base_class, **kwargs) - self.top_k = top_k + config_class = AIIAConfig + base_model_prefix = "AIIA" + def __init__(self, config: AIIAConfig, base_class=AIIABase): + super().__init__(config=config, base_class=base_class) def forward(self, x: torch.Tensor) -> torch.Tensor: # Compute the gate_weights similar to standard moe. @@ -221,7 +204,7 @@ class AIIASparseMoe(AIIAmoe): gate_weights = self.gate(gate_input) # Select the top-k experts for each input based on gating weights. - _, top_k_indices = gate_weights.topk(self.top_k, dim=-1) + _, top_k_indices = gate_weights.topk(self.config.top_k, dim=-1) # Initialize a list to store outputs from selected experts. merged_outputs = [] @@ -245,4 +228,4 @@ class AIIASparseMoe(AIIAmoe): if __name__ =="__main__": config = AIIAConfig() model = AIIAmoe(config, num_experts=5) - model.save("test") \ No newline at end of file + model.save_pretrained("test") \ No newline at end of file