develop #41
|
@ -29,7 +29,7 @@ from aiia.model import AIIAConfig
|
||||||
from aiia.pretrain import Pretrainer
|
from aiia.pretrain import Pretrainer
|
||||||
|
|
||||||
# Create your model
|
# Create your model
|
||||||
config = AIIAConfig(model_name="AIIA-Base-512x20k")
|
config = AIIAConfig(model_type="AIIA-Base-512x20k")
|
||||||
model = AIIABase(config)
|
model = AIIABase(config)
|
||||||
|
|
||||||
# Initialize pretrainer with the model
|
# Initialize pretrainer with the model
|
||||||
|
|
30
example.py
30
example.py
|
@ -1,24 +1,22 @@
|
||||||
from aiia.model import AIIABase
|
from src.aiia.model import AIIAmoe
|
||||||
from aiia.model import AIIAConfig
|
from src.aiia.model import AIIAConfig
|
||||||
from aiia.pretrain import Pretrainer
|
from src.aiia.pretrain import Pretrainer
|
||||||
|
|
||||||
# Create your model
|
# Create your model
|
||||||
config = AIIAConfig(model_name="AIIA-Base-512x20k")
|
config = AIIAConfig(num_experts=5)
|
||||||
model = AIIABase(config)
|
model = AIIAmoe(config)
|
||||||
|
model.save_pretrained("test")
|
||||||
|
model = AIIAmoe.from_pretrained("test")
|
||||||
|
|
||||||
# Initialize pretrainer with the model
|
|
||||||
pretrainer = Pretrainer(model, learning_rate=1e-4, config=config)
|
pretrainer = Pretrainer(model, learning_rate=1e-4, config=config)
|
||||||
|
|
||||||
# List of dataset paths
|
# Set checkpoint directory
|
||||||
dataset_paths = [
|
checkpoint_dir = "checkpoints/my_model"
|
||||||
"/path/to/dataset1.parquet",
|
|
||||||
"/path/to/dataset2.parquet"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Start training with multiple datasets
|
# Start training (will automatically load checkpoint if available)
|
||||||
pretrainer.train(
|
pretrainer.train(
|
||||||
dataset_paths=dataset_paths,
|
dataset_paths=["path/to/dataset1.parquet", "path/to/dataset2.parquet"],
|
||||||
num_epochs=10,
|
output_path="trained_models/my_model",
|
||||||
batch_size=2,
|
checkpoint_dir=checkpoint_dir,
|
||||||
sample_size=10000
|
num_epochs=10
|
||||||
)
|
)
|
|
@ -10,7 +10,7 @@ include = '\.pyi?$'
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "aiia"
|
name = "aiia"
|
||||||
version = "0.2.1"
|
version = "0.3.1"
|
||||||
description = "AIIA Deep Learning Model Implementation"
|
description = "AIIA Deep Learning Model Implementation"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [
|
||||||
|
|
|
@ -6,3 +6,4 @@ pillow
|
||||||
pandas
|
pandas
|
||||||
torchvision
|
torchvision
|
||||||
pyarrow
|
pyarrow
|
||||||
|
transformers>=4.48.0
|
|
@ -1,6 +1,6 @@
|
||||||
[metadata]
|
[metadata]
|
||||||
name = aiia
|
name = aiia
|
||||||
version = 0.2.1
|
version = 0.3.1
|
||||||
author = Falko Habel
|
author = Falko Habel
|
||||||
author_email = falko.habel@gmx.de
|
author_email = falko.habel@gmx.de
|
||||||
description = AIIA deep learning model implementation
|
description = AIIA deep learning model implementation
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from .model.Model import AIIABase, AIIABaseShared, AIIAchunked, AIIAmoe, AIIASparseMoe, AIIArecursive
|
from .model.Model import AIIABase, AIIABaseShared, AIIAmoe, AIIASparseMoe
|
||||||
from .model.config import AIIAConfig
|
from .model.config import AIIAConfig
|
||||||
from .data.DataLoader import DataLoader
|
from .data.DataLoader import DataLoader
|
||||||
from .pretrain.pretrainer import Pretrainer, ProjectionHead
|
from .pretrain.pretrainer import Pretrainer, ProjectionHead
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.2.1"
|
__version__ = "0.3.1"
|
||||||
|
|
|
@ -1,119 +1,48 @@
|
||||||
from .config import AIIAConfig
|
from .config import AIIAConfig
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from transformers import PreTrainedModel
|
||||||
import torch
|
import torch
|
||||||
import os
|
|
||||||
import copy
|
import copy
|
||||||
import warnings
|
|
||||||
|
|
||||||
|
|
||||||
class AIIA(nn.Module):
|
class AIIABase(PreTrainedModel):
|
||||||
def __init__(self, config: AIIAConfig, **kwargs):
|
config_class = AIIAConfig
|
||||||
super(AIIA, self).__init__()
|
base_model_prefix = "AIIA"
|
||||||
# Create a deep copy of the configuration to avoid sharing
|
|
||||||
self.config = copy.deepcopy(config)
|
|
||||||
|
|
||||||
# Update the config with any additional keyword arguments
|
def __init__(self, config: AIIAConfig):
|
||||||
for key, value in kwargs.items():
|
super().__init__(config)
|
||||||
setattr(self.config, key, value)
|
|
||||||
|
|
||||||
def save(self, path: str):
|
|
||||||
if not os.path.exists(path):
|
|
||||||
os.makedirs(path, exist_ok=True)
|
|
||||||
torch.save(self.state_dict(), f"{path}/model.pth")
|
|
||||||
self.config.save(path)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, path, precision: str = None, strict: bool = True, **kwargs):
|
|
||||||
config = AIIAConfig.load(path)
|
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
||||||
|
|
||||||
# Load the state dict to analyze structure
|
|
||||||
model_dict = torch.load(f"{path}/model.pth", map_location=device)
|
|
||||||
|
|
||||||
# Special handling for AIIAmoe - detect number of experts from state_dict
|
|
||||||
if cls.__name__ == "AIIAmoe" and "num_experts" not in kwargs:
|
|
||||||
# Find maximum expert index
|
|
||||||
max_expert_idx = -1
|
|
||||||
for key in model_dict.keys():
|
|
||||||
if key.startswith("experts."):
|
|
||||||
parts = key.split(".")
|
|
||||||
if len(parts) > 1:
|
|
||||||
try:
|
|
||||||
expert_idx = int(parts[1])
|
|
||||||
max_expert_idx = max(max_expert_idx, expert_idx)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if max_expert_idx >= 0:
|
|
||||||
# experts.X keys found, use max_expert_idx + 1 as num_experts
|
|
||||||
kwargs["num_experts"] = max_expert_idx + 1
|
|
||||||
|
|
||||||
# Create model with detected structural parameters
|
|
||||||
model = cls(config, **kwargs)
|
|
||||||
|
|
||||||
# Handle precision conversion
|
|
||||||
dtype = None
|
|
||||||
if precision is not None:
|
|
||||||
if precision.lower() == 'fp16':
|
|
||||||
dtype = torch.float16
|
|
||||||
elif precision.lower() == 'bf16':
|
|
||||||
if device == 'cuda' and not torch.cuda.is_bf16_supported():
|
|
||||||
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
|
|
||||||
dtype = torch.float16
|
|
||||||
else:
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
|
|
||||||
|
|
||||||
if dtype is not None:
|
|
||||||
for key, param in model_dict.items():
|
|
||||||
if torch.is_tensor(param):
|
|
||||||
model_dict[key] = param.to(dtype)
|
|
||||||
|
|
||||||
# Load state dict with strict parameter for flexibility
|
|
||||||
model.load_state_dict(model_dict, strict=strict)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class AIIABase(AIIA):
|
|
||||||
def __init__(self, config: AIIAConfig, **kwargs):
|
|
||||||
super().__init__(config=config, **kwargs)
|
|
||||||
self.config = self.config
|
|
||||||
|
|
||||||
# Initialize layers based on configuration
|
# Initialize layers based on configuration
|
||||||
layers = []
|
layers = []
|
||||||
in_channels = self.config.num_channels
|
in_channels = config.num_channels
|
||||||
|
|
||||||
for _ in range(self.config.num_hidden_layers):
|
for _ in range(config.num_hidden_layers):
|
||||||
layers.extend([
|
layers.extend([
|
||||||
nn.Conv2d(in_channels, self.config.hidden_size,
|
nn.Conv2d(in_channels, config.hidden_size,
|
||||||
kernel_size=self.config.kernel_size, padding=1),
|
kernel_size=config.kernel_size, padding=1),
|
||||||
getattr(nn, self.config.activation_function)(),
|
getattr(nn, config.activation_function)(),
|
||||||
nn.MaxPool2d(kernel_size=1, stride=1)
|
nn.MaxPool2d(kernel_size=1, stride=1)
|
||||||
])
|
])
|
||||||
in_channels = self.config.hidden_size
|
in_channels = config.hidden_size
|
||||||
|
|
||||||
self.cnn = nn.Sequential(*layers)
|
self.cnn = nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.cnn(x)
|
return self.cnn(x)
|
||||||
|
|
||||||
class AIIABaseShared(AIIA):
|
class AIIABaseShared(PreTrainedModel):
|
||||||
def __init__(self, config: AIIAConfig, **kwargs):
|
config_class = AIIAConfig
|
||||||
|
base_model_prefix = "AIIA"
|
||||||
|
|
||||||
|
def __init__(self, config: AIIAConfig):
|
||||||
|
super().__init__(config)
|
||||||
"""
|
"""
|
||||||
Initialize the AIIABaseShared model.
|
Initialize the AIIABaseShared model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (AIIAConfig): Configuration object containing model parameters.
|
config (AIIAConfig): Configuration object containing model parameters.
|
||||||
**kwargs: Additional keyword arguments to override configuration settings.
|
|
||||||
"""
|
"""
|
||||||
super().__init__(config=config, **kwargs)
|
super().__init__(config=config)
|
||||||
|
|
||||||
# Update configuration with new parameters if provided
|
|
||||||
self. config = copy.deepcopy(config)
|
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
setattr(self.config, key, value)
|
|
||||||
|
|
||||||
# Initialize the network components
|
# Initialize the network components
|
||||||
self._initialize_network()
|
self._initialize_network()
|
||||||
|
@ -172,16 +101,17 @@ class AIIABaseShared(AIIA):
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class AIIAExpert(AIIA):
|
class AIIAExpert(PreTrainedModel):
|
||||||
def __init__(self, config: AIIAConfig, base_class=AIIABase, **kwargs):
|
config_class = AIIAConfig
|
||||||
super().__init__(config=config, **kwargs)
|
base_model_prefix = "AIIA"
|
||||||
self.config = self.config
|
def __init__(self, config: AIIAConfig, base_class=AIIABase):
|
||||||
|
super().__init__(config=config)
|
||||||
|
|
||||||
# Initialize base CNN with configuration and chosen base class
|
# Initialize base CNN with configuration and chosen base class
|
||||||
if issubclass(base_class, AIIABase):
|
if issubclass(base_class, AIIABase):
|
||||||
self.base_cnn = AIIABase(self.config, **kwargs)
|
self.base_cnn = AIIABase(self.config)
|
||||||
elif issubclass(base_class, AIIABaseShared):
|
elif issubclass(base_class, AIIABaseShared):
|
||||||
self.base_cnn = AIIABaseShared(self.config, **kwargs)
|
self.base_cnn = AIIABaseShared(self.config)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid base class")
|
raise ValueError("Invalid base class")
|
||||||
|
|
||||||
|
@ -198,26 +128,26 @@ class AIIAExpert(AIIA):
|
||||||
# Process input through the base CNN
|
# Process input through the base CNN
|
||||||
return self.base_cnn(x)
|
return self.base_cnn(x)
|
||||||
|
|
||||||
class AIIAmoe(AIIA):
|
class AIIAmoe(PreTrainedModel):
|
||||||
def __init__(self, config: AIIAConfig, num_experts: int = 3, base_class=AIIABase, **kwargs):
|
config_class = AIIAConfig
|
||||||
super().__init__(config=config, **kwargs)
|
base_model_prefix = "AIIA"
|
||||||
|
|
||||||
|
def __init__(self, config: AIIAConfig, base_class=AIIABase):
|
||||||
|
super().__init__(config=config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Update the config to include the number of experts.
|
# Get num_experts directly from config instead of parameter
|
||||||
self.config.num_experts = num_experts
|
num_experts = getattr(config, "num_experts", 3) # default to 3 if not in config
|
||||||
|
|
||||||
# Initialize multiple experts from the chosen base class.
|
# Initialize multiple experts from the chosen base class
|
||||||
self.experts = nn.ModuleList([
|
self.experts = nn.ModuleList([
|
||||||
AIIAExpert(self.config, base_class=base_class, **kwargs)
|
AIIAExpert(self.config, base_class=base_class)
|
||||||
for _ in range(num_experts)
|
for _ in range(num_experts)
|
||||||
])
|
])
|
||||||
|
|
||||||
# To generate gating weights, we first need to determine the feature dimension.
|
gate_in_features = self.config.hidden_size
|
||||||
# Each expert is assumed to return an output of shape (B, C, H, W); after averaging over H and W,
|
|
||||||
# we obtain a tensor of shape (B, C) where C is the number of channels (here assumed to be 224).
|
|
||||||
gate_in_features = 512 # Adjust this if your expert output changes.
|
|
||||||
|
|
||||||
# Create a gating network that maps the aggregated features to num_experts weights.
|
# Create a gating network that maps the aggregated features to num_experts weights
|
||||||
self.gate = nn.Sequential(
|
self.gate = nn.Sequential(
|
||||||
nn.Linear(gate_in_features, num_experts),
|
nn.Linear(gate_in_features, num_experts),
|
||||||
nn.Softmax(dim=1)
|
nn.Softmax(dim=1)
|
||||||
|
@ -261,9 +191,10 @@ class AIIAmoe(AIIA):
|
||||||
|
|
||||||
|
|
||||||
class AIIASparseMoe(AIIAmoe):
|
class AIIASparseMoe(AIIAmoe):
|
||||||
def __init__(self, config: AIIAConfig, num_experts: int = 3, top_k: int = 2, base_class=AIIABase, **kwargs):
|
config_class = AIIAConfig
|
||||||
super().__init__(config=config, num_experts=num_experts, base_class=base_class, **kwargs)
|
base_model_prefix = "AIIA"
|
||||||
self.top_k = top_k
|
def __init__(self, config: AIIAConfig, base_class=AIIABase):
|
||||||
|
super().__init__(config=config, base_class=base_class)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
# Compute the gate_weights similar to standard moe.
|
# Compute the gate_weights similar to standard moe.
|
||||||
|
@ -273,7 +204,7 @@ class AIIASparseMoe(AIIAmoe):
|
||||||
gate_weights = self.gate(gate_input)
|
gate_weights = self.gate(gate_input)
|
||||||
|
|
||||||
# Select the top-k experts for each input based on gating weights.
|
# Select the top-k experts for each input based on gating weights.
|
||||||
_, top_k_indices = gate_weights.topk(self.top_k, dim=-1)
|
_, top_k_indices = gate_weights.topk(self.config.top_k, dim=-1)
|
||||||
|
|
||||||
# Initialize a list to store outputs from selected experts.
|
# Initialize a list to store outputs from selected experts.
|
||||||
merged_outputs = []
|
merged_outputs = []
|
||||||
|
@ -294,64 +225,7 @@ class AIIASparseMoe(AIIAmoe):
|
||||||
return torch.cat(merged_outputs, dim=0)
|
return torch.cat(merged_outputs, dim=0)
|
||||||
|
|
||||||
|
|
||||||
class AIIAchunked(AIIA):
|
|
||||||
def __init__(self, config: AIIAConfig, patch_size: int = 16, base_class=AIIABase, **kwargs):
|
|
||||||
super().__init__(config=config, **kwargs)
|
|
||||||
self.config = self.config
|
|
||||||
|
|
||||||
# Update config with new parameters if provided
|
|
||||||
self.config.patch_size = patch_size
|
|
||||||
|
|
||||||
# Initialize base CNN for processing each patch using the specified base class
|
|
||||||
if issubclass(base_class, AIIABase):
|
|
||||||
self.base_cnn = AIIABase(self.config, **kwargs)
|
|
||||||
elif issubclass(base_class, AIIABaseShared): # Add support for AIIABaseShared
|
|
||||||
self.base_cnn = AIIABaseShared(self.config, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid base class")
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
|
|
||||||
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, self.patch_size, self.patch_size)
|
|
||||||
patch_outputs = []
|
|
||||||
|
|
||||||
for p in torch.split(patches, 1, dim=2):
|
|
||||||
p = p.squeeze(2)
|
|
||||||
po = self.base_cnn(p)
|
|
||||||
patch_outputs.append(po)
|
|
||||||
|
|
||||||
combined_output = torch.mean(torch.stack(patch_outputs, dim=0), dim=0)
|
|
||||||
return combined_output
|
|
||||||
|
|
||||||
class AIIArecursive(AIIA):
|
|
||||||
def __init__(self, config: AIIAConfig, recursion_depth: int = 3, base_class=AIIABase, **kwargs):
|
|
||||||
|
|
||||||
super().__init__(config=config, **kwargs)
|
|
||||||
self.config = self.config
|
|
||||||
|
|
||||||
# Pass recursion_depth as a kwarg to the config
|
|
||||||
self.config.recursion_depth = recursion_depth
|
|
||||||
|
|
||||||
# Initialize chunked CNN with updated config
|
|
||||||
self.chunked_cnn = AIIAchunked(self.config, base_class, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, x, depth=0):
|
|
||||||
if depth == self.recursion_depth:
|
|
||||||
return self.chunked_cnn(x)
|
|
||||||
else:
|
|
||||||
patches = x.unfold(2, 16, 16).unfold(3, 16, 16)
|
|
||||||
patches = patches.contiguous().view(patches.size(0), patches.size(1), -1, 16, 16)
|
|
||||||
processed_patches = []
|
|
||||||
|
|
||||||
for p in torch.split(patches, 1, dim=2):
|
|
||||||
p = p.squeeze(2)
|
|
||||||
pp = self.forward(p, depth + 1)
|
|
||||||
processed_patches.append(pp)
|
|
||||||
|
|
||||||
combined_output = torch.mean(torch.stack(processed_patches, dim=0), dim=0)
|
|
||||||
return combined_output
|
|
||||||
|
|
||||||
if __name__ =="__main__":
|
if __name__ =="__main__":
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
model = AIIAmoe(config, num_experts=5)
|
model = AIIAmoe(config, num_experts=5)
|
||||||
model.save("test")
|
model.save_pretrained("test")
|
|
@ -1,20 +1,15 @@
|
||||||
from .Model import (
|
from .Model import (
|
||||||
AIIABase,
|
AIIABase,
|
||||||
AIIABaseShared,
|
AIIABaseShared,
|
||||||
AIIAchunked,
|
|
||||||
AIIAmoe,
|
AIIAmoe,
|
||||||
AIIASparseMoe,
|
AIIASparseMoe,
|
||||||
AIIArecursive
|
|
||||||
)
|
)
|
||||||
from .config import AIIAConfig
|
from .config import AIIAConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AIIABase",
|
"AIIABase",
|
||||||
"AIIABaseShared",
|
"AIIABaseShared",
|
||||||
"AIIAchunked",
|
|
||||||
"AIIAmoe",
|
"AIIAmoe",
|
||||||
"AIIASparseMoe",
|
"AIIASparseMoe",
|
||||||
"AIIArecursive",
|
|
||||||
"AIIAConfig",
|
"AIIAConfig",
|
||||||
|
|
||||||
]
|
]
|
|
@ -1,28 +1,24 @@
|
||||||
import torch
|
from transformers import PretrainedConfig
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
class AIIAConfig(PretrainedConfig):
|
||||||
|
model_type = "AIIA" # Add this class attribute
|
||||||
|
|
||||||
class AIIAConfig:
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str = "AIIA",
|
|
||||||
kernel_size: int = 3,
|
kernel_size: int = 3,
|
||||||
activation_function: str = "GELU",
|
activation_function: str = "GELU",
|
||||||
hidden_size: int = 512,
|
hidden_size: int = 512,
|
||||||
num_hidden_layers: int = 12,
|
num_hidden_layers: int = 12,
|
||||||
num_channels: int = 3,
|
num_channels: int = 3,
|
||||||
learning_rate: float = 5e-5,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
super().__init__(**kwargs)
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.activation_function = activation_function
|
self.activation_function = activation_function
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.learning_rate = learning_rate
|
|
||||||
|
|
||||||
# Store additional keyword arguments as attributes
|
# Store additional keyword arguments as attributes
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
|
@ -51,16 +47,3 @@ class AIIAConfig:
|
||||||
return {k: serialize(v) for k, v in value.items()}
|
return {k: serialize(v) for k, v in value.items()}
|
||||||
return value
|
return value
|
||||||
return {k: serialize(v) for k, v in self.__dict__.items()}
|
return {k: serialize(v) for k, v in self.__dict__.items()}
|
||||||
|
|
||||||
def save(self, file_path):
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
os.makedirs(file_path, exist_ok=True)
|
|
||||||
with open(os.path.join(file_path, "config.json"), "w") as f:
|
|
||||||
# Save the recursively converted dictionary.
|
|
||||||
json.dump(self.to_dict(), f, indent=4)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, file_path):
|
|
||||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
|
||||||
config_dict = json.load(f)
|
|
||||||
return cls(**config_dict)
|
|
|
@ -1,9 +1,11 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import csv
|
import csv
|
||||||
|
import datetime
|
||||||
|
import time
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from ..model.Model import AIIA
|
from transformers import PreTrainedModel
|
||||||
from ..model.config import AIIAConfig
|
from ..model.config import AIIAConfig
|
||||||
from ..data.DataLoader import AIIADataLoader
|
from ..data.DataLoader import AIIADataLoader
|
||||||
import os
|
import os
|
||||||
|
@ -21,7 +23,7 @@ class ProjectionHead(nn.Module):
|
||||||
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
|
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
|
||||||
|
|
||||||
class Pretrainer:
|
class Pretrainer:
|
||||||
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None):
|
def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None):
|
||||||
"""
|
"""
|
||||||
Initialize the pretrainer with a model.
|
Initialize the pretrainer with a model.
|
||||||
|
|
||||||
|
@ -112,20 +114,169 @@ class Pretrainer:
|
||||||
|
|
||||||
return batch_loss
|
return batch_loss
|
||||||
|
|
||||||
def train(self, dataset_paths,output_path:str="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000):
|
def _save_checkpoint(self, checkpoint_dir, epoch, batch_count, checkpoint_name):
|
||||||
"""
|
"""Save a model checkpoint.
|
||||||
Train the model using multiple specified datasets.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_paths (list): List of paths to parquet datasets
|
checkpoint_dir (str): Directory to save the checkpoint
|
||||||
num_epochs (int): Number of training epochs
|
epoch (int): Current epoch number
|
||||||
batch_size (int): Batch size for training
|
batch_count (int): Current batch count
|
||||||
sample_size (int): Number of samples to use from each dataset
|
checkpoint_name (str): Name for the checkpoint file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Path to the saved checkpoint
|
||||||
"""
|
"""
|
||||||
|
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
|
||||||
|
checkpoint_data = {
|
||||||
|
'epoch': epoch + 1,
|
||||||
|
'batch': batch_count,
|
||||||
|
'model_state_dict': self.model.state_dict(),
|
||||||
|
'projection_head_state_dict': self.projection_head.state_dict(),
|
||||||
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||||
|
'train_losses': self.train_losses,
|
||||||
|
'val_losses': self.val_losses,
|
||||||
|
}
|
||||||
|
torch.save(checkpoint_data, checkpoint_path)
|
||||||
|
return checkpoint_path
|
||||||
|
|
||||||
|
def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None):
|
||||||
|
"""
|
||||||
|
Check for checkpoints and load if available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_dir (str): Directory where checkpoints are stored
|
||||||
|
specific_checkpoint (str, optional): Specific checkpoint file to load. If None, loads the most recent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise
|
||||||
|
"""
|
||||||
|
# Create checkpoint directory if it doesn't exist
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# If a specific checkpoint is requested
|
||||||
|
if specific_checkpoint:
|
||||||
|
checkpoint_path = os.path.join(checkpoint_dir, specific_checkpoint)
|
||||||
|
if os.path.exists(checkpoint_path):
|
||||||
|
return self._load_checkpoint_file(checkpoint_path)
|
||||||
|
else:
|
||||||
|
print(f"Specified checkpoint {specific_checkpoint} not found.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find all checkpoint files
|
||||||
|
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pt")]
|
||||||
|
|
||||||
|
if not checkpoint_files:
|
||||||
|
print("No checkpoints found in directory.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find the most recent checkpoint
|
||||||
|
checkpoint_files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
|
||||||
|
most_recent = checkpoint_files[0]
|
||||||
|
checkpoint_path = os.path.join(checkpoint_dir, most_recent)
|
||||||
|
|
||||||
|
return self._load_checkpoint_file(checkpoint_path)
|
||||||
|
|
||||||
|
def _load_checkpoint_file(self, checkpoint_path):
|
||||||
|
"""
|
||||||
|
Load a specific checkpoint file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path (str): Path to the checkpoint file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||||
|
|
||||||
|
# Load model state
|
||||||
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
|
||||||
|
# Load projection head state
|
||||||
|
self.projection_head.load_state_dict(checkpoint['projection_head_state_dict'])
|
||||||
|
|
||||||
|
# Load optimizer state
|
||||||
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
|
||||||
|
# Load loss history
|
||||||
|
self.train_losses = checkpoint.get('train_losses', [])
|
||||||
|
self.val_losses = checkpoint.get('val_losses', [])
|
||||||
|
|
||||||
|
loaded_epoch = checkpoint['epoch']
|
||||||
|
loaded_batch = checkpoint['batch']
|
||||||
|
|
||||||
|
print(f"Checkpoint loaded from {checkpoint_path}")
|
||||||
|
print(f"Resuming from epoch {loaded_epoch}, batch {loaded_batch}")
|
||||||
|
|
||||||
|
return loaded_epoch, loaded_batch
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading checkpoint: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def train(self, dataset_paths, output_path="AIIA", column="image_bytes",
|
||||||
|
num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None):
|
||||||
|
"""Train the model using multiple specified datasets with checkpoint resumption support."""
|
||||||
if not dataset_paths:
|
if not dataset_paths:
|
||||||
raise ValueError("No dataset paths provided")
|
raise ValueError("No dataset paths provided")
|
||||||
|
|
||||||
# Read and merge all datasets
|
self._initialize_checkpoint_variables()
|
||||||
|
start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir)
|
||||||
|
|
||||||
|
dataframes = self._load_and_merge_datasets(dataset_paths, sample_size)
|
||||||
|
aiia_loader = self._initialize_data_loader(dataframes, column, batch_size)
|
||||||
|
|
||||||
|
criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions()
|
||||||
|
|
||||||
|
for epoch in range(start_epoch, num_epochs):
|
||||||
|
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||||
|
print("-" * 20)
|
||||||
|
total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader,
|
||||||
|
start_batch if (epoch == start_epoch and resume_training) else 0,
|
||||||
|
criterion_denoise,
|
||||||
|
criterion_rotate)
|
||||||
|
|
||||||
|
avg_train_loss = total_train_loss / max(batch_count, 1)
|
||||||
|
self.train_losses.append(avg_train_loss)
|
||||||
|
print(f"Training Loss: {avg_train_loss:.4f}")
|
||||||
|
|
||||||
|
val_loss = self._validation_phase(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
|
||||||
|
|
||||||
|
if val_loss < best_val_loss:
|
||||||
|
best_val_loss = val_loss
|
||||||
|
self.model.save(output_path)
|
||||||
|
print("Best model saved!")
|
||||||
|
|
||||||
|
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
|
||||||
|
self.save_losses(losses_path)
|
||||||
|
|
||||||
|
def _initialize_checkpoint_variables(self):
|
||||||
|
"""Initialize checkpoint tracking variables."""
|
||||||
|
self.last_checkpoint_time = time.time()
|
||||||
|
self.checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds
|
||||||
|
self.last_22_date = None
|
||||||
|
self.recent_checkpoints = []
|
||||||
|
|
||||||
|
def _load_checkpoints(self, checkpoint_dir):
|
||||||
|
"""Load checkpoints and return start epoch, batch, and resumption flag."""
|
||||||
|
start_epoch = 0
|
||||||
|
start_batch = 0
|
||||||
|
resume_training = False
|
||||||
|
|
||||||
|
if checkpoint_dir is not None:
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
checkpoint_info = self.load_checkpoint(checkpoint_dir)
|
||||||
|
if checkpoint_info:
|
||||||
|
start_epoch, start_batch = checkpoint_info
|
||||||
|
resume_training = True
|
||||||
|
# Adjust epoch to be 0-indexed for the loop
|
||||||
|
start_epoch -= 1
|
||||||
|
|
||||||
|
return start_epoch, start_batch, resume_training
|
||||||
|
|
||||||
|
def _load_and_merge_datasets(self, dataset_paths, sample_size):
|
||||||
|
"""Load and merge datasets."""
|
||||||
dataframes = []
|
dataframes = []
|
||||||
for path in dataset_paths:
|
for path in dataset_paths:
|
||||||
try:
|
try:
|
||||||
|
@ -137,10 +288,11 @@ class Pretrainer:
|
||||||
if not dataframes:
|
if not dataframes:
|
||||||
raise ValueError("No valid datasets could be loaded")
|
raise ValueError("No valid datasets could be loaded")
|
||||||
|
|
||||||
merged_df = pd.concat(dataframes, ignore_index=True)
|
return pd.concat(dataframes, ignore_index=True)
|
||||||
|
|
||||||
# Initialize data loader
|
def _initialize_data_loader(self, merged_df, column, batch_size):
|
||||||
aiia_loader = AIIADataLoader(
|
"""Initialize the data loader."""
|
||||||
|
return AIIADataLoader(
|
||||||
merged_df,
|
merged_df,
|
||||||
column=column,
|
column=column,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
@ -148,24 +300,30 @@ class Pretrainer:
|
||||||
collate_fn=self.safe_collate
|
collate_fn=self.safe_collate
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _initialize_loss_functions(self):
|
||||||
|
"""Initialize loss functions and tracking variables."""
|
||||||
criterion_denoise = nn.MSELoss()
|
criterion_denoise = nn.MSELoss()
|
||||||
criterion_rotate = nn.CrossEntropyLoss()
|
criterion_rotate = nn.CrossEntropyLoss()
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
|
return criterion_denoise, criterion_rotate, best_val_loss
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
def _training_phase(self, train_loader, skip_batches, criterion_denoise, criterion_rotate):
|
||||||
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
"""Handle the training phase."""
|
||||||
print("-" * 20)
|
|
||||||
|
|
||||||
# Training phase
|
|
||||||
self.model.train()
|
self.model.train()
|
||||||
self.projection_head.train()
|
self.projection_head.train()
|
||||||
total_train_loss = 0.0
|
total_train_loss = 0.0
|
||||||
batch_count = 0
|
batch_count = 0
|
||||||
|
|
||||||
for batch_data in tqdm(aiia_loader.train_loader):
|
train_batches = list(enumerate(train_loader))
|
||||||
|
for i, batch_data in tqdm(train_batches[skip_batches:],
|
||||||
|
initial=skip_batches,
|
||||||
|
total=len(train_batches)):
|
||||||
if batch_data is None:
|
if batch_data is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
current_batch = i + 1
|
||||||
|
self._handle_checkpoints(current_batch)
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
||||||
|
|
||||||
|
@ -175,22 +333,42 @@ class Pretrainer:
|
||||||
total_train_loss += batch_loss.item()
|
total_train_loss += batch_loss.item()
|
||||||
batch_count += 1
|
batch_count += 1
|
||||||
|
|
||||||
avg_train_loss = total_train_loss / max(batch_count, 1)
|
return total_train_loss, batch_count
|
||||||
self.train_losses.append(avg_train_loss)
|
|
||||||
print(f"Training Loss: {avg_train_loss:.4f}")
|
|
||||||
|
|
||||||
# Validation phase
|
def _handle_checkpoints(self, current_batch):
|
||||||
|
"""Handle checkpoint saving logic."""
|
||||||
|
current_time = time.time()
|
||||||
|
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time
|
||||||
|
today = current_dt.date()
|
||||||
|
|
||||||
|
if self.checkpoint_dir and (current_time - self.last_checkpoint_time) >= self.checkpoint_interval:
|
||||||
|
checkpoint_name = f"checkpoint_epoch{self.current_epoch+1}_batch{current_batch}.pt"
|
||||||
|
checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
|
||||||
|
|
||||||
|
# Track and maintain only 3 recent checkpoints
|
||||||
|
self.recent_checkpoints.append(checkpoint_path)
|
||||||
|
if len(self.recent_checkpoints) > 3:
|
||||||
|
oldest = self.recent_checkpoints.pop(0)
|
||||||
|
if os.path.exists(oldest):
|
||||||
|
os.remove(oldest)
|
||||||
|
|
||||||
|
self.last_checkpoint_time = current_time
|
||||||
|
print(f"Checkpoint saved at {checkpoint_path}")
|
||||||
|
|
||||||
|
# Special 22:00 checkpoint (considering it's currently 10:15 PM)
|
||||||
|
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
|
||||||
|
|
||||||
|
if self.checkpoint_dir and is_22_oclock and self.last_22_date != today:
|
||||||
|
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
|
||||||
|
checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
|
||||||
|
self.last_22_date = today
|
||||||
|
print(f"22:00 Checkpoint saved at {checkpoint_path}")
|
||||||
|
|
||||||
|
def _validation_phase(self, val_loader, criterion_denoise, criterion_rotate):
|
||||||
|
"""Handle the validation phase."""
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.projection_head.eval()
|
self.projection_head.eval()
|
||||||
val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
|
return self._validate(val_loader, criterion_denoise, criterion_rotate)
|
||||||
|
|
||||||
if val_loss < best_val_loss:
|
|
||||||
best_val_loss = val_loss
|
|
||||||
self.model.save(output_path)
|
|
||||||
print("Best model saved!")
|
|
||||||
|
|
||||||
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
|
|
||||||
self.save_losses(losses_path)
|
|
||||||
|
|
||||||
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
|
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
|
||||||
"""Perform validation and return average validation loss."""
|
"""Perform validation and return average validation loss."""
|
||||||
|
|
|
@ -1,159 +1,133 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig, AIIASparseMoe
|
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAConfig, AIIASparseMoe
|
||||||
|
|
||||||
def test_aiiabase_creation():
|
def test_aiiabase_creation():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
model = AIIABase(config)
|
model = AIIABase(config)
|
||||||
assert isinstance(model, AIIABase)
|
assert isinstance(model, AIIABase)
|
||||||
|
|
||||||
def test_aiiabase_save_load():
|
def test_aiiabase_save_pretrained_from_pretrained():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
model = AIIABase(config)
|
model = AIIABase(config)
|
||||||
save_path = "test_aiiabase_save_load"
|
save_pretrained_path = "test_aiiabase_save_pretrained_load"
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
model.save(save_path)
|
model.save_pretrained(save_pretrained_path)
|
||||||
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
|
||||||
assert os.path.exists(os.path.join(save_path, "config.json"))
|
assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
|
||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
loaded_model = AIIABase.load(save_path)
|
loaded_model = AIIABase.from_pretrained(save_pretrained_path)
|
||||||
|
|
||||||
# Check if the loaded model is an instance of AIIABase
|
# Check if the loaded model is an instance of AIIABase
|
||||||
assert isinstance(loaded_model, AIIABase)
|
assert isinstance(loaded_model, AIIABase)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
os.remove(os.path.join(save_path, "model.pth"))
|
os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
|
||||||
os.remove(os.path.join(save_path, "config.json"))
|
os.remove(os.path.join(save_pretrained_path, "config.json"))
|
||||||
os.rmdir(save_path)
|
os.rmdir(save_pretrained_path)
|
||||||
|
|
||||||
def test_aiiabase_shared_creation():
|
def test_aiiabase_shared_creation():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
model = AIIABaseShared(config)
|
model = AIIABaseShared(config)
|
||||||
assert isinstance(model, AIIABaseShared)
|
assert isinstance(model, AIIABaseShared)
|
||||||
|
|
||||||
def test_aiiabase_shared_save_load():
|
def test_aiiabase_shared_save_pretrained_from_pretrained():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
model = AIIABaseShared(config)
|
model = AIIABaseShared(config)
|
||||||
save_path = "test_aiiabase_shared_save_load"
|
save_pretrained_path = "test_aiiabase_shared_save_pretrained_load"
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
model.save(save_path)
|
model.save_pretrained(save_pretrained_path)
|
||||||
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
|
||||||
assert os.path.exists(os.path.join(save_path, "config.json"))
|
assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
|
||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
loaded_model = AIIABaseShared.load(save_path)
|
loaded_model = AIIABaseShared.from_pretrained(save_pretrained_path)
|
||||||
|
|
||||||
# Check if the loaded model is an instance of AIIABaseShared
|
# Check if the loaded model is an instance of AIIABaseShared
|
||||||
assert isinstance(loaded_model, AIIABaseShared)
|
assert isinstance(loaded_model, AIIABaseShared)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
os.remove(os.path.join(save_path, "model.pth"))
|
os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
|
||||||
os.remove(os.path.join(save_path, "config.json"))
|
os.remove(os.path.join(save_pretrained_path, "config.json"))
|
||||||
os.rmdir(save_path)
|
os.rmdir(save_pretrained_path)
|
||||||
|
|
||||||
def test_aiiaexpert_creation():
|
def test_aiiaexpert_creation():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
model = AIIAExpert(config)
|
model = AIIAExpert(config)
|
||||||
assert isinstance(model, AIIAExpert)
|
assert isinstance(model, AIIAExpert)
|
||||||
|
|
||||||
def test_aiiaexpert_save_load():
|
def test_aiiaexpert_save_pretrained_from_pretrained():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
model = AIIAExpert(config)
|
model = AIIAExpert(config)
|
||||||
save_path = "test_aiiaexpert_save_load"
|
save_pretrained_path = "test_aiiaexpert_save_pretrained_load"
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
model.save(save_path)
|
model.save_pretrained(save_pretrained_path)
|
||||||
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
|
||||||
assert os.path.exists(os.path.join(save_path, "config.json"))
|
assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
|
||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
loaded_model = AIIAExpert.load(save_path)
|
loaded_model = AIIAExpert.from_pretrained(save_pretrained_path)
|
||||||
|
|
||||||
# Check if the loaded model is an instance of AIIAExpert
|
# Check if the loaded model is an instance of AIIAExpert
|
||||||
assert isinstance(loaded_model, AIIAExpert)
|
assert isinstance(loaded_model, AIIAExpert)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
os.remove(os.path.join(save_path, "model.pth"))
|
os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
|
||||||
os.remove(os.path.join(save_path, "config.json"))
|
os.remove(os.path.join(save_pretrained_path, "config.json"))
|
||||||
os.rmdir(save_path)
|
os.rmdir(save_pretrained_path)
|
||||||
|
|
||||||
def test_aiiamoe_creation():
|
def test_aiiamoe_creation():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig(num_experts=3)
|
||||||
model = AIIAmoe(config, num_experts=5)
|
model = AIIAmoe(config)
|
||||||
assert isinstance(model, AIIAmoe)
|
assert isinstance(model, AIIAmoe)
|
||||||
|
|
||||||
def test_aiiamoe_save_load():
|
def test_aiiamoe_save_pretrained_from_pretrained():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig(num_experts=3)
|
||||||
model = AIIAmoe(config, num_experts=5)
|
model = AIIAmoe(config)
|
||||||
save_path = "test_aiiamoe_save_load"
|
save_pretrained_path = "test_aiiamoe_save_pretrained_load"
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
model.save(save_path)
|
model.save_pretrained(save_pretrained_path)
|
||||||
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
|
||||||
assert os.path.exists(os.path.join(save_path, "config.json"))
|
assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
|
||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
loaded_model = AIIAmoe.load(save_path)
|
loaded_model = AIIAmoe.from_pretrained(save_pretrained_path)
|
||||||
|
|
||||||
# Check if the loaded model is an instance of AIIAmoe
|
# Check if the loaded model is an instance of AIIAmoe
|
||||||
assert isinstance(loaded_model, AIIAmoe)
|
assert isinstance(loaded_model, AIIAmoe)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
os.remove(os.path.join(save_path, "model.pth"))
|
os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
|
||||||
os.remove(os.path.join(save_path, "config.json"))
|
os.remove(os.path.join(save_pretrained_path, "config.json"))
|
||||||
os.rmdir(save_path)
|
os.rmdir(save_pretrained_path)
|
||||||
|
|
||||||
def test_aiiasparsemoe_creation():
|
def test_aiiasparsemoe_creation():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig(num_experts=5, top_k=2)
|
||||||
model = AIIASparseMoe(config, num_experts=5, top_k=2)
|
model = AIIASparseMoe(config, base_class=AIIABaseShared)
|
||||||
assert isinstance(model, AIIASparseMoe)
|
assert isinstance(model, AIIASparseMoe)
|
||||||
|
|
||||||
def test_aiiasparsemoe_save_load():
|
def test_aiiasparsemoe_save_pretrained_from_pretrained():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig(num_experts=3, top_k=1)
|
||||||
model = AIIASparseMoe(config, num_experts=3, top_k=1)
|
model = AIIASparseMoe(config)
|
||||||
save_path = "test_aiiasparsemoe_save_load"
|
save_pretrained_path = "test_aiiasparsemoe_save_pretrained_load"
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
model.save(save_path)
|
model.save_pretrained(save_pretrained_path)
|
||||||
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
assert os.path.exists(os.path.join(save_pretrained_path, "model.safetensors"))
|
||||||
assert os.path.exists(os.path.join(save_path, "config.json"))
|
assert os.path.exists(os.path.join(save_pretrained_path, "config.json"))
|
||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
loaded_model = AIIASparseMoe.load(save_path)
|
loaded_model = AIIASparseMoe.from_pretrained(save_pretrained_path)
|
||||||
|
|
||||||
# Check if the loaded model is an instance of AIIASparseMoe
|
# Check if the loaded model is an instance of AIIASparseMoe
|
||||||
assert isinstance(loaded_model, AIIASparseMoe)
|
assert isinstance(loaded_model, AIIASparseMoe)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
os.remove(os.path.join(save_path, "model.pth"))
|
os.remove(os.path.join(save_pretrained_path, "model.safetensors"))
|
||||||
os.remove(os.path.join(save_path, "config.json"))
|
os.remove(os.path.join(save_pretrained_path, "config.json"))
|
||||||
os.rmdir(save_path)
|
os.rmdir(save_pretrained_path)
|
||||||
|
|
||||||
def test_aiiachunked_creation():
|
|
||||||
config = AIIAConfig()
|
|
||||||
model = AIIAchunked(config)
|
|
||||||
assert isinstance(model, AIIAchunked)
|
|
||||||
|
|
||||||
def test_aiiachunked_save_load():
|
|
||||||
config = AIIAConfig()
|
|
||||||
model = AIIAchunked(config)
|
|
||||||
save_path = "test_aiiachunked_save_load"
|
|
||||||
|
|
||||||
# Save the model
|
|
||||||
model.save(save_path)
|
|
||||||
assert os.path.exists(os.path.join(save_path, "model.pth"))
|
|
||||||
assert os.path.exists(os.path.join(save_path, "config.json"))
|
|
||||||
|
|
||||||
# Load the model
|
|
||||||
loaded_model = AIIAchunked.load(save_path)
|
|
||||||
|
|
||||||
# Check if the loaded model is an instance of AIIAchunked
|
|
||||||
assert isinstance(loaded_model, AIIAchunked)
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
os.remove(os.path.join(save_path, "model.pth"))
|
|
||||||
os.remove(os.path.join(save_path, "config.json"))
|
|
||||||
os.rmdir(save_path)
|
|
|
@ -1,75 +1,77 @@
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import pytest
|
import pytest
|
||||||
import torch.nn as nn
|
|
||||||
from aiia import AIIAConfig
|
from aiia import AIIAConfig
|
||||||
|
|
||||||
|
|
||||||
def test_aiia_config_initialization():
|
def test_aiia_config_initialization():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
assert config.model_name == "AIIA"
|
assert config.model_type == "AIIA"
|
||||||
assert config.kernel_size == 3
|
assert config.kernel_size == 3
|
||||||
assert config.activation_function == "GELU"
|
assert config.activation_function == "GELU"
|
||||||
assert config.hidden_size == 512
|
assert config.hidden_size == 512
|
||||||
assert config.num_hidden_layers == 12
|
assert config.num_hidden_layers == 12
|
||||||
assert config.num_channels == 3
|
assert config.num_channels == 3
|
||||||
assert config.learning_rate == 5e-5
|
|
||||||
|
|
||||||
def test_aiia_config_custom_initialization():
|
def test_aiia_config_custom_initialization():
|
||||||
config = AIIAConfig(
|
config = AIIAConfig(
|
||||||
model_name="CustomModel",
|
model_type="CustomModel",
|
||||||
kernel_size=5,
|
kernel_size=5,
|
||||||
activation_function="ReLU",
|
activation_function="ReLU",
|
||||||
hidden_size=1024,
|
hidden_size=1024,
|
||||||
num_hidden_layers=8,
|
num_hidden_layers=8,
|
||||||
num_channels=1,
|
num_channels=1
|
||||||
learning_rate=1e-4
|
|
||||||
)
|
)
|
||||||
assert config.model_name == "CustomModel"
|
assert config.model_type == "CustomModel"
|
||||||
assert config.kernel_size == 5
|
assert config.kernel_size == 5
|
||||||
assert config.activation_function == "ReLU"
|
assert config.activation_function == "ReLU"
|
||||||
assert config.hidden_size == 1024
|
assert config.hidden_size == 1024
|
||||||
assert config.num_hidden_layers == 8
|
assert config.num_hidden_layers == 8
|
||||||
assert config.num_channels == 1
|
assert config.num_channels == 1
|
||||||
assert config.learning_rate == 1e-4
|
|
||||||
|
|
||||||
def test_aiia_config_invalid_activation_function():
|
def test_aiia_config_invalid_activation_function():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
AIIAConfig(activation_function="InvalidFunction")
|
AIIAConfig(activation_function="InvalidFunction")
|
||||||
|
|
||||||
|
|
||||||
def test_aiia_config_to_dict():
|
def test_aiia_config_to_dict():
|
||||||
config = AIIAConfig()
|
config = AIIAConfig()
|
||||||
config_dict = config.to_dict()
|
config_dict = config.to_dict()
|
||||||
assert isinstance(config_dict, dict)
|
assert isinstance(config_dict, dict)
|
||||||
assert config_dict["model_name"] == "AIIA"
|
|
||||||
assert config_dict["kernel_size"] == 3
|
assert config_dict["kernel_size"] == 3
|
||||||
|
|
||||||
def test_aiia_config_save_and_load():
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
config = AIIAConfig(model_name="TempModel")
|
|
||||||
save_path = os.path.join(tmpdir, "config")
|
|
||||||
config.save(save_path)
|
|
||||||
|
|
||||||
loaded_config = AIIAConfig.load(save_path)
|
def test_aiia_config_save_pretrained_and_from_pretrained():
|
||||||
assert loaded_config.model_name == "TempModel"
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
config = AIIAConfig(model_type="TempModel")
|
||||||
|
save_pretrained_path = os.path.join(tmpdir, "config")
|
||||||
|
config.save_pretrained(save_pretrained_path)
|
||||||
|
|
||||||
|
loaded_config = AIIAConfig.from_pretrained(save_pretrained_path)
|
||||||
|
assert loaded_config.model_type == "TempModel"
|
||||||
assert loaded_config.kernel_size == 3
|
assert loaded_config.kernel_size == 3
|
||||||
assert loaded_config.activation_function == "GELU"
|
assert loaded_config.activation_function == "GELU"
|
||||||
|
|
||||||
def test_aiia_config_save_and_load_with_custom_attributes():
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
config = AIIAConfig(model_name="TempModel", custom_attr="value")
|
|
||||||
save_path = os.path.join(tmpdir, "config")
|
|
||||||
config.save(save_path)
|
|
||||||
|
|
||||||
loaded_config = AIIAConfig.load(save_path)
|
def test_aiia_config_save_pretrained_and_load_with_custom_attributes():
|
||||||
assert loaded_config.model_name == "TempModel"
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
config = AIIAConfig(model_type="TempModel", custom_attr="value")
|
||||||
|
save_pretrained_path = os.path.join(tmpdir, "config")
|
||||||
|
config.save_pretrained(save_pretrained_path)
|
||||||
|
|
||||||
|
loaded_config = AIIAConfig.from_pretrained(save_pretrained_path)
|
||||||
|
assert loaded_config.model_type == "TempModel"
|
||||||
assert loaded_config.custom_attr == "value"
|
assert loaded_config.custom_attr == "value"
|
||||||
|
|
||||||
def test_aiia_config_save_and_load_with_nested_attributes():
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
config = AIIAConfig(model_name="TempModel", nested={"key": "value"})
|
|
||||||
save_path = os.path.join(tmpdir, "config")
|
|
||||||
config.save(save_path)
|
|
||||||
|
|
||||||
loaded_config = AIIAConfig.load(save_path)
|
def test_aiia_config_save_pretrained_and_load_with_nested_attributes():
|
||||||
assert loaded_config.model_name == "TempModel"
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
config = AIIAConfig(model_type="TempModel", nested={"key": "value"})
|
||||||
|
save_pretrained_path = os.path.join(tmpdir, "config")
|
||||||
|
config.save_pretrained(save_pretrained_path)
|
||||||
|
|
||||||
|
loaded_config = AIIAConfig.from_pretrained(save_pretrained_path)
|
||||||
|
assert loaded_config.model_type == "TempModel"
|
||||||
assert loaded_config.nested == {"key": "value"}
|
assert loaded_config.nested == {"key": "value"}
|
|
@ -3,6 +3,8 @@ import torch
|
||||||
from unittest.mock import MagicMock, patch, MagicMock, mock_open, call
|
from unittest.mock import MagicMock, patch, MagicMock, mock_open, call
|
||||||
from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig
|
from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
|
||||||
# Test the ProjectionHead class
|
# Test the ProjectionHead class
|
||||||
def test_projection_head():
|
def test_projection_head():
|
||||||
|
@ -53,11 +55,94 @@ def test_process_batch(mock_process_batch):
|
||||||
loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
||||||
assert loss == 0.5
|
assert loss == 0.5
|
||||||
|
|
||||||
|
# Error cases
|
||||||
|
# New tests for checkpoint handling
|
||||||
|
@patch('torch.save')
|
||||||
|
@patch('os.path.join')
|
||||||
|
def test_save_checkpoint(mock_join, mock_save):
|
||||||
|
"""Test checkpoint saving functionality."""
|
||||||
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
|
pretrainer.projection_head = MagicMock()
|
||||||
|
pretrainer.optimizer = MagicMock()
|
||||||
|
|
||||||
|
checkpoint_dir = "checkpoints"
|
||||||
|
epoch = 1
|
||||||
|
batch_count = 100
|
||||||
|
checkpoint_name = "test_checkpoint.pt"
|
||||||
|
|
||||||
|
mock_join.return_value = os.path.join(checkpoint_dir, checkpoint_name)
|
||||||
|
|
||||||
|
path = pretrainer._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name)
|
||||||
|
|
||||||
|
assert path == os.path.join(checkpoint_dir, checkpoint_name)
|
||||||
|
mock_save.assert_called_once()
|
||||||
|
|
||||||
|
@patch('os.makedirs')
|
||||||
|
@patch('os.path.exists')
|
||||||
|
@patch('torch.load')
|
||||||
|
def test_load_checkpoint_specific(mock_load, mock_exists, mock_makedirs):
|
||||||
|
"""Test loading a specific checkpoint."""
|
||||||
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
|
pretrainer.projection_head = MagicMock()
|
||||||
|
pretrainer.optimizer = MagicMock()
|
||||||
|
|
||||||
|
checkpoint_dir = "checkpoints"
|
||||||
|
specific_checkpoint = "specific_checkpoint.pt"
|
||||||
|
mock_exists.return_value = True
|
||||||
|
|
||||||
|
mock_load.return_value = {
|
||||||
|
'epoch': 2,
|
||||||
|
'batch': 150,
|
||||||
|
'model_state_dict': {},
|
||||||
|
'projection_head_state_dict': {},
|
||||||
|
'optimizer_state_dict': {},
|
||||||
|
'train_losses': [],
|
||||||
|
'val_losses': []
|
||||||
|
}
|
||||||
|
|
||||||
|
result = pretrainer.load_checkpoint(checkpoint_dir, specific_checkpoint)
|
||||||
|
assert result == (2, 150)
|
||||||
|
|
||||||
|
@patch('os.listdir')
|
||||||
|
@patch('os.path.getmtime')
|
||||||
|
def test_load_checkpoint_most_recent(mock_getmtime, mock_listdir):
|
||||||
|
"""Test loading the most recent checkpoint."""
|
||||||
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
|
|
||||||
|
checkpoint_dir = "checkpoints"
|
||||||
|
mock_listdir.return_value = ["checkpoint_1.pt", "checkpoint_2.pt"]
|
||||||
|
mock_getmtime.side_effect = [100, 200] # checkpoint_2.pt is more recent
|
||||||
|
|
||||||
|
with patch.object(pretrainer, '_load_checkpoint_file', return_value=(2, 150)):
|
||||||
|
result = pretrainer.load_checkpoint(checkpoint_dir)
|
||||||
|
assert result == (2, 150)
|
||||||
|
|
||||||
|
def test_initialize_checkpoint_variables():
|
||||||
|
"""Test initialization of checkpoint variables."""
|
||||||
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
|
pretrainer._initialize_checkpoint_variables()
|
||||||
|
|
||||||
|
assert hasattr(pretrainer, 'last_checkpoint_time')
|
||||||
|
assert pretrainer.checkpoint_interval == 2 * 60 * 60
|
||||||
|
assert pretrainer.last_22_date is None
|
||||||
|
assert pretrainer.recent_checkpoints == []
|
||||||
|
|
||||||
|
@patch('torch.nn.MSELoss')
|
||||||
|
@patch('torch.nn.CrossEntropyLoss')
|
||||||
|
def test_initialize_loss_functions(mock_ce_loss, mock_mse_loss):
|
||||||
|
"""Test loss function initialization."""
|
||||||
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
|
criterion_denoise, criterion_rotate, best_val_loss = pretrainer._initialize_loss_functions()
|
||||||
|
|
||||||
|
assert mock_mse_loss.called
|
||||||
|
assert mock_ce_loss.called
|
||||||
|
assert best_val_loss == float('inf')
|
||||||
|
|
||||||
@patch('pandas.concat')
|
@patch('pandas.concat')
|
||||||
@patch('pandas.read_parquet')
|
@patch('pandas.read_parquet')
|
||||||
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
||||||
@patch('os.path.join', return_value='mocked/path/model.pt')
|
@patch('os.path.join', return_value='mocked/path/model.pt')
|
||||||
@patch('builtins.print') # Add this to mock the print function
|
@patch('builtins.print')
|
||||||
def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat):
|
def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat):
|
||||||
"""Test the train method under normal conditions with comprehensive verification."""
|
"""Test the train method under normal conditions with comprehensive verification."""
|
||||||
# Setup test data and mocks
|
# Setup test data and mocks
|
||||||
|
@ -73,6 +158,7 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea
|
||||||
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
|
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
|
||||||
pretrainer.projection_head = mock_projection_head
|
pretrainer.projection_head = mock_projection_head
|
||||||
pretrainer.optimizer = MagicMock()
|
pretrainer.optimizer = MagicMock()
|
||||||
|
pretrainer.checkpoint_dir = None # Initialize checkpoint_dir
|
||||||
|
|
||||||
# Setup dataset paths and mock batch data
|
# Setup dataset paths and mock batch data
|
||||||
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
|
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
|
||||||
|
@ -104,185 +190,118 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea
|
||||||
assert mock_process_batch.call_count == 2
|
assert mock_process_batch.call_count == 2
|
||||||
assert mock_validate.call_count == 2
|
assert mock_validate.call_count == 2
|
||||||
|
|
||||||
# Check for "Best model saved!" instead of model.save()
|
|
||||||
mock_print.assert_any_call("Best model saved!")
|
mock_print.assert_any_call("Best model saved!")
|
||||||
|
|
||||||
mock_save_losses.assert_called_once()
|
mock_save_losses.assert_called_once()
|
||||||
|
|
||||||
# Verify state changes
|
|
||||||
assert len(pretrainer.train_losses) == 2
|
assert len(pretrainer.train_losses) == 2
|
||||||
assert pretrainer.train_losses == [0.5, 0.5]
|
assert pretrainer.train_losses == [0.5, 0.5]
|
||||||
|
|
||||||
|
@patch('datetime.datetime')
|
||||||
# Error cases
|
@patch('time.time')
|
||||||
def test_train_no_dataset_paths():
|
def test_handle_checkpoints(mock_time, mock_datetime):
|
||||||
"""Test ValueError when no dataset paths are provided."""
|
"""Test checkpoint handling logic."""
|
||||||
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
|
pretrainer.checkpoint_dir = "checkpoints"
|
||||||
|
pretrainer.current_epoch = 1
|
||||||
|
pretrainer._initialize_checkpoint_variables()
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="No dataset paths provided"):
|
# Set a base time value
|
||||||
pretrainer.train([])
|
base_time = 1000
|
||||||
|
# Set the last checkpoint time to base_time
|
||||||
|
pretrainer.last_checkpoint_time = base_time
|
||||||
|
|
||||||
@patch('pandas.read_parquet')
|
# Mock time to return base_time + interval + 1 to trigger checkpoint save
|
||||||
def test_train_all_datasets_fail(mock_read_parquet):
|
mock_time.return_value = base_time + pretrainer.checkpoint_interval + 1
|
||||||
"""Test handling when all datasets fail to load."""
|
|
||||||
mock_read_parquet.side_effect = Exception("Failed to load dataset")
|
|
||||||
|
|
||||||
|
# Mock datetime for 22:00 checkpoint
|
||||||
|
mock_dt = MagicMock()
|
||||||
|
mock_dt.hour = 22
|
||||||
|
mock_dt.minute = 0
|
||||||
|
mock_dt.date.return_value = datetime.date(2023, 1, 1)
|
||||||
|
mock_datetime.now.return_value = mock_dt
|
||||||
|
|
||||||
|
with patch.object(pretrainer, '_save_checkpoint') as mock_save:
|
||||||
|
pretrainer._handle_checkpoints(100)
|
||||||
|
# Should be called twice - once for regular interval and once for 22:00
|
||||||
|
assert mock_save.call_count == 2
|
||||||
|
|
||||||
|
def test_training_phase():
|
||||||
|
"""Test the training phase logic."""
|
||||||
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="No valid datasets could be loaded"):
|
|
||||||
pretrainer.train(dataset_paths)
|
|
||||||
|
|
||||||
# Edge cases
|
|
||||||
@patch('pandas.concat')
|
|
||||||
@patch('pandas.read_parquet')
|
|
||||||
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
|
||||||
def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat):
|
|
||||||
"""Test behavior with empty data loaders."""
|
|
||||||
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
|
|
||||||
mock_read_parquet.return_value.head.return_value = real_df
|
|
||||||
mock_concat.return_value = real_df
|
|
||||||
|
|
||||||
loader_instance = MagicMock()
|
|
||||||
loader_instance.train_loader = [] # Empty train loader
|
|
||||||
loader_instance.val_loader = [] # Empty val loader
|
|
||||||
mock_data_loader.return_value = loader_instance
|
|
||||||
|
|
||||||
mock_model = MagicMock()
|
|
||||||
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
|
|
||||||
pretrainer.projection_head = MagicMock()
|
|
||||||
pretrainer.optimizer = MagicMock()
|
pretrainer.optimizer = MagicMock()
|
||||||
|
pretrainer.checkpoint_dir = None # Initialize checkpoint_dir
|
||||||
|
pretrainer._initialize_checkpoint_variables()
|
||||||
|
pretrainer.current_epoch = 0
|
||||||
|
|
||||||
with patch.object(Pretrainer, 'save_losses') as mock_save_losses:
|
# Create mock batch data with requires_grad=True
|
||||||
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
|
|
||||||
|
|
||||||
# Verify empty loader behavior
|
|
||||||
assert len(pretrainer.train_losses) == 1
|
|
||||||
assert pretrainer.train_losses[0] == 0.0
|
|
||||||
mock_save_losses.assert_called_once()
|
|
||||||
|
|
||||||
@patch('pandas.concat')
|
|
||||||
@patch('pandas.read_parquet')
|
|
||||||
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
|
||||||
def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat):
|
|
||||||
"""Test behavior when batch_data is None."""
|
|
||||||
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
|
|
||||||
mock_read_parquet.return_value.head.return_value = real_df
|
|
||||||
mock_concat.return_value = real_df
|
|
||||||
|
|
||||||
loader_instance = MagicMock()
|
|
||||||
loader_instance.train_loader = [None] # Loader returns None
|
|
||||||
loader_instance.val_loader = []
|
|
||||||
mock_data_loader.return_value = loader_instance
|
|
||||||
|
|
||||||
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
|
||||||
pretrainer.projection_head = MagicMock()
|
|
||||||
pretrainer.optimizer = MagicMock()
|
|
||||||
|
|
||||||
with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \
|
|
||||||
patch.object(Pretrainer, 'save_losses'):
|
|
||||||
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
|
|
||||||
|
|
||||||
# Verify None batch handling
|
|
||||||
mock_process_batch.assert_not_called()
|
|
||||||
assert pretrainer.train_losses[0] == 0.0
|
|
||||||
|
|
||||||
# Parameter variations
|
|
||||||
@patch('pandas.concat')
|
|
||||||
@patch('pandas.read_parquet')
|
|
||||||
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
|
||||||
def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_concat):
|
|
||||||
"""Test that custom parameters are properly passed through."""
|
|
||||||
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
|
|
||||||
mock_read_parquet.return_value.head.return_value = real_df
|
|
||||||
mock_concat.return_value = real_df
|
|
||||||
|
|
||||||
loader_instance = MagicMock()
|
|
||||||
loader_instance.train_loader = []
|
|
||||||
loader_instance.val_loader = []
|
|
||||||
mock_data_loader.return_value = loader_instance
|
|
||||||
|
|
||||||
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
|
||||||
pretrainer.projection_head = MagicMock()
|
|
||||||
pretrainer.optimizer = MagicMock()
|
|
||||||
|
|
||||||
# Custom parameters
|
|
||||||
custom_output_path = "custom/output/path"
|
|
||||||
custom_column = "custom_column"
|
|
||||||
custom_batch_size = 16
|
|
||||||
custom_sample_size = 5000
|
|
||||||
|
|
||||||
with patch.object(Pretrainer, 'save_losses'):
|
|
||||||
pretrainer.train(
|
|
||||||
['path/to/dataset.parquet'],
|
|
||||||
output_path=custom_output_path,
|
|
||||||
column=custom_column,
|
|
||||||
batch_size=custom_batch_size,
|
|
||||||
sample_size=custom_sample_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify custom parameters were used
|
|
||||||
mock_read_parquet.return_value.head.assert_called_once_with(custom_sample_size)
|
|
||||||
assert mock_data_loader.call_args[1]['column'] == custom_column
|
|
||||||
assert mock_data_loader.call_args[1]['batch_size'] == custom_batch_size
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@patch('pandas.concat')
|
|
||||||
@patch('pandas.read_parquet')
|
|
||||||
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
|
||||||
@patch('builtins.print') # Add this to mock the print function
|
|
||||||
def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat):
|
|
||||||
"""Test that model is saved only when validation loss improves."""
|
|
||||||
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
|
|
||||||
mock_read_parquet.return_value.head.return_value = real_df
|
|
||||||
mock_concat.return_value = real_df
|
|
||||||
|
|
||||||
# Create mock batch data with proper structure
|
|
||||||
mock_batch_data = {
|
mock_batch_data = {
|
||||||
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)),
|
'denoise': (
|
||||||
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1]))
|
torch.randn(2, 3, 32, 32, requires_grad=True),
|
||||||
|
torch.randn(2, 3, 32, 32, requires_grad=True)
|
||||||
|
),
|
||||||
|
'rotate': (
|
||||||
|
torch.randn(2, 3, 32, 32, requires_grad=True),
|
||||||
|
torch.tensor([0, 1], dtype=torch.long) # Labels typically don't need gradients
|
||||||
|
)
|
||||||
|
}
|
||||||
|
mock_train_loader = [(0, mock_batch_data)] # Include batch index
|
||||||
|
|
||||||
|
# Mock the loss functions to return tensors that require gradients
|
||||||
|
criterion_denoise = MagicMock(return_value=torch.tensor(0.5, requires_grad=True))
|
||||||
|
criterion_rotate = MagicMock(return_value=torch.tensor(0.5, requires_grad=True))
|
||||||
|
|
||||||
|
with patch.object(pretrainer, '_process_batch', return_value=torch.tensor(0.5, requires_grad=True)), \
|
||||||
|
patch.object(pretrainer, '_handle_checkpoints') as mock_handle_checkpoints:
|
||||||
|
|
||||||
|
total_loss, batch_count = pretrainer._training_phase(
|
||||||
|
mock_train_loader, 0, criterion_denoise, criterion_rotate)
|
||||||
|
|
||||||
|
assert total_loss == 0.5
|
||||||
|
assert batch_count == 1
|
||||||
|
mock_handle_checkpoints.assert_called_once_with(1) # Check if checkpoint handling was called
|
||||||
|
|
||||||
|
def test_validation_phase():
|
||||||
|
"""Test the validation phase logic."""
|
||||||
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
|
pretrainer.projection_head = MagicMock()
|
||||||
|
|
||||||
|
mock_val_loader = [MagicMock()]
|
||||||
|
criterion_denoise = MagicMock()
|
||||||
|
criterion_rotate = MagicMock()
|
||||||
|
|
||||||
|
with patch.object(pretrainer, '_validate', return_value=0.4):
|
||||||
|
val_loss = pretrainer._validation_phase(
|
||||||
|
mock_val_loader, criterion_denoise, criterion_rotate)
|
||||||
|
|
||||||
|
assert val_loss == 0.4
|
||||||
|
|
||||||
|
@patch('pandas.read_parquet')
|
||||||
|
def test_load_and_merge_datasets(mock_read_parquet):
|
||||||
|
"""Test dataset loading and merging."""
|
||||||
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
|
|
||||||
|
mock_df = pd.DataFrame({'col': [1, 2, 3]})
|
||||||
|
mock_read_parquet.return_value.head.return_value = mock_df
|
||||||
|
|
||||||
|
result = pretrainer._load_and_merge_datasets(['path1.parquet', 'path2.parquet'], 1000)
|
||||||
|
assert len(result) == 6 # 2 datasets * 3 rows each
|
||||||
|
|
||||||
|
def test_process_batch_none_tasks():
|
||||||
|
"""Test processing batch with no tasks."""
|
||||||
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
|
|
||||||
|
batch_data = {
|
||||||
|
'denoise': None,
|
||||||
|
'rotate': None
|
||||||
}
|
}
|
||||||
|
|
||||||
loader_instance = MagicMock()
|
loss = pretrainer._process_batch(
|
||||||
loader_instance.train_loader = [mock_batch_data]
|
batch_data,
|
||||||
loader_instance.val_loader = [mock_batch_data]
|
criterion_denoise=MagicMock(),
|
||||||
mock_data_loader.return_value = loader_instance
|
criterion_rotate=MagicMock()
|
||||||
|
)
|
||||||
|
|
||||||
mock_model = MagicMock()
|
assert loss == 0
|
||||||
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
|
|
||||||
pretrainer.projection_head = MagicMock()
|
|
||||||
pretrainer.optimizer = MagicMock()
|
|
||||||
|
|
||||||
# Initialize the best validation loss
|
|
||||||
pretrainer.best_val_loss = float('inf')
|
|
||||||
|
|
||||||
mock_batch_loss = torch.tensor(0.5, requires_grad=True)
|
|
||||||
|
|
||||||
# Test improving validation loss
|
|
||||||
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
|
|
||||||
patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \
|
|
||||||
patch.object(Pretrainer, 'save_losses'):
|
|
||||||
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
|
|
||||||
|
|
||||||
# Check for "Best model saved!" 3 times
|
|
||||||
assert mock_print.call_args_list.count(call("Best model saved!")) == 3
|
|
||||||
|
|
||||||
# Reset for next test
|
|
||||||
mock_print.reset_mock()
|
|
||||||
pretrainer.train_losses = []
|
|
||||||
|
|
||||||
# Reset best validation loss for the second test
|
|
||||||
pretrainer.best_val_loss = float('inf')
|
|
||||||
|
|
||||||
# Test fluctuating validation loss
|
|
||||||
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
|
|
||||||
patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \
|
|
||||||
patch.object(Pretrainer, 'save_losses'):
|
|
||||||
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
|
|
||||||
|
|
||||||
# Should print "Best model saved!" only on first and third epochs
|
|
||||||
assert mock_print.call_args_list.count(call("Best model saved!")) == 2
|
|
||||||
|
|
||||||
|
|
||||||
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
|
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
|
||||||
|
|
Loading…
Reference in New Issue