develop #41

Merged
Fabel merged 27 commits from develop into main 2025-04-17 17:08:57 +00:00
13 changed files with 584 additions and 560 deletions

View File

@ -29,7 +29,7 @@ from aiia.model import AIIAConfig
from aiia.pretrain import Pretrainer from aiia.pretrain import Pretrainer
# Create your model # Create your model
config = AIIAConfig(model_name="AIIA-Base-512x20k") config = AIIAConfig(model_type="AIIA-Base-512x20k")
model = AIIABase(config) model = AIIABase(config)
# Initialize pretrainer with the model # Initialize pretrainer with the model

View File

@ -1,24 +1,22 @@
from aiia.model import AIIABase from src.aiia.model import AIIAmoe
from aiia.model import AIIAConfig from src.aiia.model import AIIAConfig
from aiia.pretrain import Pretrainer from src.aiia.pretrain import Pretrainer
# Create your model # Create your model
config = AIIAConfig(model_name="AIIA-Base-512x20k") config = AIIAConfig(num_experts=5)
model = AIIABase(config) model = AIIAmoe(config)
model.save_pretrained("test")
model = AIIAmoe.from_pretrained("test")
# Initialize pretrainer with the model
pretrainer = Pretrainer(model, learning_rate=1e-4, config=config) pretrainer = Pretrainer(model, learning_rate=1e-4, config=config)
# List of dataset paths # Set checkpoint directory
dataset_paths = [ checkpoint_dir = "checkpoints/my_model"
"/path/to/dataset1.parquet",
"/path/to/dataset2.parquet"
]
# Start training with multiple datasets # Start training (will automatically load checkpoint if available)
pretrainer.train( pretrainer.train(
dataset_paths=dataset_paths, dataset_paths=["path/to/dataset1.parquet", "path/to/dataset2.parquet"],
num_epochs=10, output_path="trained_models/my_model",
batch_size=2, checkpoint_dir=checkpoint_dir,
sample_size=10000 num_epochs=10
) )

View File

@ -10,7 +10,7 @@ include = '\.pyi?$'
[project] [project]
name = "aiia" name = "aiia"
version = "0.2.1" version = "0.3.1"
description = "AIIA Deep Learning Model Implementation" description = "AIIA Deep Learning Model Implementation"
readme = "README.md" readme = "README.md"
authors = [ authors = [

View File

@ -6,3 +6,4 @@ pillow
pandas pandas
torchvision torchvision
pyarrow pyarrow
transformers>=4.48.0

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = aiia name = aiia
version = 0.2.1 version = 0.3.1
author = Falko Habel author = Falko Habel
author_email = falko.habel@gmx.de author_email = falko.habel@gmx.de
description = AIIA deep learning model implementation description = AIIA deep learning model implementation

View File

@ -1,7 +1,7 @@
from .model.Model import AIIABase, AIIABaseShared, AIIAchunked, AIIAmoe, AIIASparseMoe, AIIArecursive from .model.Model import AIIABase, AIIABaseShared, AIIAmoe, AIIASparseMoe
from .model.config import AIIAConfig from .model.config import AIIAConfig
from .data.DataLoader import DataLoader from .data.DataLoader import DataLoader
from .pretrain.pretrainer import Pretrainer, ProjectionHead from .pretrain.pretrainer import Pretrainer, ProjectionHead
__version__ = "0.2.1" __version__ = "0.3.1"

View File

@ -1,119 +1,48 @@
from .config import AIIAConfig from .config import AIIAConfig
from torch import nn from torch import nn
from transformers import PreTrainedModel
import torch import torch
import os
import copy import copy
import warnings
class AIIA(nn.Module): class AIIABase(PreTrainedModel):
def __init__(self, config: AIIAConfig, **kwargs): config_class = AIIAConfig
super(AIIA, self).__init__() base_model_prefix = "AIIA"
# Create a deep copy of the configuration to avoid sharing
self.config = copy.deepcopy(config)
# Update the config with any additional keyword arguments def __init__(self, config: AIIAConfig):
for key, value in kwargs.items(): super().__init__(config)
setattr(self.config, key, value)
def save(self, path: str):
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
torch.save(self.state_dict(), f"{path}/model.pth")
self.config.save(path)
@classmethod
def load(cls, path, precision: str = None, strict: bool = True, **kwargs):
config = AIIAConfig.load(path)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the state dict to analyze structure
model_dict = torch.load(f"{path}/model.pth", map_location=device)
# Special handling for AIIAmoe - detect number of experts from state_dict
if cls.__name__ == "AIIAmoe" and "num_experts" not in kwargs:
# Find maximum expert index
max_expert_idx = -1
for key in model_dict.keys():
if key.startswith("experts."):
parts = key.split(".")
if len(parts) > 1:
try:
expert_idx = int(parts[1])
max_expert_idx = max(max_expert_idx, expert_idx)
except ValueError:
pass
if max_expert_idx >= 0:
# experts.X keys found, use max_expert_idx + 1 as num_experts
kwargs["num_experts"] = max_expert_idx + 1
# Create model with detected structural parameters
model = cls(config, **kwargs)
# Handle precision conversion
dtype = None
if precision is not None:
if precision.lower() == 'fp16':
dtype = torch.float16
elif precision.lower() == 'bf16':
if device == 'cuda' and not torch.cuda.is_bf16_supported():
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
dtype = torch.float16
else:
dtype = torch.bfloat16
else:
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
if dtype is not None:
for key, param in model_dict.items():
if torch.is_tensor(param):
model_dict[key] = param.to(dtype)
# Load state dict with strict parameter for flexibility
model.load_state_dict(model_dict, strict=strict)
return model
class AIIABase(AIIA):
def __init__(self, config: AIIAConfig, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Initialize layers based on configuration # Initialize layers based on configuration
layers = [] layers = []
in_channels = self.config.num_channels in_channels = config.num_channels
for _ in range(self.config.num_hidden_layers): for _ in range(config.num_hidden_layers):
layers.extend([ layers.extend([
nn.Conv2d(in_channels, self.config.hidden_size, nn.Conv2d(in_channels, config.hidden_size,
kernel_size=self.config.kernel_size, padding=1), kernel_size=config.kernel_size, padding=1),
getattr(nn, self.config.activation_function)(), getattr(nn, config.activation_function)(),
nn.MaxPool2d(kernel_size=1, stride=1) nn.MaxPool2d(kernel_size=1, stride=1)
]) ])
in_channels = self.config.hidden_size in_channels = config.hidden_size
self.cnn = nn.Sequential(*layers) self.cnn = nn.Sequential(*layers)
def forward(self, x): def forward(self, x):
return self.cnn(x) return self.cnn(x)
class AIIABaseShared(AIIA): class AIIABaseShared(PreTrainedModel):
def __init__(self, config: AIIAConfig, **kwargs): config_class = AIIAConfig
base_model_prefix = "AIIA"
def __init__(self, config: AIIAConfig):
super().__init__(config)
""" """
Initialize the AIIABaseShared model. Initialize the AIIABaseShared model.
Args: Args:
config (AIIAConfig): Configuration object containing model parameters. config (AIIAConfig): Configuration object containing model parameters.
**kwargs: Additional keyword arguments to override configuration settings.
""" """
super().__init__(config=config, **kwargs) super().__init__(config=config)
# Update configuration with new parameters if provided
self. config = copy.deepcopy(config)
for key, value in kwargs.items():
setattr(self.config, key, value)
# Initialize the network components # Initialize the network components
self._initialize_network() self._initialize_network()
@ -172,16 +101,17 @@ class AIIABaseShared(AIIA):
return out return out
class AIIAExpert(AIIA): class AIIAExpert(PreTrainedModel):
def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs): config_class = AIIAConfig
super().__init__(config=config, **kwargs) base_model_prefix = "AIIA"
self.config = self.config def __init__(self, config: AIIAConfig, base_class=AIIABase):
super().__init__(config=config)
# Initialize base CNN with configuration and chosen base class # Initialize base CNN with configuration and chosen base class
if issubclass(base_class, AIIABase): if issubclass(base_class, AIIABase):
self.base_cnn = AIIABase(self.config, **kwargs) self.base_cnn = AIIABase(self.config)
elif issubclass(base_class, AIIABaseShared): elif issubclass(base_class, AIIABaseShared):
self.base_cnn = AIIABaseShared(self.config, **kwargs) self.base_cnn = AIIABaseShared(self.config)
else: else:
raise ValueError("Invalid base class") raise ValueError("Invalid base class")
@ -198,26 +128,26 @@ class AIIAExpert(AIIA):
# Process input through the base CNN # Process input through the base CNN
return self.base_cnn(x) return self.base_cnn(x)
class AIIAmoe(AIIA): class AIIAmoe(PreTrainedModel):
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs): config_class = AIIAConfig
super().__init__(config=config, **kwargs) base_model_prefix = "AIIA"
def __init__(self, config: AIIAConfig, base_class=AIIABase):
super().__init__(config=config)
self.config = config self.config = config
# Update the config to include the number of experts. # Get num_experts directly from config instead of parameter
self.config.num_experts = num_experts num_experts = getattr(config, "num_experts", 3) # default to 3 if not in config
# Initialize multiple experts from the chosen base class. # Initialize multiple experts from the chosen base class
self.experts = nn.ModuleList([ self.experts = nn.ModuleList([
AIIAExpert(self.config, base_class=base_class, **kwargs) AIIAExpert(self.config, base_class=base_class)
for _ in range(num_experts) for _ in range(num_experts)
]) ])
# To generate gating weights, we first need to determine the feature dimension. gate_in_features = self.config.hidden_size
# Each expert is assumed to return an output of shape (B, C, H, W); after averaging over H and W,
# we obtain a tensor of shape (B, C) where C is the number of channels (here assumed to be 224).
gate_in_features = 512 # Adjust this if your expert output changes.
# Create a gating network that maps the aggregated features to num_experts weights. # Create a gating network that maps the aggregated features to num_experts weights
self.gate = nn.Sequential( self.gate = nn.Sequential(
nn.Linear(gate_in_features, num_experts), nn.Linear(gate_in_features, num_experts),
nn.Softmax(dim=1) nn.Softmax(dim=1)
@ -261,9 +191,10 @@ class AIIAmoe(AIIA):
class AIIASparseMoe(AIIAmoe): class AIIASparseMoe(AIIAmoe):
def __init__(self, config: AIIAConfig, num_experts: int = 3, top_k: int = 2, base_class=AIIABase, **kwargs): config_class = AIIAConfig
super().__init__(config=config, num_experts=num_experts, base_class=base_class, **kwargs) base_model_prefix = "AIIA"
self.top_k = top_k def __init__(self, config: AIIAConfig, base_class=AIIABase):
super().__init__(config=config, base_class=base_class)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# Compute the gate_weights similar to standard moe. # Compute the gate_weights similar to standard moe.
@ -273,7 +204,7 @@ class AIIASparseMoe(AIIAmoe):
gate_weights = self.gate(gate_input) gate_weights = self.gate(gate_input)
# Select the top-k experts for each input based on gating weights. # Select the top-k experts for each input based on gating weights.
_, top_k_indices = gate_weights.topk(self.top_k, dim=-1) _, top_k_indices = gate_weights.topk(self.config.top_k, dim=-1)
# Initialize a list to store outputs from selected experts. # Initialize a list to store outputs from selected experts.
merged_outputs = [] merged_outputs = []
@ -294,64 +225,7 @@ class AIIASparseMoe(AIIAmoe):
return torch.cat(merged_outputs, dim=0) return torch.cat(merged_outputs, dim=0)
class AIIAchunked(AIIA):
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Update config with new parameters if provided
self.config.patch_size = patch_size
# Initialize base CNN for processing each patch using the specified base class
if issubclass(base_class, AIIABase):
self.base_cnn = AIIABase(self.config, **kwargs)
elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared
self.base_cnn = AIIABaseShared(self.config, **kwargs)
else:
raise ValueError("Invalid base class")
def forward(self, x):
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, self.patch_size, self.patch_size)
patch_outputs = []
for p in torch.split(patches, 1, dim=2):
p = p.squeeze(2)
po = self.base_cnn(p)
patch_outputs.append(po)
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
return combined_output
class AIIArecursive(AIIA):
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs):
super().__init__(config=config, **kwargs)
self.config = self.config
# Pass recursion_depth as a kwarg to the config
self.config.recursion_depth = recursion_depth
# Initialize chunked CNN with updated config
self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs)
def forward(self, x, depth=0):
if depth == self.recursion_depth:
return self.chunked_cnn(x)
else:
patches = x.unfold(2, 16, 16).unfold(3, 16, 16)
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, 16, 16)
processed_patches = []
for p in torch.split(patches, 1, dim=2):
p = p.squeeze(2)
pp = self.forward(p, depth + 1)
processed_patches.append(pp)
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
return combined_output
if __name__ =="__main__": if __name__ =="__main__":
config = AIIAConfig() config = AIIAConfig()
model = AIIAmoe(config, num_experts=5) model = AIIAmoe(config, num_experts=5)
model.save("test") model.save_pretrained("test")

View File

@ -1,20 +1,15 @@
from .Model import ( from .Model import (
AIIABase, AIIABase,
AIIABaseShared, AIIABaseShared,
AIIAchunked,
AIIAmoe, AIIAmoe,
AIIASparseMoe, AIIASparseMoe,
AIIArecursive
) )
from .config import AIIAConfig from .config import AIIAConfig
__all__ = [ __all__ = [
"AIIABase", "AIIABase",
"AIIABaseShared", "AIIABaseShared",
"AIIAchunked",
"AIIAmoe", "AIIAmoe",
"AIIASparseMoe", "AIIASparseMoe",
"AIIArecursive",
"AIIAConfig", "AIIAConfig",
] ]

View File

@ -1,28 +1,24 @@
import torch from transformers import PretrainedConfig
import torch.nn as nn import torch.nn as nn
import json
import os
class AIIAConfig(PretrainedConfig):
model_type = "AIIA" # Add this class attribute
class AIIAConfig:
def __init__( def __init__(
self, self,
model_name: str = "AIIA",
kernel_size: int = 3, kernel_size: int = 3,
activation_function: str = "GELU", activation_function: str = "GELU",
hidden_size: int = 512, hidden_size: int = 512,
num_hidden_layers: int = 12, num_hidden_layers: int = 12,
num_channels: int = 3, num_channels: int = 3,
learning_rate: float = 5e-5,
**kwargs **kwargs
): ):
self.model_name = model_name super().__init__(**kwargs)
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.activation_function = activation_function self.activation_function = activation_function
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_channels = num_channels self.num_channels = num_channels
self.learning_rate = learning_rate
# Store additional keyword arguments as attributes # Store additional keyword arguments as attributes
for key, value in kwargs.items(): for key, value in kwargs.items():
@ -51,16 +47,3 @@ class AIIAConfig:
return {k: serialize(v) for k, v in value.items()} return {k: serialize(v) for k, v in value.items()}
return value return value
return {k: serialize(v) for k, v in self.__dict__.items()} return {k: serialize(v) for k, v in self.__dict__.items()}
def save(self, file_path):
if not os.path.exists(file_path):
os.makedirs(file_path, exist_ok=True)
with open(os.path.join(file_path, "config.json"), "w") as f:
# Save the recursively converted dictionary.
json.dump(self.to_dict(), f, indent=4)
@classmethod
def load(cls, file_path):
with open(os.path.join(file_path, "config.json"), "r") as f:
config_dict = json.load(f)
return cls(**config_dict)

View File

@ -1,9 +1,11 @@
import torch import torch
from torch import nn from torch import nn
import csv import csv
import datetime
import time
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from ..model.Model import AIIA from transformers import PreTrainedModel
from ..model.config import AIIAConfig from ..model.config import AIIAConfig
from ..data.DataLoader import AIIADataLoader from ..data.DataLoader import AIIADataLoader
import os import os
@ -21,7 +23,7 @@ class ProjectionHead(nn.Module):
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
class Pretrainer: class Pretrainer:
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None): def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None):
""" """
Initialize the pretrainer with a model. Initialize the pretrainer with a model.
@ -112,20 +114,169 @@ class Pretrainer:
return batch_loss return batch_loss
def train(self, dataset_paths,output_path:str="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000): def _save_checkpoint(self, checkpoint_dir, epoch, batch_count, checkpoint_name):
""" """Save a model checkpoint.
Train the model using multiple specified datasets.
Args: Args:
dataset_paths (list): List of paths to parquet datasets checkpoint_dir (str): Directory to save the checkpoint
num_epochs (int): Number of training epochs epoch (int): Current epoch number
batch_size (int): Batch size for training batch_count (int): Current batch count
sample_size (int): Number of samples to use from each dataset checkpoint_name (str): Name for the checkpoint file
Returns:
str: Path to the saved checkpoint
""" """
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
checkpoint_data = {
'epoch': epoch + 1,
'batch': batch_count,
'model_state_dict': self.model.state_dict(),
'projection_head_state_dict': self.projection_head.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_losses': self.train_losses,
'val_losses': self.val_losses,
}
torch.save(checkpoint_data, checkpoint_path)
return checkpoint_path
def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None):
"""
Check for checkpoints and load if available.
Args:
checkpoint_dir (str): Directory where checkpoints are stored
specific_checkpoint (str, optional): Specific checkpoint file to load. If None, loads the most recent.
Returns:
tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise
"""
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
# If a specific checkpoint is requested
if specific_checkpoint:
checkpoint_path = os.path.join(checkpoint_dir, specific_checkpoint)
if os.path.exists(checkpoint_path):
return self._load_checkpoint_file(checkpoint_path)
else:
print(f"Specified checkpoint {specific_checkpoint} not found.")
return None
# Find all checkpoint files
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pt")]
if not checkpoint_files:
print("No checkpoints found in directory.")
return None
# Find the most recent checkpoint
checkpoint_files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
most_recent = checkpoint_files[0]
checkpoint_path = os.path.join(checkpoint_dir, most_recent)
return self._load_checkpoint_file(checkpoint_path)
def _load_checkpoint_file(self, checkpoint_path):
"""
Load a specific checkpoint file.
Args:
checkpoint_path (str): Path to the checkpoint file
Returns:
tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise
"""
try:
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# Load model state
self.model.load_state_dict(checkpoint['model_state_dict'])
# Load projection head state
self.projection_head.load_state_dict(checkpoint['projection_head_state_dict'])
# Load optimizer state
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Load loss history
self.train_losses = checkpoint.get('train_losses', [])
self.val_losses = checkpoint.get('val_losses', [])
loaded_epoch = checkpoint['epoch']
loaded_batch = checkpoint['batch']
print(f"Checkpoint loaded from {checkpoint_path}")
print(f"Resuming from epoch {loaded_epoch}, batch {loaded_batch}")
return loaded_epoch, loaded_batch
except Exception as e:
print(f"Error loading checkpoint: {e}")
return None
def train(self, dataset_paths, output_path="AIIA", column="image_bytes",
num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None):
"""Train the model using multiple specified datasets with checkpoint resumption support."""
if not dataset_paths: if not dataset_paths:
raise ValueError("No dataset paths provided") raise ValueError("No dataset paths provided")
# Read and merge all datasets self._initialize_checkpoint_variables()
start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir)
dataframes = self._load_and_merge_datasets(dataset_paths, sample_size)
aiia_loader = self._initialize_data_loader(dataframes, column, batch_size)
criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions()
for epoch in range(start_epoch, num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader,
start_batch if (epoch == start_epoch and resume_training) else 0,
criterion_denoise,
criterion_rotate)
avg_train_loss = total_train_loss / max(batch_count, 1)
self.train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
val_loss = self._validation_phase(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
if val_loss < best_val_loss:
best_val_loss = val_loss
self.model.save(output_path)
print("Best model saved!")
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
self.save_losses(losses_path)
def _initialize_checkpoint_variables(self):
"""Initialize checkpoint tracking variables."""
self.last_checkpoint_time = time.time()
self.checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds
self.last_22_date = None
self.recent_checkpoints = []
def _load_checkpoints(self, checkpoint_dir):
"""Load checkpoints and return start epoch, batch, and resumption flag."""
start_epoch = 0
start_batch = 0
resume_training = False
if checkpoint_dir is not None:
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_info = self.load_checkpoint(checkpoint_dir)
if checkpoint_info:
start_epoch, start_batch = checkpoint_info
resume_training = True
# Adjust epoch to be 0-indexed for the loop
start_epoch -= 1
return start_epoch, start_batch, resume_training
def _load_and_merge_datasets(self, dataset_paths, sample_size):
"""Load and merge datasets."""
dataframes = [] dataframes = []
for path in dataset_paths: for path in dataset_paths:
try: try:
@ -137,10 +288,11 @@ class Pretrainer:
if not dataframes: if not dataframes:
raise ValueError("No valid datasets could be loaded") raise ValueError("No valid datasets could be loaded")
merged_df = pd.concat(dataframes, ignore_index=True) return pd.concat(dataframes, ignore_index=True)
# Initialize data loader def _initialize_data_loader(self, merged_df, column, batch_size):
aiia_loader = AIIADataLoader( """Initialize the data loader."""
return AIIADataLoader(
merged_df, merged_df,
column=column, column=column,
batch_size=batch_size, batch_size=batch_size,
@ -148,24 +300,30 @@ class Pretrainer:
collate_fn=self.safe_collate collate_fn=self.safe_collate
) )
def _initialize_loss_functions(self):
"""Initialize loss functions and tracking variables."""
criterion_denoise = nn.MSELoss() criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss() criterion_rotate = nn.CrossEntropyLoss()
best_val_loss = float('inf') best_val_loss = float('inf')
return criterion_denoise, criterion_rotate, best_val_loss
for epoch in range(num_epochs): def _training_phase(self, train_loader, skip_batches, criterion_denoise, criterion_rotate):
print(f"\nEpoch {epoch+1}/{num_epochs}") """Handle the training phase."""
print("-" * 20)
# Training phase
self.model.train() self.model.train()
self.projection_head.train() self.projection_head.train()
total_train_loss = 0.0 total_train_loss = 0.0
batch_count = 0 batch_count = 0
for batch_data in tqdm(aiia_loader.train_loader): train_batches = list(enumerate(train_loader))
for i, batch_data in tqdm(train_batches[skip_batches:],
initial=skip_batches,
total=len(train_batches)):
if batch_data is None: if batch_data is None:
continue continue
current_batch = i + 1
self._handle_checkpoints(current_batch)
self.optimizer.zero_grad() self.optimizer.zero_grad()
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
@ -175,22 +333,42 @@ class Pretrainer:
total_train_loss += batch_loss.item() total_train_loss += batch_loss.item()
batch_count += 1 batch_count += 1
avg_train_loss = total_train_loss / max(batch_count, 1) return total_train_loss, batch_count
self.train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase def _handle_checkpoints(self, current_batch):
"""Handle checkpoint saving logic."""
current_time = time.time()
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time
today = current_dt.date()
if self.checkpoint_dir and (current_time - self.last_checkpoint_time) >= self.checkpoint_interval:
checkpoint_name = f"checkpoint_epoch{self.current_epoch+1}_batch{current_batch}.pt"
checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
# Track and maintain only 3 recent checkpoints
self.recent_checkpoints.append(checkpoint_path)
if len(self.recent_checkpoints) > 3:
oldest = self.recent_checkpoints.pop(0)
if os.path.exists(oldest):
os.remove(oldest)
self.last_checkpoint_time = current_time
print(f"Checkpoint saved at {checkpoint_path}")
# Special 22:00 checkpoint (considering it's currently 10:15 PM)
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
if self.checkpoint_dir and is_22_oclock and self.last_22_date != today:
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
self.last_22_date = today
print(f"22:00 Checkpoint saved at {checkpoint_path}")
def _validation_phase(self, val_loader, criterion_denoise, criterion_rotate):
"""Handle the validation phase."""
self.model.eval() self.model.eval()
self.projection_head.eval() self.projection_head.eval()
val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate) return self._validate(val_loader, criterion_denoise, criterion_rotate)
if val_loss < best_val_loss:
best_val_loss = val_loss
self.model.save(output_path)
print("Best model saved!")
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
self.save_losses(losses_path)
def _validate(self, val_loader, criterion_denoise, criterion_rotate): def _validate(self, val_loader, criterion_denoise, criterion_rotate):
"""Perform validation and return average validation loss.""" """Perform validation and return average validation loss."""

View File

@ -1,159 +1,133 @@
import os import os
import torch import torch
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig, AIIASparseMoe from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAConfig, AIIASparseMoe
def test_aiiabase_creation(): def test_aiiabase_creation():
config = AIIAConfig() config = AIIAConfig()
model = AIIABase(config) model = AIIABase(config)
assert isinstance(model, AIIABase) assert isinstance(model, AIIABase)
def test_aiiabase_save_load(): def test_aiiabase_save_pretrained_from_pretrained():
config = AIIAConfig() config = AIIAConfig()
model = AIIABase(config) model = AIIABase(config)
save_path = "test_aiiabase_save_load" save_pretrained_path = "test_aiiabase_save_pretrained_load"
# Save the model # Save the model
model.save(save_path) model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_path, "model.pth")) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_path, "config.json")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model # Load the model
loaded_model = AIIABase.load(save_path) loaded_model = AIIABase.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIABase # Check if the loaded model is an instance of AIIABase
assert isinstance(loaded_model, AIIABase) assert isinstance(loaded_model, AIIABase)
# Clean up # Clean up
os.remove(os.path.join(save_path, "model.pth")) os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_path, "config.json")) os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_path) os.rmdir(save_pretrained_path)
def test_aiiabase_shared_creation(): def test_aiiabase_shared_creation():
config = AIIAConfig() config = AIIAConfig()
model = AIIABaseShared(config) model = AIIABaseShared(config)
assert isinstance(model, AIIABaseShared) assert isinstance(model, AIIABaseShared)
def test_aiiabase_shared_save_load(): def test_aiiabase_shared_save_pretrained_from_pretrained():
config = AIIAConfig() config = AIIAConfig()
model = AIIABaseShared(config) model = AIIABaseShared(config)
save_path = "test_aiiabase_shared_save_load" save_pretrained_path = "test_aiiabase_shared_save_pretrained_load"
# Save the model # Save the model
model.save(save_path) model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_path, "model.pth")) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_path, "config.json")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model # Load the model
loaded_model = AIIABaseShared.load(save_path) loaded_model = AIIABaseShared.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIABaseShared # Check if the loaded model is an instance of AIIABaseShared
assert isinstance(loaded_model, AIIABaseShared) assert isinstance(loaded_model, AIIABaseShared)
# Clean up # Clean up
os.remove(os.path.join(save_path, "model.pth")) os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_path, "config.json")) os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_path) os.rmdir(save_pretrained_path)
def test_aiiaexpert_creation(): def test_aiiaexpert_creation():
config = AIIAConfig() config = AIIAConfig()
model = AIIAExpert(config) model = AIIAExpert(config)
assert isinstance(model, AIIAExpert) assert isinstance(model, AIIAExpert)
def test_aiiaexpert_save_load(): def test_aiiaexpert_save_pretrained_from_pretrained():
config = AIIAConfig() config = AIIAConfig()
model = AIIAExpert(config) model = AIIAExpert(config)
save_path = "test_aiiaexpert_save_load" save_pretrained_path = "test_aiiaexpert_save_pretrained_load"
# Save the model # Save the model
model.save(save_path) model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_path, "model.pth")) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_path, "config.json")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model # Load the model
loaded_model = AIIAExpert.load(save_path) loaded_model = AIIAExpert.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIAExpert # Check if the loaded model is an instance of AIIAExpert
assert isinstance(loaded_model, AIIAExpert) assert isinstance(loaded_model, AIIAExpert)
# Clean up # Clean up
os.remove(os.path.join(save_path, "model.pth")) os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_path, "config.json")) os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_path) os.rmdir(save_pretrained_path)
def test_aiiamoe_creation(): def test_aiiamoe_creation():
config = AIIAConfig() config = AIIAConfig(num_experts=3)
model = AIIAmoe(config, num_experts=5) model = AIIAmoe(config)
assert isinstance(model, AIIAmoe) assert isinstance(model, AIIAmoe)
def test_aiiamoe_save_load(): def test_aiiamoe_save_pretrained_from_pretrained():
config = AIIAConfig() config = AIIAConfig(num_experts=3)
model = AIIAmoe(config, num_experts=5) model = AIIAmoe(config)
save_path = "test_aiiamoe_save_load" save_pretrained_path = "test_aiiamoe_save_pretrained_load"
# Save the model # Save the model
model.save(save_path) model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_path, "model.pth")) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_path, "config.json")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model # Load the model
loaded_model = AIIAmoe.load(save_path) loaded_model = AIIAmoe.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIAmoe # Check if the loaded model is an instance of AIIAmoe
assert isinstance(loaded_model, AIIAmoe) assert isinstance(loaded_model, AIIAmoe)
# Clean up # Clean up
os.remove(os.path.join(save_path, "model.pth")) os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_path, "config.json")) os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_path) os.rmdir(save_pretrained_path)
def test_aiiasparsemoe_creation(): def test_aiiasparsemoe_creation():
config = AIIAConfig() config = AIIAConfig(num_experts=5, top_k=2)
model = AIIASparseMoe(config, num_experts=5, top_k=2) model = AIIASparseMoe(config, base_class=AIIABaseShared)
assert isinstance(model, AIIASparseMoe) assert isinstance(model, AIIASparseMoe)
def test_aiiasparsemoe_save_load(): def test_aiiasparsemoe_save_pretrained_from_pretrained():
config = AIIAConfig() config = AIIAConfig(num_experts=3, top_k=1)
model = AIIASparseMoe(config, num_experts=3, top_k=1) model = AIIASparseMoe(config)
save_path = "test_aiiasparsemoe_save_load" save_pretrained_path = "test_aiiasparsemoe_save_pretrained_load"
# Save the model # Save the model
model.save(save_path) model.save_pretrained(save_pretrained_path)
assert os.path.exists(os.path.join(save_path, "model.pth")) assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
assert os.path.exists(os.path.join(save_path, "config.json")) assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
# Load the model # Load the model
loaded_model = AIIASparseMoe.load(save_path) loaded_model = AIIASparseMoe.from_pretrained(save_pretrained_path)
# Check if the loaded model is an instance of AIIASparseMoe # Check if the loaded model is an instance of AIIASparseMoe
assert isinstance(loaded_model, AIIASparseMoe) assert isinstance(loaded_model, AIIASparseMoe)
# Clean up # Clean up
os.remove(os.path.join(save_path, "model.pth")) os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
os.remove(os.path.join(save_path, "config.json")) os.remove(os.path.join(save_pretrained_path, "config.json"))
os.rmdir(save_path) os.rmdir(save_pretrained_path)
def test_aiiachunked_creation():
config = AIIAConfig()
model = AIIAchunked(config)
assert isinstance(model, AIIAchunked)
def test_aiiachunked_save_load():
config = AIIAConfig()
model = AIIAchunked(config)
save_path = "test_aiiachunked_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIAchunked.load(save_path)
# Check if the loaded model is an instance of AIIAchunked
assert isinstance(loaded_model, AIIAchunked)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)

View File

@ -1,75 +1,77 @@
import os import os
import tempfile import tempfile
import pytest import pytest
import torch.nn as nn
from aiia import AIIAConfig from aiia import AIIAConfig
def test_aiia_config_initialization(): def test_aiia_config_initialization():
config = AIIAConfig() config = AIIAConfig()
assert config.model_name == "AIIA" assert config.model_type == "AIIA"
assert config.kernel_size == 3 assert config.kernel_size == 3
assert config.activation_function == "GELU" assert config.activation_function == "GELU"
assert config.hidden_size == 512 assert config.hidden_size == 512
assert config.num_hidden_layers == 12 assert config.num_hidden_layers == 12
assert config.num_channels == 3 assert config.num_channels == 3
assert config.learning_rate == 5e-5
def test_aiia_config_custom_initialization(): def test_aiia_config_custom_initialization():
config = AIIAConfig( config = AIIAConfig(
model_name="CustomModel", model_type="CustomModel",
kernel_size=5, kernel_size=5,
activation_function="ReLU", activation_function="ReLU",
hidden_size=1024, hidden_size=1024,
num_hidden_layers=8, num_hidden_layers=8,
num_channels=1, num_channels=1
learning_rate=1e-4
) )
assert config.model_name == "CustomModel" assert config.model_type == "CustomModel"
assert config.kernel_size == 5 assert config.kernel_size == 5
assert config.activation_function == "ReLU" assert config.activation_function == "ReLU"
assert config.hidden_size == 1024 assert config.hidden_size == 1024
assert config.num_hidden_layers == 8 assert config.num_hidden_layers == 8
assert config.num_channels == 1 assert config.num_channels == 1
assert config.learning_rate == 1e-4
def test_aiia_config_invalid_activation_function(): def test_aiia_config_invalid_activation_function():
with pytest.raises(ValueError): with pytest.raises(ValueError):
AIIAConfig(activation_function="InvalidFunction") AIIAConfig(activation_function="InvalidFunction")
def test_aiia_config_to_dict(): def test_aiia_config_to_dict():
config = AIIAConfig() config = AIIAConfig()
config_dict = config.to_dict() config_dict = config.to_dict()
assert isinstance(config_dict, dict) assert isinstance(config_dict, dict)
assert config_dict["model_name"] == "AIIA"
assert config_dict["kernel_size"] == 3 assert config_dict["kernel_size"] == 3
def test_aiia_config_save_and_load():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel")
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
loaded_config = AIIAConfig.load(save_path) def test_aiia_config_save_pretrained_and_from_pretrained():
assert loaded_config.model_name == "TempModel" with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_type="TempModel")
save_pretrained_path = os.path.join(tmpdir, "config")
config.save_pretrained(save_pretrained_path)
loaded_config = AIIAConfig.from_pretrained(save_pretrained_path)
assert loaded_config.model_type == "TempModel"
assert loaded_config.kernel_size == 3 assert loaded_config.kernel_size == 3
assert loaded_config.activation_function == "GELU" assert loaded_config.activation_function == "GELU"
def test_aiia_config_save_and_load_with_custom_attributes():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel", custom_attr="value")
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
loaded_config = AIIAConfig.load(save_path) def test_aiia_config_save_pretrained_and_load_with_custom_attributes():
assert loaded_config.model_name == "TempModel" with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_type="TempModel", custom_attr="value")
save_pretrained_path = os.path.join(tmpdir, "config")
config.save_pretrained(save_pretrained_path)
loaded_config = AIIAConfig.from_pretrained(save_pretrained_path)
assert loaded_config.model_type == "TempModel"
assert loaded_config.custom_attr == "value" assert loaded_config.custom_attr == "value"
def test_aiia_config_save_and_load_with_nested_attributes():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel", nested={"key": "value"})
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
loaded_config = AIIAConfig.load(save_path) def test_aiia_config_save_pretrained_and_load_with_nested_attributes():
assert loaded_config.model_name == "TempModel" with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_type="TempModel", nested={"key": "value"})
save_pretrained_path = os.path.join(tmpdir, "config")
config.save_pretrained(save_pretrained_path)
loaded_config = AIIAConfig.from_pretrained(save_pretrained_path)
assert loaded_config.model_type == "TempModel"
assert loaded_config.nested == {"key": "value"} assert loaded_config.nested == {"key": "value"}

View File

@ -3,6 +3,8 @@ import torch
from unittest.mock import MagicMock, patch, MagicMock, mock_open, call from unittest.mock import MagicMock, patch, MagicMock, mock_open, call
from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig
import pandas as pd import pandas as pd
import os
import datetime
# Test the ProjectionHead class # Test the ProjectionHead class
def test_projection_head(): def test_projection_head():
@ -53,11 +55,94 @@ def test_process_batch(mock_process_batch):
loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate) loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate)
assert loss == 0.5 assert loss == 0.5
# Error cases
# New tests for checkpoint handling
@patch('torch.save')
@patch('os.path.join')
def test_save_checkpoint(mock_join, mock_save):
"""Test checkpoint saving functionality."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
checkpoint_dir = "checkpoints"
epoch = 1
batch_count = 100
checkpoint_name = "test_checkpoint.pt"
mock_join.return_value = os.path.join(checkpoint_dir, checkpoint_name)
path = pretrainer._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name)
assert path == os.path.join(checkpoint_dir, checkpoint_name)
mock_save.assert_called_once()
@patch('os.makedirs')
@patch('os.path.exists')
@patch('torch.load')
def test_load_checkpoint_specific(mock_load, mock_exists, mock_makedirs):
"""Test loading a specific checkpoint."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
checkpoint_dir = "checkpoints"
specific_checkpoint = "specific_checkpoint.pt"
mock_exists.return_value = True
mock_load.return_value = {
'epoch': 2,
'batch': 150,
'model_state_dict': {},
'projection_head_state_dict': {},
'optimizer_state_dict': {},
'train_losses': [],
'val_losses': []
}
result = pretrainer.load_checkpoint(checkpoint_dir, specific_checkpoint)
assert result == (2, 150)
@patch('os.listdir')
@patch('os.path.getmtime')
def test_load_checkpoint_most_recent(mock_getmtime, mock_listdir):
"""Test loading the most recent checkpoint."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
checkpoint_dir = "checkpoints"
mock_listdir.return_value = ["checkpoint_1.pt", "checkpoint_2.pt"]
mock_getmtime.side_effect = [100, 200] # checkpoint_2.pt is more recent
with patch.object(pretrainer, '_load_checkpoint_file', return_value=(2, 150)):
result = pretrainer.load_checkpoint(checkpoint_dir)
assert result == (2, 150)
def test_initialize_checkpoint_variables():
"""Test initialization of checkpoint variables."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer._initialize_checkpoint_variables()
assert hasattr(pretrainer, 'last_checkpoint_time')
assert pretrainer.checkpoint_interval == 2 * 60 * 60
assert pretrainer.last_22_date is None
assert pretrainer.recent_checkpoints == []
@patch('torch.nn.MSELoss')
@patch('torch.nn.CrossEntropyLoss')
def test_initialize_loss_functions(mock_ce_loss, mock_mse_loss):
"""Test loss function initialization."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
criterion_denoise, criterion_rotate, best_val_loss = pretrainer._initialize_loss_functions()
assert mock_mse_loss.called
assert mock_ce_loss.called
assert best_val_loss == float('inf')
@patch('pandas.concat') @patch('pandas.concat')
@patch('pandas.read_parquet') @patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader') @patch('aiia.pretrain.pretrainer.AIIADataLoader')
@patch('os.path.join', return_value='mocked/path/model.pt') @patch('os.path.join', return_value='mocked/path/model.pt')
@patch('builtins.print') # Add this to mock the print function @patch('builtins.print')
def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat): def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat):
"""Test the train method under normal conditions with comprehensive verification.""" """Test the train method under normal conditions with comprehensive verification."""
# Setup test data and mocks # Setup test data and mocks
@ -73,6 +158,7 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
pretrainer.projection_head = mock_projection_head pretrainer.projection_head = mock_projection_head
pretrainer.optimizer = MagicMock() pretrainer.optimizer = MagicMock()
pretrainer.checkpoint_dir = None # Initialize checkpoint_dir
# Setup dataset paths and mock batch data # Setup dataset paths and mock batch data
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet'] dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
@ -104,185 +190,118 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea
assert mock_process_batch.call_count == 2 assert mock_process_batch.call_count == 2
assert mock_validate.call_count == 2 assert mock_validate.call_count == 2
# Check for "Best model saved!" instead of model.save()
mock_print.assert_any_call("Best model saved!") mock_print.assert_any_call("Best model saved!")
mock_save_losses.assert_called_once() mock_save_losses.assert_called_once()
# Verify state changes
assert len(pretrainer.train_losses) == 2 assert len(pretrainer.train_losses) == 2
assert pretrainer.train_losses == [0.5, 0.5] assert pretrainer.train_losses == [0.5, 0.5]
@patch('datetime.datetime')
# Error cases @patch('time.time')
def test_train_no_dataset_paths(): def test_handle_checkpoints(mock_time, mock_datetime):
"""Test ValueError when no dataset paths are provided.""" """Test checkpoint handling logic."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.checkpoint_dir = "checkpoints"
pretrainer.current_epoch = 1
pretrainer._initialize_checkpoint_variables()
with pytest.raises(ValueError, match="No dataset paths provided"): # Set a base time value
pretrainer.train([]) base_time = 1000
# Set the last checkpoint time to base_time
pretrainer.last_checkpoint_time = base_time
@patch('pandas.read_parquet') # Mock time to return base_time + interval + 1 to trigger checkpoint save
def test_train_all_datasets_fail(mock_read_parquet): mock_time.return_value = base_time + pretrainer.checkpoint_interval + 1
"""Test handling when all datasets fail to load."""
mock_read_parquet.side_effect = Exception("Failed to load dataset")
# Mock datetime for 22:00 checkpoint
mock_dt = MagicMock()
mock_dt.hour = 22
mock_dt.minute = 0
mock_dt.date.return_value = datetime.date(2023, 1, 1)
mock_datetime.now.return_value = mock_dt
with patch.object(pretrainer, '_save_checkpoint') as mock_save:
pretrainer._handle_checkpoints(100)
# Should be called twice - once for regular interval and once for 22:00
assert mock_save.call_count == 2
def test_training_phase():
"""Test the training phase logic."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
with pytest.raises(ValueError, match="No valid datasets could be loaded"):
pretrainer.train(dataset_paths)
# Edge cases
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat):
"""Test behavior with empty data loaders."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
loader_instance = MagicMock()
loader_instance.train_loader = [] # Empty train loader
loader_instance.val_loader = [] # Empty val loader
mock_data_loader.return_value = loader_instance
mock_model = MagicMock()
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock() pretrainer.optimizer = MagicMock()
pretrainer.checkpoint_dir = None # Initialize checkpoint_dir
pretrainer._initialize_checkpoint_variables()
pretrainer.current_epoch = 0
with patch.object(Pretrainer, 'save_losses') as mock_save_losses: # Create mock batch data with requires_grad=True
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
# Verify empty loader behavior
assert len(pretrainer.train_losses) == 1
assert pretrainer.train_losses[0] == 0.0
mock_save_losses.assert_called_once()
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat):
"""Test behavior when batch_data is None."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
loader_instance = MagicMock()
loader_instance.train_loader = [None] # Loader returns None
loader_instance.val_loader = []
mock_data_loader.return_value = loader_instance
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \
patch.object(Pretrainer, 'save_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
# Verify None batch handling
mock_process_batch.assert_not_called()
assert pretrainer.train_losses[0] == 0.0
# Parameter variations
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_concat):
"""Test that custom parameters are properly passed through."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
loader_instance = MagicMock()
loader_instance.train_loader = []
loader_instance.val_loader = []
mock_data_loader.return_value = loader_instance
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
# Custom parameters
custom_output_path = "custom/output/path"
custom_column = "custom_column"
custom_batch_size = 16
custom_sample_size = 5000
with patch.object(Pretrainer, 'save_losses'):
pretrainer.train(
['path/to/dataset.parquet'],
output_path=custom_output_path,
column=custom_column,
batch_size=custom_batch_size,
sample_size=custom_sample_size
)
# Verify custom parameters were used
mock_read_parquet.return_value.head.assert_called_once_with(custom_sample_size)
assert mock_data_loader.call_args[1]['column'] == custom_column
assert mock_data_loader.call_args[1]['batch_size'] == custom_batch_size
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
@patch('builtins.print') # Add this to mock the print function
def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat):
"""Test that model is saved only when validation loss improves."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
# Create mock batch data with proper structure
mock_batch_data = { mock_batch_data = {
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)), 'denoise': (
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1])) torch.randn(2, 3, 32, 32, requires_grad=True),
torch.randn(2, 3, 32, 32, requires_grad=True)
),
'rotate': (
torch.randn(2, 3, 32, 32, requires_grad=True),
torch.tensor([0, 1], dtype=torch.long) # Labels typically don't need gradients
)
}
mock_train_loader = [(0, mock_batch_data)] # Include batch index
# Mock the loss functions to return tensors that require gradients
criterion_denoise = MagicMock(return_value=torch.tensor(0.5, requires_grad=True))
criterion_rotate = MagicMock(return_value=torch.tensor(0.5, requires_grad=True))
with patch.object(pretrainer, '_process_batch', return_value=torch.tensor(0.5, requires_grad=True)), \
patch.object(pretrainer, '_handle_checkpoints') as mock_handle_checkpoints:
total_loss, batch_count = pretrainer._training_phase(
mock_train_loader, 0, criterion_denoise, criterion_rotate)
assert total_loss == 0.5
assert batch_count == 1
mock_handle_checkpoints.assert_called_once_with(1) # Check if checkpoint handling was called
def test_validation_phase():
"""Test the validation phase logic."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
mock_val_loader = [MagicMock()]
criterion_denoise = MagicMock()
criterion_rotate = MagicMock()
with patch.object(pretrainer, '_validate', return_value=0.4):
val_loss = pretrainer._validation_phase(
mock_val_loader, criterion_denoise, criterion_rotate)
assert val_loss == 0.4
@patch('pandas.read_parquet')
def test_load_and_merge_datasets(mock_read_parquet):
"""Test dataset loading and merging."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
mock_df = pd.DataFrame({'col': [1, 2, 3]})
mock_read_parquet.return_value.head.return_value = mock_df
result = pretrainer._load_and_merge_datasets(['path1.parquet', 'path2.parquet'], 1000)
assert len(result) == 6 # 2 datasets * 3 rows each
def test_process_batch_none_tasks():
"""Test processing batch with no tasks."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
batch_data = {
'denoise': None,
'rotate': None
} }
loader_instance = MagicMock() loss = pretrainer._process_batch(
loader_instance.train_loader = [mock_batch_data] batch_data,
loader_instance.val_loader = [mock_batch_data] criterion_denoise=MagicMock(),
mock_data_loader.return_value = loader_instance criterion_rotate=MagicMock()
)
mock_model = MagicMock() assert loss == 0
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
# Initialize the best validation loss
pretrainer.best_val_loss = float('inf')
mock_batch_loss = torch.tensor(0.5, requires_grad=True)
# Test improving validation loss
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \
patch.object(Pretrainer, 'save_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
# Check for "Best model saved!" 3 times
assert mock_print.call_args_list.count(call("Best model saved!")) == 3
# Reset for next test
mock_print.reset_mock()
pretrainer.train_losses = []
# Reset best validation loss for the second test
pretrainer.best_val_loss = float('inf')
# Test fluctuating validation loss
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \
patch.object(Pretrainer, 'save_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
# Should print "Best model saved!" only on first and third epochs
assert mock_print.call_args_list.count(call("Best model saved!")) == 2
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch') @patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')