diff --git a/tests/pretrain/test_pretrainer.py b/tests/pretrain/test_pretrainer.py new file mode 100644 index 0000000..d8543a6 --- /dev/null +++ b/tests/pretrain/test_pretrainer.py @@ -0,0 +1,115 @@ +import pytest +import torch +from unittest.mock import MagicMock, patch +from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig +from unittest.mock import patch, MagicMock +import pandas as pd + + +# Test the ProjectionHead class +def test_projection_head(): + head = ProjectionHead(hidden_size=512) + x = torch.randn(1, 512, 32, 32) + + # Test denoise task + output_denoise = head(x, task='denoise') + assert output_denoise.shape == (1, 3, 32, 32) + + # Test rotate task + output_rotate = head(x, task='rotate') + assert output_rotate.shape == (1, 4) + +# Test the Pretrainer class initialization +def test_pretrainer_initialization(): + config = AIIAConfig() + model = AIIABase(config=config) + pretrainer = Pretrainer(model=model, learning_rate=0.001, config=config) + assert pretrainer.device in ["cuda", "cpu"] + assert isinstance(pretrainer.projection_head, ProjectionHead) + assert isinstance(pretrainer.optimizer, torch.optim.AdamW) + +# Test the safe_collate method +def test_safe_collate(): + pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) + batch = [ + (torch.randn(3, 32, 32), torch.randn(3, 32, 32), 'denoise'), + (torch.randn(3, 32, 32), torch.tensor(1), 'rotate') + ] + + collated_batch = pretrainer.safe_collate(batch) + assert 'denoise' in collated_batch + assert 'rotate' in collated_batch + +# Test the _process_batch method +@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch') +def test_process_batch(mock_process_batch): + pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) + batch_data = { + 'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)), + 'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1])) + } + criterion_denoise = MagicMock() + criterion_rotate = MagicMock() + + mock_process_batch.return_value = 0.5 + loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate) + assert loss == 0.5 + +@patch('pandas.concat') +@patch('pandas.read_parquet') +@patch('aiia.pretrain.pretrainer.AIIADataLoader') +def test_train(mock_data_loader, mock_read_parquet, mock_concat): + # Create a real DataFrame for testing + real_df = pd.DataFrame({ + 'image_bytes': [torch.randn(1, 3, 224, 224).tolist()] + }) + # Configure read_parquet so that for each dataset path, .head(10000) returns real_df + mock_read_parquet.return_value.head.return_value = real_df + + # When merging DataFrames, bypass type-checks by letting pd.concat just return real_df + mock_concat.return_value = real_df + + # Create an instance of the Pretrainer + pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) + dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet'] + + # Setup the data loader mock instance with empty loaders + loader_instance = MagicMock() + loader_instance.train_loader = [] # so the training loop is immediately skipped + loader_instance.val_loader = [] # so the validation loop is also skipped + mock_data_loader.return_value = loader_instance + + # Patch _validate to avoid any actual validation computations. + with patch.object(Pretrainer, '_validate', return_value=0.5): + pretrainer.train(dataset_paths, num_epochs=1) + + # Verify that AIIADataLoader was instantiated exactly once... + mock_data_loader.assert_called_once() + # ...and that pd.read_parquet was called once per dataset path (i.e. 2 times in this test) + expected_calls = len(dataset_paths) + assert mock_read_parquet.call_count == expected_calls, ( + f"Expected {expected_calls} calls to pd.read_parquet, got {mock_read_parquet.call_count}" + ) + + +@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch') +def test_validate(mock_process_batch): + pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) + val_loader = [MagicMock()] + criterion_denoise = MagicMock() + criterion_rotate = MagicMock() + + mock_process_batch.return_value = torch.tensor(0.5) + loss = pretrainer._validate(val_loader, criterion_denoise, criterion_rotate) + assert loss == 0.5 + +# Test the save_losses method +@patch('aiia.pretrain.pretrainer.Pretrainer.save_losses') +def test_save_losses(mock_save_losses): + pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) + pretrainer.train_losses = [0.1, 0.2] + pretrainer.val_losses = [0.3, 0.4] + + csv_file = 'losses.csv' + pretrainer.save_losses(csv_file) + mock_save_losses.assert_called_once_with(csv_file) \ No newline at end of file