import pytest import torch from unittest.mock import MagicMock, patch from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig from unittest.mock import patch, MagicMock import pandas as pd # Test the ProjectionHead class def test_projection_head(): head = ProjectionHead(hidden_size=512) x = torch.randn(1, 512, 32, 32) # Test denoise task output_denoise = head(x, task='denoise') assert output_denoise.shape == (1, 3, 32, 32) # Test rotate task output_rotate = head(x, task='rotate') assert output_rotate.shape == (1, 4) # Test the Pretrainer class initialization def test_pretrainer_initialization(): config = AIIAConfig() model = AIIABase(config=config) pretrainer = Pretrainer(model=model, learning_rate=0.001, config=config) assert pretrainer.device in ["cuda", "cpu"] assert isinstance(pretrainer.projection_head, ProjectionHead) assert isinstance(pretrainer.optimizer, torch.optim.AdamW) # Test the safe_collate method def test_safe_collate(): pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) batch = [ (torch.randn(3, 32, 32), torch.randn(3, 32, 32), 'denoise'), (torch.randn(3, 32, 32), torch.tensor(1), 'rotate') ] collated_batch = pretrainer.safe_collate(batch) assert 'denoise' in collated_batch assert 'rotate' in collated_batch # Test the _process_batch method @patch('aiia.pretrain.pretrainer.Pretrainer._process_batch') def test_process_batch(mock_process_batch): pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) batch_data = { 'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)), 'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1])) } criterion_denoise = MagicMock() criterion_rotate = MagicMock() mock_process_batch.return_value = 0.5 loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate) assert loss == 0.5 @patch('pandas.concat') @patch('pandas.read_parquet') @patch('aiia.pretrain.pretrainer.AIIADataLoader') def test_train(mock_data_loader, mock_read_parquet, mock_concat): # Create a real DataFrame for testing real_df = pd.DataFrame({ 'image_bytes': [torch.randn(1, 3, 224, 224).tolist()] }) # Configure read_parquet so that for each dataset path, .head(10000) returns real_df mock_read_parquet.return_value.head.return_value = real_df # When merging DataFrames, bypass type-checks by letting pd.concat just return real_df mock_concat.return_value = real_df # Create an instance of the Pretrainer pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet'] # Setup the data loader mock instance with empty loaders loader_instance = MagicMock() loader_instance.train_loader = [] # so the training loop is immediately skipped loader_instance.val_loader = [] # so the validation loop is also skipped mock_data_loader.return_value = loader_instance # Patch _validate to avoid any actual validation computations. with patch.object(Pretrainer, '_validate', return_value=0.5): pretrainer.train(dataset_paths, num_epochs=1) # Verify that AIIADataLoader was instantiated exactly once... mock_data_loader.assert_called_once() # ...and that pd.read_parquet was called once per dataset path (i.e. 2 times in this test) expected_calls = len(dataset_paths) assert mock_read_parquet.call_count == expected_calls, ( f"Expected {expected_calls} calls to pd.read_parquet, got {mock_read_parquet.call_count}" ) @patch('aiia.pretrain.pretrainer.Pretrainer._process_batch') def test_validate(mock_process_batch): pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) val_loader = [MagicMock()] criterion_denoise = MagicMock() criterion_rotate = MagicMock() mock_process_batch.return_value = torch.tensor(0.5) loss = pretrainer._validate(val_loader, criterion_denoise, criterion_rotate) assert loss == 0.5 # Test the save_losses method @patch('aiia.pretrain.pretrainer.Pretrainer.save_losses') def test_save_losses(mock_save_losses): pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) pretrainer.train_losses = [0.1, 0.2] pretrainer.val_losses = [0.3, 0.4] csv_file = 'losses.csv' pretrainer.save_losses(csv_file) mock_save_losses.assert_called_once_with(csv_file)