feat/tf_support #37

Merged
Fabel merged 13 commits from feat/tf_support into develop 2025-04-16 20:59:48 +00:00
13 changed files with 184 additions and 353 deletions

View File

@ -29,7 +29,7 @@ from aiia.model import AIIAConfig
from aiia.pretrain import Pretrainer from aiia.pretrain import Pretrainer
# Create your model # Create your model
config = AIIAConfig(model_name="AIIA-Base-512x20k") config = AIIAConfig(model_type="AIIA-Base-512x20k")
model = AIIABase(config) model = AIIABase(config)
# Initialize pretrainer with the model # Initialize pretrainer with the model

View File

@ -1,10 +1,12 @@
from aiia.model import AIIABase from src.aiia.model import AIIAmoe
from aiia.model import AIIAConfig from src.aiia.model import AIIAConfig
from aiia.pretrain import Pretrainer from src.aiia.pretrain import Pretrainer
# Create your model # Create your model
config = AIIAConfig(model_name="AIIA-Base-512x20k") config = AIIAConfig(num_experts=5)
model = AIIABase(config) model = AIIAmoe(config)
model.save_pretrained("test")
model = AIIAmoe.from_pretrained("test")
# Initialize pretrainer with the model # Initialize pretrainer with the model
pretrainer = Pretrainer(model, learning_rate=1e-4, config=config) pretrainer = Pretrainer(model, learning_rate=1e-4, config=config)

View File

@ -10,7 +10,7 @@ include = '\.pyi?$'
[project] [project]
name = "aiia" name = "aiia"
version = "0.2.1" version = "0.3.0"
description = "AIIA Deep Learning Model Implementation" description = "AIIA Deep Learning Model Implementation"
readme = "README.md" readme = "README.md"
authors = [ authors = [

View File

@ -6,3 +6,4 @@ pillow
pandas pandas
torchvision torchvision
pyarrow pyarrow
transformers>=4.48.0

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = aiia name = aiia
version = 0.2.1 version = 0.3.0
author = Falko Habel author = Falko Habel
author_email = falko.habel@gmx.de author_email = falko.habel@gmx.de
description = AIIA deep learning model implementation description = AIIA deep learning model implementation

View File

@ -1,7 +1,7 @@
from .model.Model import AIIABase, AIIABaseShared, AIIAchunked, AIIAmoe, AIIASparseMoe, AIIArecursive from .model.Model import AIIABase, AIIABaseShared, AIIAmoe, AIIASparseMoe
from .model.config import AIIAConfig from .model.config import AIIAConfig
from .data.DataLoader import DataLoader from .data.DataLoader import DataLoader
from .pretrain.pretrainer import Pretrainer, ProjectionHead from .pretrain.pretrainer import Pretrainer, ProjectionHead
__version__ = "0.2.1" __version__ = "0.3.0"

View File

@ -1,119 +1,48 @@
from .config import AIIAConfig from .config import AIIAConfig
from torch import nn from torch import nn
from transformers import PreTrainedModel
import torch import torch
import os
import copy import copy
import warnings
class AIIA(nn.Module): class AIIABase(PreTrainedModel):
def __init__(self, config: AIIAConfig, **kwargs): config_class = AIIAConfig
super(AIIA, self).__init__() base_model_prefix = "AIIA"
# Create a deep copy of the configuration to avoid sharing
self.config = copy.deepcopy(config)
# Update the config with any additional keyword arguments def __init__(self, config: AIIAConfig):
for key, value in kwargs.items(): super().__init__(config)
setattr(self.config, key, value)
def save(self, path: str):
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
torch.save(self.state_dict(), f"{path}/model.pth")
self.config.save(path)
@classmethod
def load(cls, path, precision: str = None, strict: bool = True, **kwargs):
config = AIIAConfig.load(path)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the state dict to analyze structure
model_dict = torch.load(f"{path}/model.pth", map_location=device)
# Special handling for AIIAmoe - detect number of experts from state_dict
if cls.__name__ == "AIIAmoe" and "num_experts" not in kwargs:
# Find maximum expert index
max_expert_idx = -1
for key in model_dict.keys():
if key.startswith("experts."):
parts = key.split(".")
if len(parts) > 1:
try:
expert_idx = int(parts[1])
max_expert_idx = max(max_expert_idx, expert_idx)
except ValueError:
pass
if max_expert_idx >= 0:
# experts.X keys found, use max_expert_idx + 1 as num_experts
kwargs["num_experts"] = max_expert_idx + 1
# Create model with detected structural parameters
model = cls(config, **kwargs)
# Handle precision conversion
dtype = None
if precision is not None:
if precision.lower() == 'fp16':
dtype = torch.float16
elif precision.lower() == 'bf16':
if device == 'cuda' and not torch.cuda.is_bf16_supported():
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
dtype = torch.float16
else:
dtype = torch.bfloat16
else:
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
if dtype is not None:
for key, param in model_dict.items():
if torch.is_tensor(param):
model_dict[key] = param.to(dtype)
# Load state dict with strict parameter for flexibility
model.load_state_dict(model_dict, strict=strict)
return model
class AIIABase(AIIA):
def __init__(self, config: AIIAConfig, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Initialize layers based on configuration # Initialize layers based on configuration
layers = [] layers = []
in_channels = self.config.num_channels in_channels = config.num_channels
for _ in range(self.config.num_hidden_layers): for _ in range(config.num_hidden_layers):
layers.extend([ layers.extend([
nn.Conv2d(in_channels, self.config.hidden_size, nn.Conv2d(in_channels, config.hidden_size,
kernel_size=self.config.kernel_size, padding=1), kernel_size=config.kernel_size, padding=1),
getattr(nn, self.config.activation_function)(), getattr(nn, config.activation_function)(),
nn.MaxPool2d(kernel_size=1, stride=1) nn.MaxPool2d(kernel_size=1, stride=1)
]) ])
in_channels = self.config.hidden_size in_channels = config.hidden_size
self.cnn = nn.Sequential(*layers) self.cnn = nn.Sequential(*layers)
def forward(self, x): def forward(self, x):
return self.cnn(x) return self.cnn(x)
class AIIABaseShared(AIIA): class AIIABaseShared(PreTrainedModel):
def __init__(self, config: AIIAConfig, **kwargs): config_class = AIIAConfig
base_model_prefix = "AIIA"
def __init__(self, config: AIIAConfig):
super().__init__(config)
""" """
Initialize the AIIABaseShared model. Initialize the AIIABaseShared model.
Args: Args:
config (AIIAConfig): Configuration object containing model parameters. config (AIIAConfig): Configuration object containing model parameters.
**kwargs: Additional keyword arguments to override configuration settings.
""" """
super().__init__(config=config, **kwargs) super().__init__(config=config)
# Update configuration with new parameters if provided
self. config = copy.deepcopy(config)
for key, value in kwargs.items():
setattr(self.config, key, value)
# Initialize the network components # Initialize the network components
self._initialize_network() self._initialize_network()
@ -172,16 +101,17 @@ class AIIABaseShared(AIIA):
return out return out
class AIIAExpert(AIIA): class AIIAExpert(PreTrainedModel):
def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs): config_class = AIIAConfig
super().__init__(config=config, **kwargs) base_model_prefix = "AIIA"
self.config = self.config def __init__(self, config: AIIAConfig, base_class=AIIABase):
super().__init__(config=config)
# Initialize base CNN with configuration and chosen base class # Initialize base CNN with configuration and chosen base class
if issubclass(base_class, AIIABase): if issubclass(base_class, AIIABase):
self.base_cnn = AIIABase(self.config, **kwargs) self.base_cnn = AIIABase(self.config)
elif issubclass(base_class, AIIABaseShared): elif issubclass(base_class, AIIABaseShared):
self.base_cnn = AIIABaseShared(self.config, **kwargs) self.base_cnn = AIIABaseShared(self.config)
else: else:
raise ValueError("Invalid base class") raise ValueError("Invalid base class")
@ -198,26 +128,26 @@ class AIIAExpert(AIIA):
# Process input through the base CNN # Process input through the base CNN
return self.base_cnn(x) return self.base_cnn(x)
class AIIAmoe(AIIA): class AIIAmoe(PreTrainedModel):
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs): config_class = AIIAConfig
super().__init__(config=config, **kwargs) base_model_prefix = "AIIA"
def __init__(self, config: AIIAConfig, base_class=AIIABase):
super().__init__(config=config)
self.config = config self.config = config
# Update the config to include the number of experts. # Get num_experts directly from config instead of parameter
self.config.num_experts = num_experts num_experts = getattr(config, "num_experts", 3) # default to 3 if not in config
# Initialize multiple experts from the chosen base class. # Initialize multiple experts from the chosen base class
self.experts = nn.ModuleList([ self.experts = nn.ModuleList([
AIIAExpert(self.config, base_class=base_class, **kwargs) AIIAExpert(self.config, base_class=base_class)
for _ in range(num_experts) for _ in range(num_experts)
]) ])
# To generate gating weights, we first need to determine the feature dimension. gate_in_features = self.config.hidden_size
# Each expert is assumed to return an output of shape (B, C, H, W); after averaging over H and W,
# we obtain a tensor of shape (B, C) where C is the number of channels (here assumed to be 224).
gate_in_features = 512 # Adjust this if your expert output changes.
# Create a gating network that maps the aggregated features to num_experts weights. # Create a gating network that maps the aggregated features to num_experts weights
self.gate = nn.Sequential( self.gate = nn.Sequential(
nn.Linear(gate_in_features, num_experts), nn.Linear(gate_in_features, num_experts),
nn.Softmax(dim=1) nn.Softmax(dim=1)
@ -261,9 +191,10 @@ class AIIAmoe(AIIA):
class AIIASparseMoe(AIIAmoe): class AIIASparseMoe(AIIAmoe):
def __init__(self, config: AIIAConfig, num_experts: int = 3, top_k: int = 2, base_class=AIIABase, **kwargs): config_class = AIIAConfig
super().__init__(config=config, num_experts=num_experts, base_class=base_class, **kwargs) base_model_prefix = "AIIA"
self.top_k = top_k def __init__(self, config: AIIAConfig, base_class=AIIABase):
super().__init__(config=config, base_class=base_class)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# Compute the gate_weights similar to standard moe. # Compute the gate_weights similar to standard moe.
@ -273,7 +204,7 @@ class AIIASparseMoe(AIIAmoe):
gate_weights = self.gate(gate_input) gate_weights = self.gate(gate_input)
# Select the top-k experts for each input based on gating weights. # Select the top-k experts for each input based on gating weights.
_, top_k_indices = gate_weights.topk(self.top_k, dim=-1) _, top_k_indices = gate_weights.topk(self.config.top_k, dim=-1)
# Initialize a list to store outputs from selected experts. # Initialize a list to store outputs from selected experts.
merged_outputs = [] merged_outputs = []
@ -294,64 +225,7 @@ class AIIASparseMoe(AIIAmoe):
return torch.cat(merged_outputs, dim=0) return torch.cat(merged_outputs, dim=0)
class AIIAchunked(AIIA):
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Update config with new parameters if provided
self.config.patch_size = patch_size
# Initialize base CNN for processing each patch using the specified base class
if issubclass(base_class, AIIABase):
self.base_cnn = AIIABase(self.config, **kwargs)
elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared
self.base_cnn = AIIABaseShared(self.config, **kwargs)
else:
raise ValueError("Invalid base class")
def forward(self, x):
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, self.patch_size, self.patch_size)
patch_outputs = []
for p in torch.split(patches, 1, dim=2):
p = p.squeeze(2)
po = self.base_cnn(p)
patch_outputs.append(po)
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
return combined_output
class AIIArecursive(AIIA):
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Pass recursion_depth as a kwarg to the config
self.config.recursion_depth = recursion_depth
# Initialize chunked CNN with updated config
self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs)
def forward(self, x, depth=0):
if depth == self.recursion_depth:
return self.chunked_cnn(x)
else:
patches = x.unfold(2, 16, 16).unfold(3, 16, 16)
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, 16, 16)
processed_patches = []
for p in torch.split(patches, 1, dim=2):
p = p.squeeze(2)
pp = self.forward(p, depth + 1)
processed_patches.append(pp)
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
return combined_output
if __name__ =="__main__": if __name__ =="__main__":
config = AIIAConfig() config = AIIAConfig()
model = AIIAmoe(config, num_experts=5) model = AIIAmoe(config, num_experts=5)
model.save("test") model.save_pretrained("test")

View File

@ -1,20 +1,15 @@
from .Model import ( from .Model import (
AIIABase, AIIABase,
AIIABaseShared, AIIABaseShared,
AIIAchunked,
AIIAmoe, AIIAmoe,
AIIASparseMoe, AIIASparseMoe,
AIIArecursive
) )
from .config import AIIAConfig from .config import AIIAConfig
__all__ = [ __all__ = [
"AIIABase", "AIIABase",
"AIIABaseShared", "AIIABaseShared",
"AIIAchunked",
"AIIAmoe", "AIIAmoe",
"AIIASparseMoe", "AIIASparseMoe",
"AIIArecursive",
"AIIAConfig", "AIIAConfig",
] ]

View File

@ -1,28 +1,24 @@
import torch from transformers import PretrainedConfig
import torch.nn as nn import torch.nn as nn
import json
import os
class AIIAConfig(PretrainedConfig):
model_type = "AIIA" # Add this class attribute
class AIIAConfig:
def __init__( def __init__(
self, self,
model_name: str = "AIIA",
kernel_size: int = 3, kernel_size: int = 3,
activation_function: str = "GELU", activation_function: str = "GELU",
hidden_size: int = 512, hidden_size: int = 512,
num_hidden_layers: int = 12, num_hidden_layers: int = 12,
num_channels: int = 3, num_channels: int = 3,
learning_rate: float = 5e-5,
**kwargs **kwargs
): ):
self.model_name = model_name super().__init__(**kwargs)
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.activation_function = activation_function self.activation_function = activation_function
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_channels = num_channels self.num_channels = num_channels
self.learning_rate = learning_rate
# Store additional keyword arguments as attributes # Store additional keyword arguments as attributes
for key, value in kwargs.items(): for key, value in kwargs.items():
@ -51,16 +47,3 @@ class AIIAConfig:
return {k: serialize(v) for k, v in value.items()} return {k: serialize(v) for k, v in value.items()}
return value return value
return {k: serialize(v) for k, v in self.__dict__.items()} return {k: serialize(v) for k, v in self.__dict__.items()}
def save(self, file_path):
if not os.path.exists(file_path):
os.makedirs(file_path, exist_ok=True)
with open(os.path.join(file_path, "config.json"), "w") as f:
# Save the recursively converted dictionary.
json.dump(self.to_dict(), f, indent=4)
@classmethod
def load(cls, file_path):
with open(os.path.join(file_path, "config.json"), "r") as f:
config_dict = json.load(f)
return cls(**config_dict)

View File

@ -3,7 +3,7 @@ from torch import nn
import csv import csv
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from ..model.Model import AIIA from transformers import PreTrainedModel
from ..model.config import AIIAConfig from ..model.config import AIIAConfig
from ..data.DataLoader import AIIADataLoader from ..data.DataLoader import AIIADataLoader
import os import os
@ -21,12 +21,12 @@ class ProjectionHead(nn.Module):
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
class Pretrainer: class Pretrainer:
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None): def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None):
""" """
Initialize the pretrainer with a model. Initialize the pretrainer with a model.
Args: Args:
model (AIIA): The model instance to pretrain model (PreTrainedModel): The model instance to pretrain
learning_rate (float): Learning rate for optimization learning_rate (float): Learning rate for optimization
config (dict): Model configuration containing hidden_size config (dict): Model configuration containing hidden_size
""" """
@ -186,11 +186,11 @@ class Pretrainer:
if val_loss < best_val_loss: if val_loss < best_val_loss:
best_val_loss = val_loss best_val_loss = val_loss
self.model.save(output_path) self.model.save_pretrained(output_path)
print("Best model saved!") print("Best model save_pretrainedd!")
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
self.save_losses(losses_path) self.save_pretrained_losses(losses_path)
def _validate(self, val_loader, criterion_denoise, criterion_rotate): def _validate(self, val_loader, criterion_denoise, criterion_rotate):
"""Perform validation and return average validation loss.""" """Perform validation and return average validation loss."""
@ -216,8 +216,8 @@ class Pretrainer:
return avg_val_loss return avg_val_loss
def save_losses(self, csv_file): def save_pretrained_losses(self, csv_file):
"""Save training and validation losses to a CSV file.""" """save_pretrained training and validation losses to a CSV file."""
data = list(zip( data = list(zip(
range(1, len(self.train_losses) + 1), range(1, len(self.train_losses) + 1),
self.train_losses, self.train_losses,

View File

@ -1,159 +1,133 @@
import os import os
import torch import torch
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig, AIIASparseMoe from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAConfig, AIIASparseMoe
def test_aiiabase_creation(): def test_aiiabase_creation():
config = AIIAConfig() config = AIIAConfig()
model = AIIABase(config) model = AIIABase(config)
assert isinstance(model, AIIABase) assert isinstance(model, AIIABase)
def test_aiiabase_save_load(): def test_aiiabase_save_pretrained_from_pretrained():
config = AIIAConfig() config = AIIAConfig()
model = AIIABase(config) model = AIIABase(config)
save_path = "test_aiiabase_save_load" save_pretrained_path = "test_aiiabase_save_pretrained_load"
# Save the model # Save the model
model.save(save_path) model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_path, "model.pth")) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_path, "config.json")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model # Load the model
loaded_model = AIIABase.load(save_path) loaded_model = AIIABase.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIABase # Check if the loaded model is an instance of AIIABase
assert isinstance(loaded_model, AIIABase) assert isinstance(loaded_model, AIIABase)
# Clean up # Clean up
os.remove(os.path.join(save_path, "model.pth")) os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_path, "config.json")) os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_path) os.rmdir(save_pretrained_path)
def test_aiiabase_shared_creation(): def test_aiiabase_shared_creation():
config = AIIAConfig() config = AIIAConfig()
model = AIIABaseShared(config) model = AIIABaseShared(config)
assert isinstance(model, AIIABaseShared) assert isinstance(model, AIIABaseShared)
def test_aiiabase_shared_save_load(): def test_aiiabase_shared_save_pretrained_from_pretrained():
config = AIIAConfig() config = AIIAConfig()
model = AIIABaseShared(config) model = AIIABaseShared(config)
save_path = "test_aiiabase_shared_save_load" save_pretrained_path = "test_aiiabase_shared_save_pretrained_load"
# Save the model # Save the model
model.save(save_path) model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_path, "model.pth")) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_path, "config.json")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model # Load the model
loaded_model = AIIABaseShared.load(save_path) loaded_model = AIIABaseShared.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIABaseShared # Check if the loaded model is an instance of AIIABaseShared
assert isinstance(loaded_model, AIIABaseShared) assert isinstance(loaded_model, AIIABaseShared)
# Clean up # Clean up
os.remove(os.path.join(save_path, "model.pth")) os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_path, "config.json")) os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_path) os.rmdir(save_pretrained_path)
def test_aiiaexpert_creation(): def test_aiiaexpert_creation():
config = AIIAConfig() config = AIIAConfig()
model = AIIAExpert(config) model = AIIAExpert(config)
assert isinstance(model, AIIAExpert) assert isinstance(model, AIIAExpert)
def test_aiiaexpert_save_load(): def test_aiiaexpert_save_pretrained_from_pretrained():
config = AIIAConfig() config = AIIAConfig()
model = AIIAExpert(config) model = AIIAExpert(config)
save_path = "test_aiiaexpert_save_load" save_pretrained_path = "test_aiiaexpert_save_pretrained_load"
# Save the model # Save the model
model.save(save_path) model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_path, "model.pth")) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_path, "config.json")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model # Load the model
loaded_model = AIIAExpert.load(save_path) loaded_model = AIIAExpert.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIAExpert # Check if the loaded model is an instance of AIIAExpert
assert isinstance(loaded_model, AIIAExpert) assert isinstance(loaded_model, AIIAExpert)
# Clean up # Clean up
os.remove(os.path.join(save_path, "model.pth")) os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_path, "config.json")) os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_path) os.rmdir(save_pretrained_path)
def test_aiiamoe_creation(): def test_aiiamoe_creation():
config = AIIAConfig() config = AIIAConfig(num_experts=3)
model = AIIAmoe(config, num_experts=5) model = AIIAmoe(config)
assert isinstance(model, AIIAmoe) assert isinstance(model, AIIAmoe)
def test_aiiamoe_save_load(): def test_aiiamoe_save_pretrained_from_pretrained():
config = AIIAConfig() config = AIIAConfig(num_experts=3)
model = AIIAmoe(config, num_experts=5) model = AIIAmoe(config)
save_path = "test_aiiamoe_save_load" save_pretrained_path = "test_aiiamoe_save_pretrained_load"
# Save the model # Save the model
model.save(save_path) model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_path, "model.pth")) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_path, "config.json")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model # Load the model
loaded_model = AIIAmoe.load(save_path) loaded_model = AIIAmoe.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIAmoe # Check if the loaded model is an instance of AIIAmoe
assert isinstance(loaded_model, AIIAmoe) assert isinstance(loaded_model, AIIAmoe)
# Clean up # Clean up
os.remove(os.path.join(save_path, "model.pth")) os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_path, "config.json")) os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_path) os.rmdir(save_pretrained_path)
def test_aiiasparsemoe_creation(): def test_aiiasparsemoe_creation():
config = AIIAConfig() config = AIIAConfig(num_experts=5, top_k=2)
model = AIIASparseMoe(config, num_experts=5, top_k=2) model = AIIASparseMoe(config, base_class=AIIABaseShared)
assert isinstance(model, AIIASparseMoe) assert isinstance(model, AIIASparseMoe)
def test_aiiasparsemoe_save_load(): def test_aiiasparsemoe_save_pretrained_from_pretrained():
config = AIIAConfig() config = AIIAConfig(num_experts=3, top_k=1)
model = AIIASparseMoe(config, num_experts=3, top_k=1) model = AIIASparseMoe(config)
save_path = "test_aiiasparsemoe_save_load" save_pretrained_path = "test_aiiasparsemoe_save_pretrained_load"
# Save the model # Save the model
model.save(save_path) model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_path, "model.pth")) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_path, "config.json")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model # Load the model
loaded_model = AIIASparseMoe.load(save_path) loaded_model = AIIASparseMoe.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIASparseMoe # Check if the loaded model is an instance of AIIASparseMoe
assert isinstance(loaded_model, AIIASparseMoe) assert isinstance(loaded_model, AIIASparseMoe)
# Clean up # Clean up
os.remove(os.path.join(save_path, "model.pth")) os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_path, "config.json")) os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_path) os.rmdir(save_pretrained_path)
def test_aiiachunked_creation():
config = AIIAConfig()
model = AIIAchunked(config)
assert isinstance(model, AIIAchunked)
def test_aiiachunked_save_load():
config = AIIAConfig()
model = AIIAchunked(config)
save_path = "test_aiiachunked_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIAchunked.load(save_path)
# Check if the loaded model is an instance of AIIAchunked
assert isinstance(loaded_model, AIIAchunked)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)

View File

@ -1,75 +1,77 @@
import os import os
import tempfile import tempfile
import pytest import pytest
import torch.nn as nn
from aiia import AIIAConfig from aiia import AIIAConfig
def test_aiia_config_initialization(): def test_aiia_config_initialization():
config = AIIAConfig() config = AIIAConfig()
assert config.model_name == "AIIA" assert config.model_type == "AIIA"
assert config.kernel_size == 3 assert config.kernel_size == 3
assert config.activation_function == "GELU" assert config.activation_function == "GELU"
assert config.hidden_size == 512 assert config.hidden_size == 512
assert config.num_hidden_layers == 12 assert config.num_hidden_layers == 12
assert config.num_channels == 3 assert config.num_channels == 3
assert config.learning_rate == 5e-5
def test_aiia_config_custom_initialization(): def test_aiia_config_custom_initialization():
config = AIIAConfig( config = AIIAConfig(
model_name="CustomModel", model_type="CustomModel",
kernel_size=5, kernel_size=5,
activation_function="ReLU", activation_function="ReLU",
hidden_size=1024, hidden_size=1024,
num_hidden_layers=8, num_hidden_layers=8,
num_channels=1, num_channels=1
learning_rate=1e-4
) )
assert config.model_name == "CustomModel" assert config.model_type == "CustomModel"
assert config.kernel_size == 5 assert config.kernel_size == 5
assert config.activation_function == "ReLU" assert config.activation_function == "ReLU"
assert config.hidden_size == 1024 assert config.hidden_size == 1024
assert config.num_hidden_layers == 8 assert config.num_hidden_layers == 8
assert config.num_channels == 1 assert config.num_channels == 1
assert config.learning_rate == 1e-4
def test_aiia_config_invalid_activation_function(): def test_aiia_config_invalid_activation_function():
with pytest.raises(ValueError): with pytest.raises(ValueError):
AIIAConfig(activation_function="InvalidFunction") AIIAConfig(activation_function="InvalidFunction")
def test_aiia_config_to_dict(): def test_aiia_config_to_dict():
config = AIIAConfig() config = AIIAConfig()
config_dict = config.to_dict() config_dict = config.to_dict()
assert isinstance(config_dict, dict) assert isinstance(config_dict, dict)
assert config_dict["model_name"] == "AIIA"
assert config_dict["kernel_size"] == 3 assert config_dict["kernel_size"] == 3
def test_aiia_config_save_and_load():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel")
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
loaded_config = AIIAConfig.load(save_path) def test_aiia_config_save_pretrained_and_from_pretrained():
assert loaded_config.model_name == "TempModel" with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_type="TempModel")
save_pretrained_path = os.path.join(tmpdir, "config")
config.save_pretrained(save_pretrained_path)
loaded_config = AIIAConfig.from_pretrained(save_pretrained_path)
assert loaded_config.model_type == "TempModel"
assert loaded_config.kernel_size == 3 assert loaded_config.kernel_size == 3
assert loaded_config.activation_function == "GELU" assert loaded_config.activation_function == "GELU"
def test_aiia_config_save_and_load_with_custom_attributes():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel", custom_attr="value")
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
loaded_config = AIIAConfig.load(save_path) def test_aiia_config_save_pretrained_and_load_with_custom_attributes():
assert loaded_config.model_name == "TempModel" with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_type="TempModel", custom_attr="value")
save_pretrained_path = os.path.join(tmpdir, "config")
config.save_pretrained(save_pretrained_path)
loaded_config = AIIAConfig.from_pretrained(save_pretrained_path)
assert loaded_config.model_type == "TempModel"
assert loaded_config.custom_attr == "value" assert loaded_config.custom_attr == "value"
def test_aiia_config_save_and_load_with_nested_attributes():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel", nested={"key": "value"})
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
loaded_config = AIIAConfig.load(save_path) def test_aiia_config_save_pretrained_and_load_with_nested_attributes():
assert loaded_config.model_name == "TempModel" with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_type="TempModel", nested={"key": "value"})
save_pretrained_path = os.path.join(tmpdir, "config")
config.save_pretrained(save_pretrained_path)
loaded_config = AIIAConfig.from_pretrained(save_pretrained_path)
assert loaded_config.model_type == "TempModel"
assert loaded_config.nested == {"key": "value"} assert loaded_config.nested == {"key": "value"}

View File

@ -94,7 +94,7 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea
# Execute training with patched methods # Execute training with patched methods
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss) as mock_process_batch, \ with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss) as mock_process_batch, \
patch.object(Pretrainer, '_validate', side_effect=[0.8, 0.3]) as mock_validate, \ patch.object(Pretrainer, '_validate', side_effect=[0.8, 0.3]) as mock_validate, \
patch.object(Pretrainer, 'save_losses') as mock_save_losses, \ patch.object(Pretrainer, 'save_pretrained_losses') as mock_save_pretrained_losses, \
patch('builtins.open', mock_open()): patch('builtins.open', mock_open()):
pretrainer.train(dataset_paths, output_path=output_path, num_epochs=2) pretrainer.train(dataset_paths, output_path=output_path, num_epochs=2)
@ -104,10 +104,10 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea
assert mock_process_batch.call_count == 2 assert mock_process_batch.call_count == 2
assert mock_validate.call_count == 2 assert mock_validate.call_count == 2
# Check for "Best model saved!" instead of model.save() # Check for "Best model save_pretrainedd!" instead of model.save_pretrained()
mock_print.assert_any_call("Best model saved!") mock_print.assert_any_call("Best model save_pretrainedd!")
mock_save_losses.assert_called_once() mock_save_pretrained_losses.assert_called_once()
# Verify state changes # Verify state changes
assert len(pretrainer.train_losses) == 2 assert len(pretrainer.train_losses) == 2
@ -153,13 +153,13 @@ def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat):
pretrainer.projection_head = MagicMock() pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock() pretrainer.optimizer = MagicMock()
with patch.object(Pretrainer, 'save_losses') as mock_save_losses: with patch.object(Pretrainer, 'save_pretrained_losses') as mock_save_pretrained_losses:
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1) pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
# Verify empty loader behavior # Verify empty loader behavior
assert len(pretrainer.train_losses) == 1 assert len(pretrainer.train_losses) == 1
assert pretrainer.train_losses[0] == 0.0 assert pretrainer.train_losses[0] == 0.0
mock_save_losses.assert_called_once() mock_save_pretrained_losses.assert_called_once()
@patch('pandas.concat') @patch('pandas.concat')
@patch('pandas.read_parquet') @patch('pandas.read_parquet')
@ -180,7 +180,7 @@ def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat)
pretrainer.optimizer = MagicMock() pretrainer.optimizer = MagicMock()
with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \ with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \
patch.object(Pretrainer, 'save_losses'): patch.object(Pretrainer, 'save_pretrained_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1) pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
# Verify None batch handling # Verify None batch handling
@ -212,7 +212,7 @@ def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_
custom_batch_size = 16 custom_batch_size = 16
custom_sample_size = 5000 custom_sample_size = 5000
with patch.object(Pretrainer, 'save_losses'): with patch.object(Pretrainer, 'save_pretrained_losses'):
pretrainer.train( pretrainer.train(
['path/to/dataset.parquet'], ['path/to/dataset.parquet'],
output_path=custom_output_path, output_path=custom_output_path,
@ -233,7 +233,7 @@ def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_
@patch('aiia.pretrain.pretrainer.AIIADataLoader') @patch('aiia.pretrain.pretrainer.AIIADataLoader')
@patch('builtins.print') # Add this to mock the print function @patch('builtins.print') # Add this to mock the print function
def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat): def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat):
"""Test that model is saved only when validation loss improves.""" """Test that model is save_pretrainedd only when validation loss improves."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df mock_concat.return_value = real_df
@ -262,11 +262,11 @@ def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_re
# Test improving validation loss # Test improving validation loss
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \ with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \ patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \
patch.object(Pretrainer, 'save_losses'): patch.object(Pretrainer, 'save_pretrained_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3) pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
# Check for "Best model saved!" 3 times # Check for "Best model save_pretrainedd!" 3 times
assert mock_print.call_args_list.count(call("Best model saved!")) == 3 assert mock_print.call_args_list.count(call("Best model save_pretrainedd!")) == 3
# Reset for next test # Reset for next test
mock_print.reset_mock() mock_print.reset_mock()
@ -278,11 +278,11 @@ def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_re
# Test fluctuating validation loss # Test fluctuating validation loss
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \ with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \ patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \
patch.object(Pretrainer, 'save_losses'): patch.object(Pretrainer, 'save_pretrained_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3) pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
# Should print "Best model saved!" only on first and third epochs # Should print "Best model save_pretrainedd!" only on first and third epochs
assert mock_print.call_args_list.count(call("Best model saved!")) == 2 assert mock_print.call_args_list.count(call("Best model save_pretrainedd!")) == 2
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch') @patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
@ -296,13 +296,13 @@ def test_validate(mock_process_batch):
loss = pretrainer._validate(val_loader, criterion_denoise, criterion_rotate) loss = pretrainer._validate(val_loader, criterion_denoise, criterion_rotate)
assert loss == 0.5 assert loss == 0.5
# Test the save_losses method # Test the save_pretrained_losses method
@patch('aiia.pretrain.pretrainer.Pretrainer.save_losses') @patch('aiia.pretrain.pretrainer.Pretrainer.save_pretrained_losses')
def test_save_losses(mock_save_losses): def test_save_pretrained_losses(mock_save_pretrained_losses):
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
pretrainer.train_losses = [0.1, 0.2] pretrainer.train_losses = [0.1, 0.2]
pretrainer.val_losses = [0.3, 0.4] pretrainer.val_losses = [0.3, 0.4]
csv_file = 'losses.csv' csv_file = 'losses.csv'
pretrainer.save_losses(csv_file) pretrainer.save_pretrained_losses(csv_file)
mock_save_losses.assert_called_once_with(csv_file) mock_save_pretrained_losses.assert_called_once_with(csv_file)