AIIA/tests/pretrain/test_pretrainer.py

115 lines
4.5 KiB
Python

import pytest
import torch
from unittest.mock import MagicMock, patch
from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig
from unittest.mock import patch, MagicMock
import pandas as pd
# Test the ProjectionHead class
def test_projection_head():
head = ProjectionHead(hidden_size=512)
x = torch.randn(1, 512, 32, 32)
# Test denoise task
output_denoise = head(x, task='denoise')
assert output_denoise.shape == (1, 3, 32, 32)
# Test rotate task
output_rotate = head(x, task='rotate')
assert output_rotate.shape == (1, 4)
# Test the Pretrainer class initialization
def test_pretrainer_initialization():
config = AIIAConfig()
model = AIIABase(config=config)
pretrainer = Pretrainer(model=model, learning_rate=0.001, config=config)
assert pretrainer.device in ["cuda", "cpu"]
assert isinstance(pretrainer.projection_head, ProjectionHead)
assert isinstance(pretrainer.optimizer, torch.optim.AdamW)
# Test the safe_collate method
def test_safe_collate():
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
batch = [
(torch.randn(3, 32, 32), torch.randn(3, 32, 32), 'denoise'),
(torch.randn(3, 32, 32), torch.tensor(1), 'rotate')
]
collated_batch = pretrainer.safe_collate(batch)
assert 'denoise' in collated_batch
assert 'rotate' in collated_batch
# Test the _process_batch method
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
def test_process_batch(mock_process_batch):
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
batch_data = {
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)),
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1]))
}
criterion_denoise = MagicMock()
criterion_rotate = MagicMock()
mock_process_batch.return_value = 0.5
loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate)
assert loss == 0.5
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
def test_train(mock_data_loader, mock_read_parquet, mock_concat):
# Create a real DataFrame for testing
real_df = pd.DataFrame({
'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]
})
# Configure read_parquet so that for each dataset path, .head(10000) returns real_df
mock_read_parquet.return_value.head.return_value = real_df
# When merging DataFrames, bypass type-checks by letting pd.concat just return real_df
mock_concat.return_value = real_df
# Create an instance of the Pretrainer
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
# Setup the data loader mock instance with empty loaders
loader_instance = MagicMock()
loader_instance.train_loader = [] # so the training loop is immediately skipped
loader_instance.val_loader = [] # so the validation loop is also skipped
mock_data_loader.return_value = loader_instance
# Patch _validate to avoid any actual validation computations.
with patch.object(Pretrainer, '_validate', return_value=0.5):
pretrainer.train(dataset_paths, num_epochs=1)
# Verify that AIIADataLoader was instantiated exactly once...
mock_data_loader.assert_called_once()
# ...and that pd.read_parquet was called once per dataset path (i.e. 2 times in this test)
expected_calls = len(dataset_paths)
assert mock_read_parquet.call_count == expected_calls, (
f"Expected {expected_calls} calls to pd.read_parquet, got {mock_read_parquet.call_count}"
)
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
def test_validate(mock_process_batch):
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
val_loader = [MagicMock()]
criterion_denoise = MagicMock()
criterion_rotate = MagicMock()
mock_process_batch.return_value = torch.tensor(0.5)
loss = pretrainer._validate(val_loader, criterion_denoise, criterion_rotate)
assert loss == 0.5
# Test the save_losses method
@patch('aiia.pretrain.pretrainer.Pretrainer.save_losses')
def test_save_losses(mock_save_losses):
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
pretrainer.train_losses = [0.1, 0.2]
pretrainer.val_losses = [0.3, 0.4]
csv_file = 'losses.csv'
pretrainer.save_losses(csv_file)
mock_save_losses.assert_called_once_with(csv_file)