import pytest import torch from unittest.mock import MagicMock, patch, MagicMock, mock_open, call from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig import pandas as pd # Test the ProjectionHead class def test_projection_head(): head = ProjectionHead(hidden_size=512) x = torch.randn(1, 512, 32, 32) # Test denoise task output_denoise = head(x, task='denoise') assert output_denoise.shape == (1, 3, 32, 32) # Test rotate task output_rotate = head(x, task='rotate') assert output_rotate.shape == (1, 4) # Test the Pretrainer class initialization def test_pretrainer_initialization(): config = AIIAConfig() model = AIIABase(config=config) pretrainer = Pretrainer(model=model, learning_rate=0.001, config=config) assert pretrainer.device in ["cuda", "cpu"] assert isinstance(pretrainer.projection_head, ProjectionHead) assert isinstance(pretrainer.optimizer, torch.optim.AdamW) # Test the safe_collate method def test_safe_collate(): pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) batch = [ (torch.randn(3, 32, 32), torch.randn(3, 32, 32), 'denoise'), (torch.randn(3, 32, 32), torch.tensor(1), 'rotate') ] collated_batch = pretrainer.safe_collate(batch) assert 'denoise' in collated_batch assert 'rotate' in collated_batch # Test the _process_batch method @patch('aiia.pretrain.pretrainer.Pretrainer._process_batch') def test_process_batch(mock_process_batch): pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) batch_data = { 'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)), 'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1])) } criterion_denoise = MagicMock() criterion_rotate = MagicMock() mock_process_batch.return_value = 0.5 loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate) assert loss == 0.5 @patch('pandas.concat') @patch('pandas.read_parquet') @patch('aiia.pretrain.pretrainer.AIIADataLoader') @patch('os.path.join', return_value='mocked/path/model.pt') @patch('builtins.print') # Add this to mock the print function def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat): """Test the train method under normal conditions with comprehensive verification.""" # Setup test data and mocks real_df = pd.DataFrame({ 'image_bytes': [torch.randn(1, 3, 224, 224).tolist()] }) mock_read_parquet.return_value.head.return_value = real_df mock_concat.return_value = real_df # Mock the model and related components mock_model = MagicMock() mock_projection_head = MagicMock() pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) pretrainer.projection_head = mock_projection_head pretrainer.optimizer = MagicMock() # Setup dataset paths and mock batch data dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet'] output_path = "AIIA_test" # Create mock batch data with proper structure mock_batch_data = { 'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)), 'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1])) } # Configure batch loss mock_batch_loss = torch.tensor(0.5, requires_grad=True) loader_instance = MagicMock() loader_instance.train_loader = [mock_batch_data] loader_instance.val_loader = [mock_batch_data] mock_data_loader.return_value = loader_instance # Execute training with patched methods with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss) as mock_process_batch, \ patch.object(Pretrainer, '_validate', side_effect=[0.8, 0.3]) as mock_validate, \ patch.object(Pretrainer, 'save_losses') as mock_save_losses, \ patch('builtins.open', mock_open()): pretrainer.train(dataset_paths, output_path=output_path, num_epochs=2) # Verify method calls assert mock_read_parquet.call_count == len(dataset_paths) assert mock_process_batch.call_count == 2 assert mock_validate.call_count == 2 # Check for "Best model saved!" instead of model.save() mock_print.assert_any_call("Best model saved!") mock_save_losses.assert_called_once() # Verify state changes assert len(pretrainer.train_losses) == 2 assert pretrainer.train_losses == [0.5, 0.5] # Error cases def test_train_no_dataset_paths(): """Test ValueError when no dataset paths are provided.""" pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) with pytest.raises(ValueError, match="No dataset paths provided"): pretrainer.train([]) @patch('pandas.read_parquet') def test_train_all_datasets_fail(mock_read_parquet): """Test handling when all datasets fail to load.""" mock_read_parquet.side_effect = Exception("Failed to load dataset") pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet'] with pytest.raises(ValueError, match="No valid datasets could be loaded"): pretrainer.train(dataset_paths) # Edge cases @patch('pandas.concat') @patch('pandas.read_parquet') @patch('aiia.pretrain.pretrainer.AIIADataLoader') def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat): """Test behavior with empty data loaders.""" real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) mock_read_parquet.return_value.head.return_value = real_df mock_concat.return_value = real_df loader_instance = MagicMock() loader_instance.train_loader = [] # Empty train loader loader_instance.val_loader = [] # Empty val loader mock_data_loader.return_value = loader_instance mock_model = MagicMock() pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) pretrainer.projection_head = MagicMock() pretrainer.optimizer = MagicMock() with patch.object(Pretrainer, 'save_losses') as mock_save_losses: pretrainer.train(['path/to/dataset.parquet'], num_epochs=1) # Verify empty loader behavior assert len(pretrainer.train_losses) == 1 assert pretrainer.train_losses[0] == 0.0 mock_save_losses.assert_called_once() @patch('pandas.concat') @patch('pandas.read_parquet') @patch('aiia.pretrain.pretrainer.AIIADataLoader') def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat): """Test behavior when batch_data is None.""" real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) mock_read_parquet.return_value.head.return_value = real_df mock_concat.return_value = real_df loader_instance = MagicMock() loader_instance.train_loader = [None] # Loader returns None loader_instance.val_loader = [] mock_data_loader.return_value = loader_instance pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) pretrainer.projection_head = MagicMock() pretrainer.optimizer = MagicMock() with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \ patch.object(Pretrainer, 'save_losses'): pretrainer.train(['path/to/dataset.parquet'], num_epochs=1) # Verify None batch handling mock_process_batch.assert_not_called() assert pretrainer.train_losses[0] == 0.0 # Parameter variations @patch('pandas.concat') @patch('pandas.read_parquet') @patch('aiia.pretrain.pretrainer.AIIADataLoader') def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_concat): """Test that custom parameters are properly passed through.""" real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) mock_read_parquet.return_value.head.return_value = real_df mock_concat.return_value = real_df loader_instance = MagicMock() loader_instance.train_loader = [] loader_instance.val_loader = [] mock_data_loader.return_value = loader_instance pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) pretrainer.projection_head = MagicMock() pretrainer.optimizer = MagicMock() # Custom parameters custom_output_path = "custom/output/path" custom_column = "custom_column" custom_batch_size = 16 custom_sample_size = 5000 with patch.object(Pretrainer, 'save_losses'): pretrainer.train( ['path/to/dataset.parquet'], output_path=custom_output_path, column=custom_column, batch_size=custom_batch_size, sample_size=custom_sample_size ) # Verify custom parameters were used mock_read_parquet.return_value.head.assert_called_once_with(custom_sample_size) assert mock_data_loader.call_args[1]['column'] == custom_column assert mock_data_loader.call_args[1]['batch_size'] == custom_batch_size @patch('pandas.concat') @patch('pandas.read_parquet') @patch('aiia.pretrain.pretrainer.AIIADataLoader') @patch('builtins.print') # Add this to mock the print function def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat): """Test that model is saved only when validation loss improves.""" real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) mock_read_parquet.return_value.head.return_value = real_df mock_concat.return_value = real_df # Create mock batch data with proper structure mock_batch_data = { 'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)), 'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1])) } loader_instance = MagicMock() loader_instance.train_loader = [mock_batch_data] loader_instance.val_loader = [mock_batch_data] mock_data_loader.return_value = loader_instance mock_model = MagicMock() pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) pretrainer.projection_head = MagicMock() pretrainer.optimizer = MagicMock() # Initialize the best validation loss pretrainer.best_val_loss = float('inf') mock_batch_loss = torch.tensor(0.5, requires_grad=True) # Test improving validation loss with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \ patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \ patch.object(Pretrainer, 'save_losses'): pretrainer.train(['path/to/dataset.parquet'], num_epochs=3) # Check for "Best model saved!" 3 times assert mock_print.call_args_list.count(call("Best model saved!")) == 3 # Reset for next test mock_print.reset_mock() pretrainer.train_losses = [] # Reset best validation loss for the second test pretrainer.best_val_loss = float('inf') # Test fluctuating validation loss with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \ patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \ patch.object(Pretrainer, 'save_losses'): pretrainer.train(['path/to/dataset.parquet'], num_epochs=3) # Should print "Best model saved!" only on first and third epochs assert mock_print.call_args_list.count(call("Best model saved!")) == 2 @patch('aiia.pretrain.pretrainer.Pretrainer._process_batch') def test_validate(mock_process_batch): pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) val_loader = [MagicMock()] criterion_denoise = MagicMock() criterion_rotate = MagicMock() mock_process_batch.return_value = torch.tensor(0.5) loss = pretrainer._validate(val_loader, criterion_denoise, criterion_rotate) assert loss == 0.5 # Test the save_losses method @patch('aiia.pretrain.pretrainer.Pretrainer.save_losses') def test_save_losses(mock_save_losses): pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) pretrainer.train_losses = [0.1, 0.2] pretrainer.val_losses = [0.3, 0.4] csv_file = 'losses.csv' pretrainer.save_losses(csv_file) mock_save_losses.assert_called_once_with(csv_file)