diff --git a/README.md b/README.md index 333838b..90e2a0b 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ from aiia.model import AIIAConfig from aiia.pretrain import Pretrainer # Create your model -config = AIIAConfig(model_name="AIIA-Base-512x20k") +config = AIIAConfig(model_type="AIIA-Base-512x20k") model = AIIABase(config) # Initialize pretrainer with the model diff --git a/example.py b/example.py index 8dbb67e..2ce0b6a 100644 --- a/example.py +++ b/example.py @@ -1,10 +1,12 @@ -from aiia.model import AIIABase -from aiia.model import AIIAConfig -from aiia.pretrain import Pretrainer +from src.aiia.model import AIIAmoe +from src.aiia.model import AIIAConfig +from src.aiia.pretrain import Pretrainer # Create your model -config = AIIAConfig(model_name="AIIA-Base-512x20k") -model = AIIABase(config) +config = AIIAConfig(num_experts=5) +model = AIIAmoe(config) +model.save_pretrained("test") +model = AIIAmoe.from_pretrained("test") # Initialize pretrainer with the model pretrainer = Pretrainer(model, learning_rate=1e-4, config=config) diff --git a/pyproject.toml b/pyproject.toml index e4d1f2a..2220a98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ include = '\.pyi?$' [project] name = "aiia" -version = "0.2.1" +version = "0.3.0" description = "AIIA Deep Learning Model Implementation" readme = "README.md" authors = [ diff --git a/requirements.txt b/requirements.txt index 8e2d666..60fe329 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ pytest pillow pandas torchvision -pyarrow \ No newline at end of file +pyarrow +transformers>=4.48.0 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index eeb961d..26614c8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = aiia -version = 0.2.1 +version = 0.3.0 author = Falko Habel author_email = falko.habel@gmx.de description = AIIA deep learning model implementation diff --git a/src/aiia/__init__.py b/src/aiia/__init__.py index 838c765..699a42f 100644 --- a/src/aiia/__init__.py +++ b/src/aiia/__init__.py @@ -1,7 +1,7 @@ -from .model.Model import AIIABase, AIIABaseShared, AIIAchunked, AIIAmoe, AIIASparseMoe, AIIArecursive +from .model.Model import AIIABase, AIIABaseShared, AIIAmoe, AIIASparseMoe from .model.config import AIIAConfig from .data.DataLoader import DataLoader from .pretrain.pretrainer import Pretrainer, ProjectionHead -__version__ = "0.2.1" +__version__ = "0.3.0" diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index f21a0db..1cdd8dc 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -1,120 +1,49 @@ from .config import AIIAConfig from torch import nn +from transformers import PreTrainedModel import torch -import os import copy -import warnings -class AIIA(nn.Module): - def __init__(self, config: AIIAConfig, **kwargs): - super(AIIA, self).__init__() - # Create a deep copy of the configuration to avoid sharing - self.config = copy.deepcopy(config) - - # Update the config with any additional keyword arguments - for key, value in kwargs.items(): - setattr(self.config, key, value) - - def save(self, path: str): - if not os.path.exists(path): - os.makedirs(path, exist_ok=True) - torch.save(self.state_dict(), f"{path}/model.pth") - self.config.save(path) - - @classmethod - def load(cls, path, precision: str = None, strict: bool = True, **kwargs): - config = AIIAConfig.load(path) - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - # Load the state dict to analyze structure - model_dict = torch.load(f"{path}/model.pth", map_location=device) - - # Special handling for AIIAmoe - detect number of experts from state_dict - if cls.__name__ == "AIIAmoe" and "num_experts" not in kwargs: - # Find maximum expert index - max_expert_idx = -1 - for key in model_dict.keys(): - if key.startswith("experts."): - parts = key.split(".") - if len(parts) > 1: - try: - expert_idx = int(parts[1]) - max_expert_idx = max(max_expert_idx, expert_idx) - except ValueError: - pass - - if max_expert_idx >= 0: - # experts.X keys found, use max_expert_idx + 1 as num_experts - kwargs["num_experts"] = max_expert_idx + 1 - - # Create model with detected structural parameters - model = cls(config, **kwargs) - - # Handle precision conversion - dtype = None - if precision is not None: - if precision.lower() == 'fp16': - dtype = torch.float16 - elif precision.lower() == 'bf16': - if device == 'cuda' and not torch.cuda.is_bf16_supported(): - warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.") - dtype = torch.float16 - else: - dtype = torch.bfloat16 - else: - raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.") - - if dtype is not None: - for key, param in model_dict.items(): - if torch.is_tensor(param): - model_dict[key] = param.to(dtype) - - # Load state dict with strict parameter for flexibility - model.load_state_dict(model_dict, strict=strict) - return model - - -class AIIABase(AIIA): - def __init__(self, config: AIIAConfig, **kwargs): - super().__init__(config=config, **kwargs) - self.config = self.config +class AIIABase(PreTrainedModel): + config_class = AIIAConfig + base_model_prefix = "AIIA" + + def __init__(self, config: AIIAConfig): + super().__init__(config) # Initialize layers based on configuration layers = [] - in_channels = self.config.num_channels + in_channels = config.num_channels - for _ in range(self.config.num_hidden_layers): + for _ in range(config.num_hidden_layers): layers.extend([ - nn.Conv2d(in_channels, self.config.hidden_size, - kernel_size=self.config.kernel_size, padding=1), - getattr(nn, self.config.activation_function)(), + nn.Conv2d(in_channels, config.hidden_size, + kernel_size=config.kernel_size, padding=1), + getattr(nn, config.activation_function)(), nn.MaxPool2d(kernel_size=1, stride=1) ]) - in_channels = self.config.hidden_size + in_channels = config.hidden_size self.cnn = nn.Sequential(*layers) def forward(self, x): return self.cnn(x) -class AIIABaseShared(AIIA): - def __init__(self, config: AIIAConfig, **kwargs): +class AIIABaseShared(PreTrainedModel): + config_class = AIIAConfig + base_model_prefix = "AIIA" + + def __init__(self, config: AIIAConfig): + super().__init__(config) """ Initialize the AIIABaseShared model. Args: config (AIIAConfig): Configuration object containing model parameters. - **kwargs: Additional keyword arguments to override configuration settings. """ - super().__init__(config=config, **kwargs) - - # Update configuration with new parameters if provided - self. config = copy.deepcopy(config) - - for key, value in kwargs.items(): - setattr(self.config, key, value) - + super().__init__(config=config) + # Initialize the network components self._initialize_network() self._initialize_activation_andPooling() @@ -172,16 +101,17 @@ class AIIABaseShared(AIIA): return out -class AIIAExpert(AIIA): - def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs): - super().__init__(config=config, **kwargs) - self.config = self.config +class AIIAExpert(PreTrainedModel): + config_class = AIIAConfig + base_model_prefix = "AIIA" + def __init__(self, config: AIIAConfig, base_class=AIIABase): + super().__init__(config=config) # Initialize base CNN with configuration and chosen base class if issubclass(base_class, AIIABase): - self.base_cnn = AIIABase(self.config, **kwargs) + self.base_cnn = AIIABase(self.config) elif issubclass(base_class, AIIABaseShared): - self.base_cnn = AIIABaseShared(self.config, **kwargs) + self.base_cnn = AIIABaseShared(self.config) else: raise ValueError("Invalid base class") @@ -198,31 +128,31 @@ class AIIAExpert(AIIA): # Process input through the base CNN return self.base_cnn(x) -class AIIAmoe(AIIA): - def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs): - super().__init__(config=config, **kwargs) +class AIIAmoe(PreTrainedModel): + config_class = AIIAConfig + base_model_prefix = "AIIA" + + def __init__(self, config: AIIAConfig, base_class=AIIABase): + super().__init__(config=config) self.config = config - # Update the config to include the number of experts. - self.config.num_experts = num_experts - - # Initialize multiple experts from the chosen base class. + # Get num_experts directly from config instead of parameter + num_experts = getattr(config, "num_experts", 3) # default to 3 if not in config + + # Initialize multiple experts from the chosen base class self.experts = nn.ModuleList([ - AIIAExpert(self.config, base_class=base_class, **kwargs) + AIIAExpert(self.config, base_class=base_class) for _ in range(num_experts) ]) - # To generate gating weights, we first need to determine the feature dimension. - # Each expert is assumed to return an output of shape (B, C, H, W); after averaging over H and W, - # we obtain a tensor of shape (B, C) where C is the number of channels (here assumed to be 224). - gate_in_features = 512 # Adjust this if your expert output changes. + gate_in_features = self.config.hidden_size - # Create a gating network that maps the aggregated features to num_experts weights. + # Create a gating network that maps the aggregated features to num_experts weights self.gate = nn.Sequential( nn.Linear(gate_in_features, num_experts), nn.Softmax(dim=1) ) - + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for the Mixture-of-Experts model. @@ -261,9 +191,10 @@ class AIIAmoe(AIIA): class AIIASparseMoe(AIIAmoe): - def __init__(self, config: AIIAConfig, num_experts: int = 3, top_k: int = 2, base_class=AIIABase, **kwargs): - super().__init__(config=config, num_experts=num_experts, base_class=base_class, **kwargs) - self.top_k = top_k + config_class = AIIAConfig + base_model_prefix = "AIIA" + def __init__(self, config: AIIAConfig, base_class=AIIABase): + super().__init__(config=config, base_class=base_class) def forward(self, x: torch.Tensor) -> torch.Tensor: # Compute the gate_weights similar to standard moe. @@ -273,7 +204,7 @@ class AIIASparseMoe(AIIAmoe): gate_weights = self.gate(gate_input) # Select the top-k experts for each input based on gating weights. - _, top_k_indices = gate_weights.topk(self.top_k, dim=-1) + _, top_k_indices = gate_weights.topk(self.config.top_k, dim=-1) # Initialize a list to store outputs from selected experts. merged_outputs = [] @@ -294,64 +225,7 @@ class AIIASparseMoe(AIIAmoe): return torch.cat(merged_outputs, dim=0) -class AIIAchunked(AIIA): - def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs): - super().__init__(config=config, **kwargs) - self.config = self.config - - # Update config with new parameters if provided - self.config.patch_size = patch_size - - # Initialize base CNN for processing each patch using the specified base class - if issubclass(base_class, AIIABase): - self.base_cnn = AIIABase(self.config, **kwargs) - elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared - self.base_cnn = AIIABaseShared(self.config, **kwargs) - else: - raise ValueError("Invalid base class") - - def forward(self, x): - patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size) - patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, self.patch_size, self.patch_size) - patch_outputs = [] - - for p in torch.split(patches, 1, dim=2): - p = p.squeeze(2) - po = self.base_cnn(p) - patch_outputs.append(po) - - combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0) - return combined_output - -class AIIArecursive(AIIA): - def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs): - - super().__init__(config=config, **kwargs) - self.config = self.config - - # Pass recursion_depth as a kwarg to the config - self.config.recursion_depth = recursion_depth - - # Initialize chunked CNN with updated config - self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs) - - def forward(self, x, depth=0): - if depth == self.recursion_depth: - return self.chunked_cnn(x) - else: - patches = x.unfold(2, 16, 16).unfold(3, 16, 16) - patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, 16, 16) - processed_patches = [] - - for p in torch.split(patches, 1, dim=2): - p = p.squeeze(2) - pp = self.forward(p, depth + 1) - processed_patches.append(pp) - - combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0) - return combined_output - if __name__ =="__main__": config = AIIAConfig() model = AIIAmoe(config, num_experts=5) - model.save("test") \ No newline at end of file + model.save_pretrained("test") \ No newline at end of file diff --git a/src/aiia/model/__init__.py b/src/aiia/model/__init__.py index a45512a..88f3211 100644 --- a/src/aiia/model/__init__.py +++ b/src/aiia/model/__init__.py @@ -1,20 +1,15 @@ from .Model import ( AIIABase, AIIABaseShared, - AIIAchunked, AIIAmoe, AIIASparseMoe, - AIIArecursive ) from .config import AIIAConfig __all__ = [ "AIIABase", "AIIABaseShared", - "AIIAchunked", "AIIAmoe", "AIIASparseMoe", - "AIIArecursive", "AIIAConfig", - ] \ No newline at end of file diff --git a/src/aiia/model/config.py b/src/aiia/model/config.py index a83329a..7bf3d93 100644 --- a/src/aiia/model/config.py +++ b/src/aiia/model/config.py @@ -1,29 +1,25 @@ -import torch +from transformers import PretrainedConfig import torch.nn as nn -import json -import os +class AIIAConfig(PretrainedConfig): + model_type = "AIIA" # Add this class attribute -class AIIAConfig: def __init__( self, - model_name: str = "AIIA", kernel_size: int = 3, activation_function: str = "GELU", hidden_size: int = 512, num_hidden_layers: int = 12, num_channels: int = 3, - learning_rate: float = 5e-5, **kwargs ): - self.model_name = model_name + super().__init__(**kwargs) self.kernel_size = kernel_size self.activation_function = activation_function self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_channels = num_channels - self.learning_rate = learning_rate - + # Store additional keyword arguments as attributes for key, value in kwargs.items(): setattr(self, key, value) @@ -50,17 +46,4 @@ class AIIAConfig: elif isinstance(value, dict): return {k: serialize(v) for k, v in value.items()} return value - return {k: serialize(v) for k, v in self.__dict__.items()} - - def save(self, file_path): - if not os.path.exists(file_path): - os.makedirs(file_path, exist_ok=True) - with open(os.path.join(file_path, "config.json"), "w") as f: - # Save the recursively converted dictionary. - json.dump(self.to_dict(), f, indent=4) - - @classmethod - def load(cls, file_path): - with open(os.path.join(file_path, "config.json"), "r") as f: - config_dict = json.load(f) - return cls(**config_dict) \ No newline at end of file + return {k: serialize(v) for k, v in self.__dict__.items()} \ No newline at end of file diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index 42ba4b8..93df9dd 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -3,7 +3,7 @@ from torch import nn import csv import pandas as pd from tqdm import tqdm -from ..model.Model import AIIA +from transformers import PreTrainedModel from ..model.config import AIIAConfig from ..data.DataLoader import AIIADataLoader import os @@ -21,12 +21,12 @@ class ProjectionHead(nn.Module): return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task class Pretrainer: - def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None): + def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None): """ Initialize the pretrainer with a model. Args: - model (AIIA): The model instance to pretrain + model (PreTrainedModel): The model instance to pretrain learning_rate (float): Learning rate for optimization config (dict): Model configuration containing hidden_size """ @@ -186,11 +186,11 @@ class Pretrainer: if val_loss < best_val_loss: best_val_loss = val_loss - self.model.save(output_path) - print("Best model saved!") + self.model.save_pretrained(output_path) + print("Best model save_pretrainedd!") losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') - self.save_losses(losses_path) + self.save_pretrained_losses(losses_path) def _validate(self, val_loader, criterion_denoise, criterion_rotate): """Perform validation and return average validation loss.""" @@ -216,8 +216,8 @@ class Pretrainer: return avg_val_loss - def save_losses(self, csv_file): - """Save training and validation losses to a CSV file.""" + def save_pretrained_losses(self, csv_file): + """save_pretrained training and validation losses to a CSV file.""" data = list(zip( range(1, len(self.train_losses) + 1), self.train_losses, diff --git a/tests/model/test_aiia.py b/tests/model/test_aiia.py index bffa616..4891472 100644 --- a/tests/model/test_aiia.py +++ b/tests/model/test_aiia.py @@ -1,159 +1,133 @@ import os import torch -from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig, AIIASparseMoe +from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAConfig, AIIASparseMoe def test_aiiabase_creation(): config = AIIAConfig() model = AIIABase(config) assert isinstance(model, AIIABase) -def test_aiiabase_save_load(): +def test_aiiabase_save_pretrained_from_pretrained(): config = AIIAConfig() model = AIIABase(config) - save_path = "test_aiiabase_save_load" + save_pretrained_path = "test_aiiabase_save_pretrained_load" # Save the model - model.save(save_path) - assert os.path.exists(os.path.join(save_path, "model.pth")) - assert os.path.exists(os.path.join(save_path, "config.json")) + model.save_pretrained(save_pretrained_path) + assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) + assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) # Load the model - loaded_model = AIIABase.load(save_path) + loaded_model = AIIABase.from_pretrained(save_pretrained_path) # Check if the loaded model is an instance of AIIABase assert isinstance(loaded_model, AIIABase) # Clean up - os.remove(os.path.join(save_path, "model.pth")) - os.remove(os.path.join(save_path, "config.json")) - os.rmdir(save_path) + os.remove(os.path.join(save_pretrained_path, "model.safetensors")) + os.remove(os.path.join(save_pretrained_path, "config.json")) + os.rmdir(save_pretrained_path) def test_aiiabase_shared_creation(): config = AIIAConfig() model = AIIABaseShared(config) assert isinstance(model, AIIABaseShared) -def test_aiiabase_shared_save_load(): +def test_aiiabase_shared_save_pretrained_from_pretrained(): config = AIIAConfig() model = AIIABaseShared(config) - save_path = "test_aiiabase_shared_save_load" + save_pretrained_path = "test_aiiabase_shared_save_pretrained_load" # Save the model - model.save(save_path) - assert os.path.exists(os.path.join(save_path, "model.pth")) - assert os.path.exists(os.path.join(save_path, "config.json")) + model.save_pretrained(save_pretrained_path) + assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) + assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) # Load the model - loaded_model = AIIABaseShared.load(save_path) + loaded_model = AIIABaseShared.from_pretrained(save_pretrained_path) # Check if the loaded model is an instance of AIIABaseShared assert isinstance(loaded_model, AIIABaseShared) # Clean up - os.remove(os.path.join(save_path, "model.pth")) - os.remove(os.path.join(save_path, "config.json")) - os.rmdir(save_path) + os.remove(os.path.join(save_pretrained_path, "model.safetensors")) + os.remove(os.path.join(save_pretrained_path, "config.json")) + os.rmdir(save_pretrained_path) def test_aiiaexpert_creation(): config = AIIAConfig() model = AIIAExpert(config) assert isinstance(model, AIIAExpert) -def test_aiiaexpert_save_load(): +def test_aiiaexpert_save_pretrained_from_pretrained(): config = AIIAConfig() model = AIIAExpert(config) - save_path = "test_aiiaexpert_save_load" + save_pretrained_path = "test_aiiaexpert_save_pretrained_load" # Save the model - model.save(save_path) - assert os.path.exists(os.path.join(save_path, "model.pth")) - assert os.path.exists(os.path.join(save_path, "config.json")) + model.save_pretrained(save_pretrained_path) + assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) + assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) # Load the model - loaded_model = AIIAExpert.load(save_path) + loaded_model = AIIAExpert.from_pretrained(save_pretrained_path) # Check if the loaded model is an instance of AIIAExpert assert isinstance(loaded_model, AIIAExpert) # Clean up - os.remove(os.path.join(save_path, "model.pth")) - os.remove(os.path.join(save_path, "config.json")) - os.rmdir(save_path) + os.remove(os.path.join(save_pretrained_path, "model.safetensors")) + os.remove(os.path.join(save_pretrained_path, "config.json")) + os.rmdir(save_pretrained_path) def test_aiiamoe_creation(): - config = AIIAConfig() - model = AIIAmoe(config, num_experts=5) + config = AIIAConfig(num_experts=3) + model = AIIAmoe(config) assert isinstance(model, AIIAmoe) -def test_aiiamoe_save_load(): - config = AIIAConfig() - model = AIIAmoe(config, num_experts=5) - save_path = "test_aiiamoe_save_load" +def test_aiiamoe_save_pretrained_from_pretrained(): + config = AIIAConfig(num_experts=3) + model = AIIAmoe(config) + save_pretrained_path = "test_aiiamoe_save_pretrained_load" # Save the model - model.save(save_path) - assert os.path.exists(os.path.join(save_path, "model.pth")) - assert os.path.exists(os.path.join(save_path, "config.json")) + model.save_pretrained(save_pretrained_path) + assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) + assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) # Load the model - loaded_model = AIIAmoe.load(save_path) + loaded_model = AIIAmoe.from_pretrained(save_pretrained_path) # Check if the loaded model is an instance of AIIAmoe assert isinstance(loaded_model, AIIAmoe) # Clean up - os.remove(os.path.join(save_path, "model.pth")) - os.remove(os.path.join(save_path, "config.json")) - os.rmdir(save_path) + os.remove(os.path.join(save_pretrained_path, "model.safetensors")) + os.remove(os.path.join(save_pretrained_path, "config.json")) + os.rmdir(save_pretrained_path) def test_aiiasparsemoe_creation(): - config = AIIAConfig() - model = AIIASparseMoe(config, num_experts=5, top_k=2) + config = AIIAConfig(num_experts=5, top_k=2) + model = AIIASparseMoe(config, base_class=AIIABaseShared) assert isinstance(model, AIIASparseMoe) -def test_aiiasparsemoe_save_load(): - config = AIIAConfig() - model = AIIASparseMoe(config, num_experts=3, top_k=1) - save_path = "test_aiiasparsemoe_save_load" - +def test_aiiasparsemoe_save_pretrained_from_pretrained(): + config = AIIAConfig(num_experts=3, top_k=1) + model = AIIASparseMoe(config) + save_pretrained_path = "test_aiiasparsemoe_save_pretrained_load" + # Save the model - model.save(save_path) - assert os.path.exists(os.path.join(save_path, "model.pth")) - assert os.path.exists(os.path.join(save_path, "config.json")) + model.save_pretrained(save_pretrained_path) + assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors")) + assert os.path.exists(os.path.join(save_pretrained_path, "config.json")) # Load the model - loaded_model = AIIASparseMoe.load(save_path) + loaded_model = AIIASparseMoe.from_pretrained(save_pretrained_path) # Check if the loaded model is an instance of AIIASparseMoe assert isinstance(loaded_model, AIIASparseMoe) # Clean up - os.remove(os.path.join(save_path, "model.pth")) - os.remove(os.path.join(save_path, "config.json")) - os.rmdir(save_path) - -def test_aiiachunked_creation(): - config = AIIAConfig() - model = AIIAchunked(config) - assert isinstance(model, AIIAchunked) - -def test_aiiachunked_save_load(): - config = AIIAConfig() - model = AIIAchunked(config) - save_path = "test_aiiachunked_save_load" - - # Save the model - model.save(save_path) - assert os.path.exists(os.path.join(save_path, "model.pth")) - assert os.path.exists(os.path.join(save_path, "config.json")) - - # Load the model - loaded_model = AIIAchunked.load(save_path) - - # Check if the loaded model is an instance of AIIAchunked - assert isinstance(loaded_model, AIIAchunked) - - # Clean up - os.remove(os.path.join(save_path, "model.pth")) - os.remove(os.path.join(save_path, "config.json")) - os.rmdir(save_path) \ No newline at end of file + os.remove(os.path.join(save_pretrained_path, "model.safetensors")) + os.remove(os.path.join(save_pretrained_path, "config.json")) + os.rmdir(save_pretrained_path) \ No newline at end of file diff --git a/tests/model/test_config.py b/tests/model/test_config.py index 5542a79..5789660 100644 --- a/tests/model/test_config.py +++ b/tests/model/test_config.py @@ -1,75 +1,77 @@ import os import tempfile import pytest -import torch.nn as nn from aiia import AIIAConfig + def test_aiia_config_initialization(): config = AIIAConfig() - assert config.model_name == "AIIA" + assert config.model_type == "AIIA" assert config.kernel_size == 3 assert config.activation_function == "GELU" assert config.hidden_size == 512 assert config.num_hidden_layers == 12 assert config.num_channels == 3 - assert config.learning_rate == 5e-5 + def test_aiia_config_custom_initialization(): config = AIIAConfig( - model_name="CustomModel", + model_type="CustomModel", kernel_size=5, activation_function="ReLU", hidden_size=1024, num_hidden_layers=8, - num_channels=1, - learning_rate=1e-4 + num_channels=1 ) - assert config.model_name == "CustomModel" + assert config.model_type == "CustomModel" assert config.kernel_size == 5 assert config.activation_function == "ReLU" assert config.hidden_size == 1024 assert config.num_hidden_layers == 8 assert config.num_channels == 1 - assert config.learning_rate == 1e-4 + def test_aiia_config_invalid_activation_function(): with pytest.raises(ValueError): AIIAConfig(activation_function="InvalidFunction") + def test_aiia_config_to_dict(): config = AIIAConfig() config_dict = config.to_dict() assert isinstance(config_dict, dict) - assert config_dict["model_name"] == "AIIA" assert config_dict["kernel_size"] == 3 -def test_aiia_config_save_and_load(): - with tempfile.TemporaryDirectory() as tmpdir: - config = AIIAConfig(model_name="TempModel") - save_path = os.path.join(tmpdir, "config") - config.save(save_path) - loaded_config = AIIAConfig.load(save_path) - assert loaded_config.model_name == "TempModel" +def test_aiia_config_save_pretrained_and_from_pretrained(): + with tempfile.TemporaryDirectory() as tmpdir: + config = AIIAConfig(model_type="TempModel") + save_pretrained_path = os.path.join(tmpdir, "config") + config.save_pretrained(save_pretrained_path) + + loaded_config = AIIAConfig.from_pretrained(save_pretrained_path) + assert loaded_config.model_type == "TempModel" assert loaded_config.kernel_size == 3 assert loaded_config.activation_function == "GELU" -def test_aiia_config_save_and_load_with_custom_attributes(): - with tempfile.TemporaryDirectory() as tmpdir: - config = AIIAConfig(model_name="TempModel", custom_attr="value") - save_path = os.path.join(tmpdir, "config") - config.save(save_path) - loaded_config = AIIAConfig.load(save_path) - assert loaded_config.model_name == "TempModel" +def test_aiia_config_save_pretrained_and_load_with_custom_attributes(): + with tempfile.TemporaryDirectory() as tmpdir: + config = AIIAConfig(model_type="TempModel", custom_attr="value") + save_pretrained_path = os.path.join(tmpdir, "config") + config.save_pretrained(save_pretrained_path) + + loaded_config = AIIAConfig.from_pretrained(save_pretrained_path) + assert loaded_config.model_type == "TempModel" assert loaded_config.custom_attr == "value" -def test_aiia_config_save_and_load_with_nested_attributes(): - with tempfile.TemporaryDirectory() as tmpdir: - config = AIIAConfig(model_name="TempModel", nested={"key": "value"}) - save_path = os.path.join(tmpdir, "config") - config.save(save_path) - loaded_config = AIIAConfig.load(save_path) - assert loaded_config.model_name == "TempModel" - assert loaded_config.nested == {"key": "value"} +def test_aiia_config_save_pretrained_and_load_with_nested_attributes(): + with tempfile.TemporaryDirectory() as tmpdir: + config = AIIAConfig(model_type="TempModel", nested={"key": "value"}) + save_pretrained_path = os.path.join(tmpdir, "config") + config.save_pretrained(save_pretrained_path) + + loaded_config = AIIAConfig.from_pretrained(save_pretrained_path) + assert loaded_config.model_type == "TempModel" + assert loaded_config.nested == {"key": "value"} \ No newline at end of file diff --git a/tests/pretrain/test_pretrainer.py b/tests/pretrain/test_pretrainer.py index db4c73f..8f4283c 100644 --- a/tests/pretrain/test_pretrainer.py +++ b/tests/pretrain/test_pretrainer.py @@ -94,7 +94,7 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea # Execute training with patched methods with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss) as mock_process_batch, \ patch.object(Pretrainer, '_validate', side_effect=[0.8, 0.3]) as mock_validate, \ - patch.object(Pretrainer, 'save_losses') as mock_save_losses, \ + patch.object(Pretrainer, 'save_pretrained_losses') as mock_save_pretrained_losses, \ patch('builtins.open', mock_open()): pretrainer.train(dataset_paths, output_path=output_path, num_epochs=2) @@ -104,10 +104,10 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea assert mock_process_batch.call_count == 2 assert mock_validate.call_count == 2 - # Check for "Best model saved!" instead of model.save() - mock_print.assert_any_call("Best model saved!") + # Check for "Best model save_pretrainedd!" instead of model.save_pretrained() + mock_print.assert_any_call("Best model save_pretrainedd!") - mock_save_losses.assert_called_once() + mock_save_pretrained_losses.assert_called_once() # Verify state changes assert len(pretrainer.train_losses) == 2 @@ -153,13 +153,13 @@ def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat): pretrainer.projection_head = MagicMock() pretrainer.optimizer = MagicMock() - with patch.object(Pretrainer, 'save_losses') as mock_save_losses: + with patch.object(Pretrainer, 'save_pretrained_losses') as mock_save_pretrained_losses: pretrainer.train(['path/to/dataset.parquet'], num_epochs=1) # Verify empty loader behavior assert len(pretrainer.train_losses) == 1 assert pretrainer.train_losses[0] == 0.0 - mock_save_losses.assert_called_once() + mock_save_pretrained_losses.assert_called_once() @patch('pandas.concat') @patch('pandas.read_parquet') @@ -180,7 +180,7 @@ def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat) pretrainer.optimizer = MagicMock() with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \ - patch.object(Pretrainer, 'save_losses'): + patch.object(Pretrainer, 'save_pretrained_losses'): pretrainer.train(['path/to/dataset.parquet'], num_epochs=1) # Verify None batch handling @@ -212,7 +212,7 @@ def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_ custom_batch_size = 16 custom_sample_size = 5000 - with patch.object(Pretrainer, 'save_losses'): + with patch.object(Pretrainer, 'save_pretrained_losses'): pretrainer.train( ['path/to/dataset.parquet'], output_path=custom_output_path, @@ -233,7 +233,7 @@ def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_ @patch('aiia.pretrain.pretrainer.AIIADataLoader') @patch('builtins.print') # Add this to mock the print function def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat): - """Test that model is saved only when validation loss improves.""" + """Test that model is save_pretrainedd only when validation loss improves.""" real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) mock_read_parquet.return_value.head.return_value = real_df mock_concat.return_value = real_df @@ -262,11 +262,11 @@ def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_re # Test improving validation loss with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \ patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \ - patch.object(Pretrainer, 'save_losses'): + patch.object(Pretrainer, 'save_pretrained_losses'): pretrainer.train(['path/to/dataset.parquet'], num_epochs=3) - # Check for "Best model saved!" 3 times - assert mock_print.call_args_list.count(call("Best model saved!")) == 3 + # Check for "Best model save_pretrainedd!" 3 times + assert mock_print.call_args_list.count(call("Best model save_pretrainedd!")) == 3 # Reset for next test mock_print.reset_mock() @@ -278,11 +278,11 @@ def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_re # Test fluctuating validation loss with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \ patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \ - patch.object(Pretrainer, 'save_losses'): + patch.object(Pretrainer, 'save_pretrained_losses'): pretrainer.train(['path/to/dataset.parquet'], num_epochs=3) - # Should print "Best model saved!" only on first and third epochs - assert mock_print.call_args_list.count(call("Best model saved!")) == 2 + # Should print "Best model save_pretrainedd!" only on first and third epochs + assert mock_print.call_args_list.count(call("Best model save_pretrainedd!")) == 2 @patch('aiia.pretrain.pretrainer.Pretrainer._process_batch') @@ -296,13 +296,13 @@ def test_validate(mock_process_batch): loss = pretrainer._validate(val_loader, criterion_denoise, criterion_rotate) assert loss == 0.5 -# Test the save_losses method -@patch('aiia.pretrain.pretrainer.Pretrainer.save_losses') -def test_save_losses(mock_save_losses): +# Test the save_pretrained_losses method +@patch('aiia.pretrain.pretrainer.Pretrainer.save_pretrained_losses') +def test_save_pretrained_losses(mock_save_pretrained_losses): pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) pretrainer.train_losses = [0.1, 0.2] pretrainer.val_losses = [0.3, 0.4] csv_file = 'losses.csv' - pretrainer.save_losses(csv_file) - mock_save_losses.assert_called_once_with(csv_file) \ No newline at end of file + pretrainer.save_pretrained_losses(csv_file) + mock_save_pretrained_losses.assert_called_once_with(csv_file) \ No newline at end of file