From 78d88068a0fd6192f783fd3bb36ede8623bfc069 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Wed, 16 Apr 2025 22:58:59 +0200 Subject: [PATCH] updated tests to work with new pretrainer --- tests/pretrain/test_pretrainer.py | 347 ++++++++++++++++-------------- 1 file changed, 183 insertions(+), 164 deletions(-) diff --git a/tests/pretrain/test_pretrainer.py b/tests/pretrain/test_pretrainer.py index db4c73f..209fd9e 100644 --- a/tests/pretrain/test_pretrainer.py +++ b/tests/pretrain/test_pretrainer.py @@ -3,6 +3,8 @@ import torch from unittest.mock import MagicMock, patch, MagicMock, mock_open, call from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig import pandas as pd +import os +import datetime # Test the ProjectionHead class def test_projection_head(): @@ -53,11 +55,94 @@ def test_process_batch(mock_process_batch): loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate) assert loss == 0.5 +# Error cases +# New tests for checkpoint handling +@patch('torch.save') +@patch('os.path.join') +def test_save_checkpoint(mock_join, mock_save): + """Test checkpoint saving functionality.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + pretrainer.projection_head = MagicMock() + pretrainer.optimizer = MagicMock() + + checkpoint_dir = "checkpoints" + epoch = 1 + batch_count = 100 + checkpoint_name = "test_checkpoint.pt" + + mock_join.return_value = os.path.join(checkpoint_dir, checkpoint_name) + + path = pretrainer._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name) + + assert path == os.path.join(checkpoint_dir, checkpoint_name) + mock_save.assert_called_once() + +@patch('os.makedirs') +@patch('os.path.exists') +@patch('torch.load') +def test_load_checkpoint_specific(mock_load, mock_exists, mock_makedirs): + """Test loading a specific checkpoint.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + pretrainer.projection_head = MagicMock() + pretrainer.optimizer = MagicMock() + + checkpoint_dir = "checkpoints" + specific_checkpoint = "specific_checkpoint.pt" + mock_exists.return_value = True + + mock_load.return_value = { + 'epoch': 2, + 'batch': 150, + 'model_state_dict': {}, + 'projection_head_state_dict': {}, + 'optimizer_state_dict': {}, + 'train_losses': [], + 'val_losses': [] + } + + result = pretrainer.load_checkpoint(checkpoint_dir, specific_checkpoint) + assert result == (2, 150) + +@patch('os.listdir') +@patch('os.path.getmtime') +def test_load_checkpoint_most_recent(mock_getmtime, mock_listdir): + """Test loading the most recent checkpoint.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + + checkpoint_dir = "checkpoints" + mock_listdir.return_value = ["checkpoint_1.pt", "checkpoint_2.pt"] + mock_getmtime.side_effect = [100, 200] # checkpoint_2.pt is more recent + + with patch.object(pretrainer, '_load_checkpoint_file', return_value=(2, 150)): + result = pretrainer.load_checkpoint(checkpoint_dir) + assert result == (2, 150) + +def test_initialize_checkpoint_variables(): + """Test initialization of checkpoint variables.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + pretrainer._initialize_checkpoint_variables() + + assert hasattr(pretrainer, 'last_checkpoint_time') + assert pretrainer.checkpoint_interval == 2 * 60 * 60 + assert pretrainer.last_22_date is None + assert pretrainer.recent_checkpoints == [] + +@patch('torch.nn.MSELoss') +@patch('torch.nn.CrossEntropyLoss') +def test_initialize_loss_functions(mock_ce_loss, mock_mse_loss): + """Test loss function initialization.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + criterion_denoise, criterion_rotate, best_val_loss = pretrainer._initialize_loss_functions() + + assert mock_mse_loss.called + assert mock_ce_loss.called + assert best_val_loss == float('inf') + @patch('pandas.concat') @patch('pandas.read_parquet') @patch('aiia.pretrain.pretrainer.AIIADataLoader') @patch('os.path.join', return_value='mocked/path/model.pt') -@patch('builtins.print') # Add this to mock the print function +@patch('builtins.print') def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat): """Test the train method under normal conditions with comprehensive verification.""" # Setup test data and mocks @@ -73,6 +158,7 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) pretrainer.projection_head = mock_projection_head pretrainer.optimizer = MagicMock() + pretrainer.checkpoint_dir = None # Initialize checkpoint_dir # Setup dataset paths and mock batch data dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet'] @@ -104,185 +190,118 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea assert mock_process_batch.call_count == 2 assert mock_validate.call_count == 2 - # Check for "Best model saved!" instead of model.save() mock_print.assert_any_call("Best model saved!") - mock_save_losses.assert_called_once() - # Verify state changes assert len(pretrainer.train_losses) == 2 assert pretrainer.train_losses == [0.5, 0.5] - -# Error cases -def test_train_no_dataset_paths(): - """Test ValueError when no dataset paths are provided.""" +@patch('datetime.datetime') +@patch('time.time') +def test_handle_checkpoints(mock_time, mock_datetime): + """Test checkpoint handling logic.""" pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + pretrainer.checkpoint_dir = "checkpoints" + pretrainer.current_epoch = 1 + pretrainer._initialize_checkpoint_variables() + + # Set a base time value + base_time = 1000 + # Set the last checkpoint time to base_time + pretrainer.last_checkpoint_time = base_time + + # Mock time to return base_time + interval + 1 to trigger checkpoint save + mock_time.return_value = base_time + pretrainer.checkpoint_interval + 1 + + # Mock datetime for 22:00 checkpoint + mock_dt = MagicMock() + mock_dt.hour = 22 + mock_dt.minute = 0 + mock_dt.date.return_value = datetime.date(2023, 1, 1) + mock_datetime.now.return_value = mock_dt + + with patch.object(pretrainer, '_save_checkpoint') as mock_save: + pretrainer._handle_checkpoints(100) + # Should be called twice - once for regular interval and once for 22:00 + assert mock_save.call_count == 2 - with pytest.raises(ValueError, match="No dataset paths provided"): - pretrainer.train([]) - -@patch('pandas.read_parquet') -def test_train_all_datasets_fail(mock_read_parquet): - """Test handling when all datasets fail to load.""" - mock_read_parquet.side_effect = Exception("Failed to load dataset") - +def test_training_phase(): + """Test the training phase logic.""" pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) - dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet'] - - with pytest.raises(ValueError, match="No valid datasets could be loaded"): - pretrainer.train(dataset_paths) - -# Edge cases -@patch('pandas.concat') -@patch('pandas.read_parquet') -@patch('aiia.pretrain.pretrainer.AIIADataLoader') -def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat): - """Test behavior with empty data loaders.""" - real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) - mock_read_parquet.return_value.head.return_value = real_df - mock_concat.return_value = real_df - - loader_instance = MagicMock() - loader_instance.train_loader = [] # Empty train loader - loader_instance.val_loader = [] # Empty val loader - mock_data_loader.return_value = loader_instance - - mock_model = MagicMock() - pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) - pretrainer.projection_head = MagicMock() pretrainer.optimizer = MagicMock() - - with patch.object(Pretrainer, 'save_losses') as mock_save_losses: - pretrainer.train(['path/to/dataset.parquet'], num_epochs=1) - - # Verify empty loader behavior - assert len(pretrainer.train_losses) == 1 - assert pretrainer.train_losses[0] == 0.0 - mock_save_losses.assert_called_once() - -@patch('pandas.concat') -@patch('pandas.read_parquet') -@patch('aiia.pretrain.pretrainer.AIIADataLoader') -def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat): - """Test behavior when batch_data is None.""" - real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) - mock_read_parquet.return_value.head.return_value = real_df - mock_concat.return_value = real_df - - loader_instance = MagicMock() - loader_instance.train_loader = [None] # Loader returns None - loader_instance.val_loader = [] - mock_data_loader.return_value = loader_instance - - pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) - pretrainer.projection_head = MagicMock() - pretrainer.optimizer = MagicMock() - - with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \ - patch.object(Pretrainer, 'save_losses'): - pretrainer.train(['path/to/dataset.parquet'], num_epochs=1) - - # Verify None batch handling - mock_process_batch.assert_not_called() - assert pretrainer.train_losses[0] == 0.0 - -# Parameter variations -@patch('pandas.concat') -@patch('pandas.read_parquet') -@patch('aiia.pretrain.pretrainer.AIIADataLoader') -def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_concat): - """Test that custom parameters are properly passed through.""" - real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) - mock_read_parquet.return_value.head.return_value = real_df - mock_concat.return_value = real_df - - loader_instance = MagicMock() - loader_instance.train_loader = [] - loader_instance.val_loader = [] - mock_data_loader.return_value = loader_instance - - pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) - pretrainer.projection_head = MagicMock() - pretrainer.optimizer = MagicMock() - - # Custom parameters - custom_output_path = "custom/output/path" - custom_column = "custom_column" - custom_batch_size = 16 - custom_sample_size = 5000 - - with patch.object(Pretrainer, 'save_losses'): - pretrainer.train( - ['path/to/dataset.parquet'], - output_path=custom_output_path, - column=custom_column, - batch_size=custom_batch_size, - sample_size=custom_sample_size - ) - - # Verify custom parameters were used - mock_read_parquet.return_value.head.assert_called_once_with(custom_sample_size) - assert mock_data_loader.call_args[1]['column'] == custom_column - assert mock_data_loader.call_args[1]['batch_size'] == custom_batch_size - - - -@patch('pandas.concat') -@patch('pandas.read_parquet') -@patch('aiia.pretrain.pretrainer.AIIADataLoader') -@patch('builtins.print') # Add this to mock the print function -def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat): - """Test that model is saved only when validation loss improves.""" - real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]}) - mock_read_parquet.return_value.head.return_value = real_df - mock_concat.return_value = real_df - - # Create mock batch data with proper structure + pretrainer.checkpoint_dir = None # Initialize checkpoint_dir + pretrainer._initialize_checkpoint_variables() + pretrainer.current_epoch = 0 + + # Create mock batch data with requires_grad=True mock_batch_data = { - 'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)), - 'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1])) + 'denoise': ( + torch.randn(2, 3, 32, 32, requires_grad=True), + torch.randn(2, 3, 32, 32, requires_grad=True) + ), + 'rotate': ( + torch.randn(2, 3, 32, 32, requires_grad=True), + torch.tensor([0, 1], dtype=torch.long) # Labels typically don't need gradients + ) } + mock_train_loader = [(0, mock_batch_data)] # Include batch index + + # Mock the loss functions to return tensors that require gradients + criterion_denoise = MagicMock(return_value=torch.tensor(0.5, requires_grad=True)) + criterion_rotate = MagicMock(return_value=torch.tensor(0.5, requires_grad=True)) + + with patch.object(pretrainer, '_process_batch', return_value=torch.tensor(0.5, requires_grad=True)), \ + patch.object(pretrainer, '_handle_checkpoints') as mock_handle_checkpoints: + + total_loss, batch_count = pretrainer._training_phase( + mock_train_loader, 0, criterion_denoise, criterion_rotate) + + assert total_loss == 0.5 + assert batch_count == 1 + mock_handle_checkpoints.assert_called_once_with(1) # Check if checkpoint handling was called - loader_instance = MagicMock() - loader_instance.train_loader = [mock_batch_data] - loader_instance.val_loader = [mock_batch_data] - mock_data_loader.return_value = loader_instance - - mock_model = MagicMock() - pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) +def test_validation_phase(): + """Test the validation phase logic.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) pretrainer.projection_head = MagicMock() - pretrainer.optimizer = MagicMock() + + mock_val_loader = [MagicMock()] + criterion_denoise = MagicMock() + criterion_rotate = MagicMock() + + with patch.object(pretrainer, '_validate', return_value=0.4): + val_loss = pretrainer._validation_phase( + mock_val_loader, criterion_denoise, criterion_rotate) + + assert val_loss == 0.4 - # Initialize the best validation loss - pretrainer.best_val_loss = float('inf') +@patch('pandas.read_parquet') +def test_load_and_merge_datasets(mock_read_parquet): + """Test dataset loading and merging.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + + mock_df = pd.DataFrame({'col': [1, 2, 3]}) + mock_read_parquet.return_value.head.return_value = mock_df + + result = pretrainer._load_and_merge_datasets(['path1.parquet', 'path2.parquet'], 1000) + assert len(result) == 6 # 2 datasets * 3 rows each - mock_batch_loss = torch.tensor(0.5, requires_grad=True) - - # Test improving validation loss - with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \ - patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \ - patch.object(Pretrainer, 'save_losses'): - pretrainer.train(['path/to/dataset.parquet'], num_epochs=3) - - # Check for "Best model saved!" 3 times - assert mock_print.call_args_list.count(call("Best model saved!")) == 3 - - # Reset for next test - mock_print.reset_mock() - pretrainer.train_losses = [] - - # Reset best validation loss for the second test - pretrainer.best_val_loss = float('inf') - - # Test fluctuating validation loss - with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \ - patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \ - patch.object(Pretrainer, 'save_losses'): - pretrainer.train(['path/to/dataset.parquet'], num_epochs=3) - - # Should print "Best model saved!" only on first and third epochs - assert mock_print.call_args_list.count(call("Best model saved!")) == 2 +def test_process_batch_none_tasks(): + """Test processing batch with no tasks.""" + pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) + + batch_data = { + 'denoise': None, + 'rotate': None + } + + loss = pretrainer._process_batch( + batch_data, + criterion_denoise=MagicMock(), + criterion_rotate=MagicMock() + ) + + assert loss == 0 @patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')