feat/tests #32

Merged
Fabel merged 15 commits from feat/tests into main 2025-03-16 12:10:20 +00:00
14 changed files with 801 additions and 51 deletions

10
.coveragerc Normal file
View File

@ -0,0 +1,10 @@
[run]
branch = True
source = src
omit =
*/tests/*
*/migrations/*
[report]
show_missing = True
fail_under = 80

View File

@ -0,0 +1,36 @@
name: Run VectorLoader Script
on:
push:
branches:
- main
jobs:
Explore-Gitea-Actions:
runs-on: ubuntu-latest
container: catthehacker/ubuntu:act-latest
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.11.7'
- name: Clone additional repository
run: |
git config --global credential.helper cache
git clone https://fabel:${{ secrets.CICD }}@gitea.fabelous.app/fabel/VectorLoader.git
- name: Install dependencies
run: |
cd VectorLoader
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Run vectorizing
env:
VECTORDB_TOKEN: ${{ secrets.VECTORDB_TOKEN }}
run: |
python -m src.run --full

View File

@ -0,0 +1,37 @@
name: Gitea Actions For AIIA
run-name: ${{ gitea.actor }} is testing out Gitea Actions 🚀
on: [push]
jobs:
Explore-Gitea-Actions:
runs-on: ubuntu-latest
container: catthehacker/ubuntu:act-latest
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11.7'
- name: Cache pip and model
uses: actions/cache@v3
with:
path: |
~/.cache/pip
./fabel
key: ${{ runner.os }}-pip-model-${{ hashFiles('requirements.txt', 'requirements-dev.txt') }}
restore-keys: |
${{ runner.os }}-pip-model-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-dev.txt
pip install -e .
- name: Run tests
run: |
pytest tests/

View File

@ -10,7 +10,7 @@ include = '\.pyi?$'
[project]
name = "aiia"
version = "0.1.6"
version = "0.2.0"
description = "AIIA Deep Learning Model Implementation"
readme = "README.md"
authors = [

3
pytest.ini Normal file
View File

@ -0,0 +1,3 @@
[pytest]
testpaths = tests/
python_files = test_*.py

2
requirements-dev.txt Normal file
View File

@ -0,0 +1,2 @@
pytest
pytest-mock

View File

@ -1,6 +1,6 @@
[metadata]
name = aiia
version = 0.1.6
version = 0.2.0
author = Falko Habel
author_email = falko.habel@gmx.de
description = AIIA deep learning model implementation

View File

@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader
from .pretrain.pretrainer import Pretrainer, ProjectionHead
__version__ = "0.1.6"
__version__ = "0.2.0"

View File

@ -14,10 +14,10 @@ class FilePathLoader:
self.label_column = label_column
self.successful_count = 0
self.skipped_count = 0
if self.file_path_column not in dataset.column_names:
if self.file_path_column not in dataset.columns:
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
def _get_image(self, item):
try:
path = item[self.file_path_column]
@ -32,7 +32,7 @@ class FilePathLoader:
except Exception as e:
print(f"Error loading image from {path}: {e}")
return None
def get_item(self, idx):
item = self.dataset.iloc[idx]
image = self._get_image(item)
@ -46,7 +46,7 @@ class FilePathLoader:
else:
self.skipped_count += 1
return None
def print_summary(self):
print(f"Successfully converted {self.successful_count} images.")
print(f"Skipped {self.skipped_count} images due to errors.")
@ -58,14 +58,14 @@ class JPGImageLoader:
self.label_column = label_column
self.successful_count = 0
self.skipped_count = 0
if self.bytes_column not in dataset.columns:
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
def _get_image(self, item):
try:
data = item[self.bytes_column]
if isinstance(data, str) and data.startswith("b'"):
cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
bytes_data = cleaned_data
@ -73,7 +73,7 @@ class JPGImageLoader:
bytes_data = base64.b64decode(data)
else:
bytes_data = data
img_bytes = io.BytesIO(bytes_data)
image = Image.open(img_bytes)
if image.mode == 'RGBA':
@ -86,7 +86,7 @@ class JPGImageLoader:
except Exception as e:
print(f"Error loading image from bytes: {e}")
return None
def get_item(self, idx):
item = self.dataset.iloc[idx]
image = self._get_image(item)
@ -100,37 +100,41 @@ class JPGImageLoader:
else:
self.skipped_count += 1
return None
def print_summary(self):
print(f"Successfully converted {self.successful_count} images.")
print(f"Skipped {self.skipped_count} images due to errors.")
class AIIADataLoader:
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path",
label_column=None, pretraining=False, **dataloader_kwargs):
if column not in dataset.columns:
raise ValueError(f"Column '{column}' not found in dataset")
self.batch_size = batch_size
self.val_split = val_split
self.seed = seed
self.pretraining = pretraining
random.seed(seed)
sample_value = dataset[column].iloc[0]
is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
isinstance(sample_value, bytes) or
sample_value.startswith("b'") or
isinstance(sample_value, bytes) or
sample_value.startswith("b'") or
sample_value.startswith(('b"', 'data:image'))
)
if is_bytes_or_bytestring:
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
else:
sample_paths = dataset[column].dropna().head(1).astype(str)
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
else:
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
self.items = []
for idx in range(len(dataset)):
item = self.loader.get_item(idx)
@ -141,33 +145,32 @@ class AIIADataLoader:
self.items.append((img, 'rotate', 0))
else:
self.items.append(item)
if not self.items:
raise ValueError("No valid items were loaded from the dataset")
train_indices, val_indices = self._split_data()
self.train_dataset = self._create_subset(train_indices)
self.val_dataset = self._create_subset(val_indices)
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
def _split_data(self):
if len(self.items) == 0:
raise ValueError("No items to split")
num_samples = len(self.items)
indices = list(range(num_samples))
random.shuffle(indices)
split_idx = int((1 - self.val_split) * num_samples)
train_indices = indices[:split_idx]
val_indices = indices[split_idx:]
return train_indices, val_indices
def _create_subset(self, indices):
subset_items = [self.items[i] for i in indices]
return AIIADataset(subset_items, pretraining=self.pretraining)
@ -180,22 +183,24 @@ class AIIADataset(torch.utils.data.Dataset):
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
item = self.items[idx]
if self.pretraining:
image, task, label = item
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
# Check image dimensions before transform
if image.size[0] < 224 or image.size[1] < 224:
raise ValueError("Invalid image dimensions")
image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
if task == 'denoise':
noise_std = 0.1
noisy_img = image + torch.randn_like(image) * noise_std
@ -214,15 +219,20 @@ class AIIADataset(torch.utils.data.Dataset):
image, label = item
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
# Check image dimensions before transform
if image.size[0] < 224 or image.size[1] < 224:
raise ValueError("Invalid image dimensions")
image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
return image, label
else:
if isinstance(item, Image.Image):
image = self.transform(item)
else:
image = self.transform(item[0])
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
image = item[0] if isinstance(item, tuple) else item
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
# Check image dimensions before transform
if image.size[0] < 224 or image.size[1] < 224:
raise ValueError("Invalid image dimensions")
image = self.transform(image)
return image

View File

@ -23,12 +23,36 @@ class AIIA(nn.Module):
self.config.save(path)
@classmethod
def load(cls, path, precision: str = None, **kwargs):
def load(cls, path, precision: str = None, strict: bool = True, **kwargs):
config = AIIAConfig.load(path)
model = cls(config, **kwargs) # Pass kwargs here!
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the state dict to analyze structure
model_dict = torch.load(f"{path}/model.pth", map_location=device)
# Special handling for AIIAmoe - detect number of experts from state_dict
if cls.__name__ == "AIIAmoe" and "num_experts" not in kwargs:
# Find maximum expert index
max_expert_idx = -1
for key in model_dict.keys():
if key.startswith("experts."):
parts = key.split(".")
if len(parts) > 1:
try:
expert_idx = int(parts[1])
max_expert_idx = max(max_expert_idx, expert_idx)
except ValueError:
pass
if max_expert_idx >= 0:
# experts.X keys found, use max_expert_idx + 1 as num_experts
kwargs["num_experts"] = max_expert_idx + 1
# Create model with detected structural parameters
model = cls(config, **kwargs)
# Handle precision conversion
dtype = None
if precision is not None:
if precision.lower() == 'fp16':
dtype = torch.float16
@ -40,14 +64,14 @@ class AIIA(nn.Module):
dtype = torch.bfloat16
else:
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
model_dict = torch.load(f"{path}/model.pth", map_location=device)
if dtype is not None:
for key, param in model_dict.items():
if torch.is_tensor(param):
model_dict[key] = param.to(dtype)
model.load_state_dict(model_dict)
# Load state dict with strict parameter for flexibility
model.load_state_dict(model_dict, strict=strict)
return model

View File

@ -0,0 +1,112 @@
import pytest
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import pandas as pd
import numpy as np
from aiia.data.DataLoader import FilePathLoader, JPGImageLoader, AIIADataLoader, AIIADataset
def create_sample_dataset(file_paths=None):
if file_paths is None:
file_paths = ['path/to/image1.jpg', 'path/to/image2.png']
data = {
'file_path': file_paths,
'label': [0] * len(file_paths) # Match length of labels to file_paths
}
df = pd.DataFrame(data)
return df
def create_sample_bytes_dataset(bytes_data=None):
if bytes_data is None:
bytes_data = [b'fake_image_data_1', b'fake_image_data_2']
data = {
'jpg': bytes_data,
'label': [0] * len(bytes_data) # Match length of labels to bytes_data
}
df = pd.DataFrame(data)
return df
def test_file_path_loader(mocker):
# Mock Image.open to return a fake image
mock_image = Image.new('RGB', (224, 224))
mocker.patch('PIL.Image.open', return_value=mock_image)
dataset = create_sample_dataset()
loader = FilePathLoader(dataset, label_column='label') # Added label_column
item = loader.get_item(0)
assert isinstance(item[0], Image.Image)
assert item[1] == 0
loader.print_summary()
def test_jpg_image_loader(mocker):
# Mock Image.open to return a fake image
mock_image = Image.new('RGB', (224, 224))
mocker.patch('PIL.Image.open', return_value=mock_image)
dataset = create_sample_bytes_dataset()
loader = JPGImageLoader(dataset, label_column='label') # Added label_column
item = loader.get_item(0)
assert isinstance(item[0], Image.Image)
assert item[1] == 0
loader.print_summary()
def test_aiia_data_loader(mocker):
# Mock Image.open to return a fake image
mock_image = Image.new('RGB', (224, 224))
mocker.patch('PIL.Image.open', return_value=mock_image)
dataset = create_sample_dataset()
data_loader = AIIADataLoader(dataset, batch_size=2, label_column='label')
# Test train loader
batch = next(iter(data_loader.train_loader))
assert isinstance(batch, list)
assert len(batch) == 2 # (images, labels)
assert batch[0].shape[0] == 1 # batch size
def test_aiia_dataset():
items = [(Image.new('RGB', (224, 224)), 0), (Image.new('RGB', (224, 224)), 1)]
dataset = AIIADataset(items)
assert len(dataset) == 2
item = dataset[0]
assert isinstance(item[0], torch.Tensor)
assert item[1] == 0
def test_aiia_dataset_pre_training():
items = [(Image.new('RGB', (224, 224)), 'denoise', Image.new('RGB', (224, 224)))]
dataset = AIIADataset(items, pretraining=True)
assert len(dataset) == 1
item = dataset[0]
assert isinstance(item[0], torch.Tensor)
assert isinstance(item[2], str)
def test_aiia_dataset_invalid_image():
items = [(Image.new('RGB', (50, 50)), 0)] # Create small image
dataset = AIIADataset(items)
with pytest.raises(ValueError, match="Invalid image dimensions"):
dataset[0]
def test_aiia_dataset_invalid_task():
items = [(Image.new('RGB', (224, 224)), 'invalid_task', Image.new('RGB', (224, 224)))]
dataset = AIIADataset(items, pretraining=True)
with pytest.raises(ValueError):
dataset[0]
def test_aiia_data_loader_invalid_column():
dataset = create_sample_dataset()
with pytest.raises(ValueError, match="Column 'invalid_column' not found"):
AIIADataLoader(dataset, column='invalid_column')
if __name__ == "__main__":
pytest.main(['-v'])

133
tests/model/test_aiia.py Normal file
View File

@ -0,0 +1,133 @@
import os
import torch
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig
def test_aiiabase_creation():
config = AIIAConfig()
model = AIIABase(config)
assert isinstance(model, AIIABase)
def test_aiiabase_save_load():
config = AIIAConfig()
model = AIIABase(config)
save_path = "test_aiiabase_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIABase.load(save_path)
# Check if the loaded model is an instance of AIIABase
assert isinstance(loaded_model, AIIABase)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiabase_shared_creation():
config = AIIAConfig()
model = AIIABaseShared(config)
assert isinstance(model, AIIABaseShared)
def test_aiiabase_shared_save_load():
config = AIIAConfig()
model = AIIABaseShared(config)
save_path = "test_aiiabase_shared_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIABaseShared.load(save_path)
# Check if the loaded model is an instance of AIIABaseShared
assert isinstance(loaded_model, AIIABaseShared)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiaexpert_creation():
config = AIIAConfig()
model = AIIAExpert(config)
assert isinstance(model, AIIAExpert)
def test_aiiaexpert_save_load():
config = AIIAConfig()
model = AIIAExpert(config)
save_path = "test_aiiaexpert_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIAExpert.load(save_path)
# Check if the loaded model is an instance of AIIAExpert
assert isinstance(loaded_model, AIIAExpert)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiamoe_creation():
config = AIIAConfig()
model = AIIAmoe(config, num_experts=5)
assert isinstance(model, AIIAmoe)
def test_aiiamoe_save_load():
config = AIIAConfig()
model = AIIAmoe(config, num_experts=5)
save_path = "test_aiiamoe_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIAmoe.load(save_path)
# Check if the loaded model is an instance of AIIAmoe
assert isinstance(loaded_model, AIIAmoe)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiachunked_creation():
config = AIIAConfig()
model = AIIAchunked(config)
assert isinstance(model, AIIAchunked)
def test_aiiachunked_save_load():
config = AIIAConfig()
model = AIIAchunked(config)
save_path = "test_aiiachunked_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIAchunked.load(save_path)
# Check if the loaded model is an instance of AIIAchunked
assert isinstance(loaded_model, AIIAchunked)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)

View File

@ -0,0 +1,75 @@
import os
import tempfile
import pytest
import torch.nn as nn
from aiia import AIIAConfig
def test_aiia_config_initialization():
config = AIIAConfig()
assert config.model_name == "AIIA"
assert config.kernel_size == 3
assert config.activation_function == "GELU"
assert config.hidden_size == 512
assert config.num_hidden_layers == 12
assert config.num_channels == 3
assert config.learning_rate == 5e-5
def test_aiia_config_custom_initialization():
config = AIIAConfig(
model_name="CustomModel",
kernel_size=5,
activation_function="ReLU",
hidden_size=1024,
num_hidden_layers=8,
num_channels=1,
learning_rate=1e-4
)
assert config.model_name == "CustomModel"
assert config.kernel_size == 5
assert config.activation_function == "ReLU"
assert config.hidden_size == 1024
assert config.num_hidden_layers == 8
assert config.num_channels == 1
assert config.learning_rate == 1e-4
def test_aiia_config_invalid_activation_function():
with pytest.raises(ValueError):
AIIAConfig(activation_function="InvalidFunction")
def test_aiia_config_to_dict():
config = AIIAConfig()
config_dict = config.to_dict()
assert isinstance(config_dict, dict)
assert config_dict["model_name"] == "AIIA"
assert config_dict["kernel_size"] == 3
def test_aiia_config_save_and_load():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel")
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
loaded_config = AIIAConfig.load(save_path)
assert loaded_config.model_name == "TempModel"
assert loaded_config.kernel_size == 3
assert loaded_config.activation_function == "GELU"
def test_aiia_config_save_and_load_with_custom_attributes():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel", custom_attr="value")
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
loaded_config = AIIAConfig.load(save_path)
assert loaded_config.model_name == "TempModel"
assert loaded_config.custom_attr == "value"
def test_aiia_config_save_and_load_with_nested_attributes():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel", nested={"key": "value"})
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
loaded_config = AIIAConfig.load(save_path)
assert loaded_config.model_name == "TempModel"
assert loaded_config.nested == {"key": "value"}

View File

@ -0,0 +1,308 @@
import pytest
import torch
from unittest.mock import MagicMock, patch, MagicMock, mock_open, call
from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig
import pandas as pd
# Test the ProjectionHead class
def test_projection_head():
head = ProjectionHead(hidden_size=512)
x = torch.randn(1, 512, 32, 32)
# Test denoise task
output_denoise = head(x, task='denoise')
assert output_denoise.shape == (1, 3, 32, 32)
# Test rotate task
output_rotate = head(x, task='rotate')
assert output_rotate.shape == (1, 4)
# Test the Pretrainer class initialization
def test_pretrainer_initialization():
config = AIIAConfig()
model = AIIABase(config=config)
pretrainer = Pretrainer(model=model, learning_rate=0.001, config=config)
assert pretrainer.device in ["cuda", "cpu"]
assert isinstance(pretrainer.projection_head, ProjectionHead)
assert isinstance(pretrainer.optimizer, torch.optim.AdamW)
# Test the safe_collate method
def test_safe_collate():
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
batch = [
(torch.randn(3, 32, 32), torch.randn(3, 32, 32), 'denoise'),
(torch.randn(3, 32, 32), torch.tensor(1), 'rotate')
]
collated_batch = pretrainer.safe_collate(batch)
assert 'denoise' in collated_batch
assert 'rotate' in collated_batch
# Test the _process_batch method
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
def test_process_batch(mock_process_batch):
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
batch_data = {
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)),
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1]))
}
criterion_denoise = MagicMock()
criterion_rotate = MagicMock()
mock_process_batch.return_value = 0.5
loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate)
assert loss == 0.5
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
@patch('os.path.join', return_value='mocked/path/model.pt')
@patch('builtins.print') # Add this to mock the print function
def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat):
"""Test the train method under normal conditions with comprehensive verification."""
# Setup test data and mocks
real_df = pd.DataFrame({
'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]
})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
# Mock the model and related components
mock_model = MagicMock()
mock_projection_head = MagicMock()
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
pretrainer.projection_head = mock_projection_head
pretrainer.optimizer = MagicMock()
# Setup dataset paths and mock batch data
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
output_path = "AIIA_test"
# Create mock batch data with proper structure
mock_batch_data = {
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)),
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1]))
}
# Configure batch loss
mock_batch_loss = torch.tensor(0.5, requires_grad=True)
loader_instance = MagicMock()
loader_instance.train_loader = [mock_batch_data]
loader_instance.val_loader = [mock_batch_data]
mock_data_loader.return_value = loader_instance
# Execute training with patched methods
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss) as mock_process_batch, \
patch.object(Pretrainer, '_validate', side_effect=[0.8, 0.3]) as mock_validate, \
patch.object(Pretrainer, 'save_losses') as mock_save_losses, \
patch('builtins.open', mock_open()):
pretrainer.train(dataset_paths, output_path=output_path, num_epochs=2)
# Verify method calls
assert mock_read_parquet.call_count == len(dataset_paths)
assert mock_process_batch.call_count == 2
assert mock_validate.call_count == 2
# Check for "Best model saved!" instead of model.save()
mock_print.assert_any_call("Best model saved!")
mock_save_losses.assert_called_once()
# Verify state changes
assert len(pretrainer.train_losses) == 2
assert pretrainer.train_losses == [0.5, 0.5]
# Error cases
def test_train_no_dataset_paths():
"""Test ValueError when no dataset paths are provided."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
with pytest.raises(ValueError, match="No dataset paths provided"):
pretrainer.train([])
@patch('pandas.read_parquet')
def test_train_all_datasets_fail(mock_read_parquet):
"""Test handling when all datasets fail to load."""
mock_read_parquet.side_effect = Exception("Failed to load dataset")
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
with pytest.raises(ValueError, match="No valid datasets could be loaded"):
pretrainer.train(dataset_paths)
# Edge cases
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat):
"""Test behavior with empty data loaders."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
loader_instance = MagicMock()
loader_instance.train_loader = [] # Empty train loader
loader_instance.val_loader = [] # Empty val loader
mock_data_loader.return_value = loader_instance
mock_model = MagicMock()
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
with patch.object(Pretrainer, 'save_losses') as mock_save_losses:
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
# Verify empty loader behavior
assert len(pretrainer.train_losses) == 1
assert pretrainer.train_losses[0] == 0.0
mock_save_losses.assert_called_once()
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat):
"""Test behavior when batch_data is None."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
loader_instance = MagicMock()
loader_instance.train_loader = [None] # Loader returns None
loader_instance.val_loader = []
mock_data_loader.return_value = loader_instance
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \
patch.object(Pretrainer, 'save_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
# Verify None batch handling
mock_process_batch.assert_not_called()
assert pretrainer.train_losses[0] == 0.0
# Parameter variations
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_concat):
"""Test that custom parameters are properly passed through."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
loader_instance = MagicMock()
loader_instance.train_loader = []
loader_instance.val_loader = []
mock_data_loader.return_value = loader_instance
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
# Custom parameters
custom_output_path = "custom/output/path"
custom_column = "custom_column"
custom_batch_size = 16
custom_sample_size = 5000
with patch.object(Pretrainer, 'save_losses'):
pretrainer.train(
['path/to/dataset.parquet'],
output_path=custom_output_path,
column=custom_column,
batch_size=custom_batch_size,
sample_size=custom_sample_size
)
# Verify custom parameters were used
mock_read_parquet.return_value.head.assert_called_once_with(custom_sample_size)
assert mock_data_loader.call_args[1]['column'] == custom_column
assert mock_data_loader.call_args[1]['batch_size'] == custom_batch_size
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
@patch('builtins.print') # Add this to mock the print function
def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat):
"""Test that model is saved only when validation loss improves."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
# Create mock batch data with proper structure
mock_batch_data = {
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)),
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1]))
}
loader_instance = MagicMock()
loader_instance.train_loader = [mock_batch_data]
loader_instance.val_loader = [mock_batch_data]
mock_data_loader.return_value = loader_instance
mock_model = MagicMock()
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
# Initialize the best validation loss
pretrainer.best_val_loss = float('inf')
mock_batch_loss = torch.tensor(0.5, requires_grad=True)
# Test improving validation loss
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \
patch.object(Pretrainer, 'save_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
# Check for "Best model saved!" 3 times
assert mock_print.call_args_list.count(call("Best model saved!")) == 3
# Reset for next test
mock_print.reset_mock()
pretrainer.train_losses = []
# Reset best validation loss for the second test
pretrainer.best_val_loss = float('inf')
# Test fluctuating validation loss
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \
patch.object(Pretrainer, 'save_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
# Should print "Best model saved!" only on first and third epochs
assert mock_print.call_args_list.count(call("Best model saved!")) == 2
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
def test_validate(mock_process_batch):
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
val_loader = [MagicMock()]
criterion_denoise = MagicMock()
criterion_rotate = MagicMock()
mock_process_batch.return_value = torch.tensor(0.5)
loss = pretrainer._validate(val_loader, criterion_denoise, criterion_rotate)
assert loss == 0.5
# Test the save_losses method
@patch('aiia.pretrain.pretrainer.Pretrainer.save_losses')
def test_save_losses(mock_save_losses):
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
pretrainer.train_losses = [0.1, 0.2]
pretrainer.val_losses = [0.3, 0.4]
csv_file = 'losses.csv'
pretrainer.save_losses(csv_file)
mock_save_losses.assert_called_once_with(csv_file)