diff --git a/tests/pretrain/test_pretrainer.py b/tests/pretrain/test_pretrainer.py index d8543a6..db4c73f 100644 --- a/tests/pretrain/test_pretrainer.py +++ b/tests/pretrain/test_pretrainer.py @@ -1,11 +1,9 @@ import pytest import torch -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, MagicMock, mock_open, call from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig -from unittest.mock import patch, MagicMock import pandas as pd - # Test the ProjectionHead class def test_projection_head(): head = ProjectionHead(hidden_size=512) @@ -58,38 +56,233 @@ def test_process_batch(mock_process_batch): @patch('pandas.concat') @patch('pandas.read_parquet') @patch('aiia.pretrain.pretrainer.AIIADataLoader') -def test_train(mock_data_loader, mock_read_parquet, mock_concat): - # Create a real DataFrame for testing +@patch('os.path.join', return_value='mocked/path/model.pt') +@patch('builtins.print') # Add this to mock the print function +def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat): + """Test the train method under normal conditions with comprehensive verification.""" + # Setup test data and mocks real_df = pd.DataFrame({ - 'image_bytes': [torch.randn(1, 3, 224, 224).tolist()] + 'image_bytes': [torch.randn(1, 3, 224, 224).tolist()] }) - # Configure read_parquet so that for each dataset path, .head(10000) returns real_df mock_read_parquet.return_value.head.return_value = real_df - - # When merging DataFrames, bypass type-checks by letting pd.concat just return real_df mock_concat.return_value = real_df - # Create an instance of the Pretrainer - pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) - dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet'] + # Mock the model and related components + mock_model = MagicMock() + mock_projection_head = MagicMock() + pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) + pretrainer.projection_head = mock_projection_head + pretrainer.optimizer = MagicMock() - # Setup the data loader mock instance with empty loaders + # Setup dataset paths and mock batch data + dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet'] + output_path = "AIIA_test" + + # Create mock batch data with proper structure + mock_batch_data = { + 'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)), + 'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1])) + } + + # Configure batch loss + mock_batch_loss = torch.tensor(0.5, requires_grad=True) loader_instance = MagicMock() - loader_instance.train_loader = [] # so the training loop is immediately skipped - loader_instance.val_loader = [] # so the validation loop is also skipped + loader_instance.train_loader = [mock_batch_data] + loader_instance.val_loader = [mock_batch_data] mock_data_loader.return_value = loader_instance - # Patch _validate to avoid any actual validation computations. - with patch.object(Pretrainer, '_validate', return_value=0.5): - pretrainer.train(dataset_paths, num_epochs=1) + # Execute training with patched methods + with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss) as mock_process_batch, \ + patch.object(Pretrainer, '_validate', side_effect=[0.8, 0.3]) as mock_validate, \ + patch.object(Pretrainer, 'save_losses') as mock_save_losses, \ + patch('builtins.open', mock_open()): - # Verify that AIIADataLoader was instantiated exactly once... - mock_data_loader.assert_called_once() - # ...and that pd.read_parquet was called once per dataset path (i.e. 2 times in this test) - expected_calls = len(dataset_paths) - assert mock_read_parquet.call_count == expected_calls, ( - f"Expected {expected_calls} calls to pd.read_parquet, got {mock_read_parquet.call_count}" - ) + pretrainer.train(dataset_paths, output_path=output_path, num_epochs=2) + + # Verify method calls + assert mock_read_parquet.call_count == len(dataset_paths) + assert mock_process_batch.call_count == 2 + assert mock_validate.call_count == 2 + + # Check for "Best model saved!" instead of model.save() + mock_print.assert_any_call("Best model saved!") + + mock_save_losses.assert_called_once() + + # Verify state changes + assert len(pretrainer.train_losses) == 2 + assert pretrainer.train_losses == [0.5, 0.5] + + +# Error cases +def test_train_no_dataset_paths(): + """Test ValueError when no dataset paths are provided.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + + with pytest.raises(ValueError, match="No dataset paths provided"): + pretrainer.train([]) + +@patch('pandas.read_parquet') +def test_train_all_datasets_fail(mock_read_parquet): + """Test handling when all datasets fail to load.""" + mock_read_parquet.side_effect = Exception("Failed to load dataset") + + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet'] + + with pytest.raises(ValueError, match="No valid datasets could be loaded"): + pretrainer.train(dataset_paths) + +# Edge cases +@patch('pandas.concat') +@patch('pandas.read_parquet') +@patch('aiia.pretrain.pretrainer.AIIADataLoader') +def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat): + """Test behavior with empty data loaders.""" + real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) + mock_read_parquet.return_value.head.return_value = real_df + mock_concat.return_value = real_df + + loader_instance = MagicMock() + loader_instance.train_loader = [] # Empty train loader + loader_instance.val_loader = [] # Empty val loader + mock_data_loader.return_value = loader_instance + + mock_model = MagicMock() + pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) + pretrainer.projection_head = MagicMock() + pretrainer.optimizer = MagicMock() + + with patch.object(Pretrainer, 'save_losses') as mock_save_losses: + pretrainer.train(['path/to/dataset.parquet'], num_epochs=1) + + # Verify empty loader behavior + assert len(pretrainer.train_losses) == 1 + assert pretrainer.train_losses[0] == 0.0 + mock_save_losses.assert_called_once() + +@patch('pandas.concat') +@patch('pandas.read_parquet') +@patch('aiia.pretrain.pretrainer.AIIADataLoader') +def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat): + """Test behavior when batch_data is None.""" + real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) + mock_read_parquet.return_value.head.return_value = real_df + mock_concat.return_value = real_df + + loader_instance = MagicMock() + loader_instance.train_loader = [None] # Loader returns None + loader_instance.val_loader = [] + mock_data_loader.return_value = loader_instance + + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + pretrainer.projection_head = MagicMock() + pretrainer.optimizer = MagicMock() + + with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \ + patch.object(Pretrainer, 'save_losses'): + pretrainer.train(['path/to/dataset.parquet'], num_epochs=1) + + # Verify None batch handling + mock_process_batch.assert_not_called() + assert pretrainer.train_losses[0] == 0.0 + +# Parameter variations +@patch('pandas.concat') +@patch('pandas.read_parquet') +@patch('aiia.pretrain.pretrainer.AIIADataLoader') +def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_concat): + """Test that custom parameters are properly passed through.""" + real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) + mock_read_parquet.return_value.head.return_value = real_df + mock_concat.return_value = real_df + + loader_instance = MagicMock() + loader_instance.train_loader = [] + loader_instance.val_loader = [] + mock_data_loader.return_value = loader_instance + + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + pretrainer.projection_head = MagicMock() + pretrainer.optimizer = MagicMock() + + # Custom parameters + custom_output_path = "custom/output/path" + custom_column = "custom_column" + custom_batch_size = 16 + custom_sample_size = 5000 + + with patch.object(Pretrainer, 'save_losses'): + pretrainer.train( + ['path/to/dataset.parquet'], + output_path=custom_output_path, + column=custom_column, + batch_size=custom_batch_size, + sample_size=custom_sample_size + ) + + # Verify custom parameters were used + mock_read_parquet.return_value.head.assert_called_once_with(custom_sample_size) + assert mock_data_loader.call_args[1]['column'] == custom_column + assert mock_data_loader.call_args[1]['batch_size'] == custom_batch_size + + + +@patch('pandas.concat') +@patch('pandas.read_parquet') +@patch('aiia.pretrain.pretrainer.AIIADataLoader') +@patch('builtins.print') # Add this to mock the print function +def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat): + """Test that model is saved only when validation loss improves.""" + real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) + mock_read_parquet.return_value.head.return_value = real_df + mock_concat.return_value = real_df + + # Create mock batch data with proper structure + mock_batch_data = { + 'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)), + 'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1])) + } + + loader_instance = MagicMock() + loader_instance.train_loader = [mock_batch_data] + loader_instance.val_loader = [mock_batch_data] + mock_data_loader.return_value = loader_instance + + mock_model = MagicMock() + pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) + pretrainer.projection_head = MagicMock() + pretrainer.optimizer = MagicMock() + + # Initialize the best validation loss + pretrainer.best_val_loss = float('inf') + + mock_batch_loss = torch.tensor(0.5, requires_grad=True) + + # Test improving validation loss + with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \ + patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \ + patch.object(Pretrainer, 'save_losses'): + pretrainer.train(['path/to/dataset.parquet'], num_epochs=3) + + # Check for "Best model saved!" 3 times + assert mock_print.call_args_list.count(call("Best model saved!")) == 3 + + # Reset for next test + mock_print.reset_mock() + pretrainer.train_losses = [] + + # Reset best validation loss for the second test + pretrainer.best_val_loss = float('inf') + + # Test fluctuating validation loss + with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \ + patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \ + patch.object(Pretrainer, 'save_losses'): + pretrainer.train(['path/to/dataset.parquet'], num_epochs=3) + + # Should print "Best model saved!" only on first and third epochs + assert mock_print.call_args_list.count(call("Best model saved!")) == 2 @patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')