import pytest import torch from unittest.mock import MagicMock, patch, MagicMock, mock_open, call from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig import pandas as pd import os import datetime # Test the ProjectionHead class def test_projection_head(): head = ProjectionHead(hidden_size=512) x = torch.randn(1, 512, 32, 32) # Test denoise task output_denoise = head(x, task='denoise') assert output_denoise.shape == (1, 3, 32, 32) # Test rotate task output_rotate = head(x, task='rotate') assert output_rotate.shape == (1, 4) # Test the Pretrainer class initialization def test_pretrainer_initialization(): config = AIIAConfig() model = AIIABase(config=config) pretrainer = Pretrainer(model=model, learning_rate=0.001, config=config) assert pretrainer.device in ["cuda", "cpu"] assert isinstance(pretrainer.projection_head, ProjectionHead) assert isinstance(pretrainer.optimizer, torch.optim.AdamW) # Test the safe_collate method def test_safe_collate(): pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) batch = [ (torch.randn(3, 32, 32), torch.randn(3, 32, 32), 'denoise'), (torch.randn(3, 32, 32), torch.tensor(1), 'rotate') ] collated_batch = pretrainer.safe_collate(batch) assert 'denoise' in collated_batch assert 'rotate' in collated_batch # Test the _process_batch method @patch('aiia.pretrain.pretrainer.Pretrainer._process_batch') def test_process_batch(mock_process_batch): pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) batch_data = { 'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)), 'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1])) } criterion_denoise = MagicMock() criterion_rotate = MagicMock() mock_process_batch.return_value = 0.5 loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate) assert loss == 0.5 # Error cases # New tests for checkpoint handling @patch('torch.save') @patch('os.path.join') def test_save_checkpoint(mock_join, mock_save): """Test checkpoint saving functionality.""" pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) pretrainer.projection_head = MagicMock() pretrainer.optimizer = MagicMock() checkpoint_dir = "checkpoints" epoch = 1 batch_count = 100 checkpoint_name = "test_checkpoint.pt" mock_join.return_value = os.path.join(checkpoint_dir, checkpoint_name) path = pretrainer._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name) assert path == os.path.join(checkpoint_dir, checkpoint_name) mock_save.assert_called_once() @patch('os.makedirs') @patch('os.path.exists') @patch('torch.load') def test_load_checkpoint_specific(mock_load, mock_exists, mock_makedirs): """Test loading a specific checkpoint.""" pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) pretrainer.projection_head = MagicMock() pretrainer.optimizer = MagicMock() checkpoint_dir = "checkpoints" specific_checkpoint = "specific_checkpoint.pt" mock_exists.return_value = True mock_load.return_value = { 'epoch': 2, 'batch': 150, 'model_state_dict': {}, 'projection_head_state_dict': {}, 'optimizer_state_dict': {}, 'train_losses': [], 'val_losses': [] } result = pretrainer.load_checkpoint(checkpoint_dir, specific_checkpoint) assert result == (2, 150) @patch('os.listdir') @patch('os.path.getmtime') def test_load_checkpoint_most_recent(mock_getmtime, mock_listdir): """Test loading the most recent checkpoint.""" pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) checkpoint_dir = "checkpoints" mock_listdir.return_value = ["checkpoint_1.pt", "checkpoint_2.pt"] mock_getmtime.side_effect = [100, 200] # checkpoint_2.pt is more recent with patch.object(pretrainer, '_load_checkpoint_file', return_value=(2, 150)): result = pretrainer.load_checkpoint(checkpoint_dir) assert result == (2, 150) def test_initialize_checkpoint_variables(): """Test initialization of checkpoint variables.""" pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) pretrainer._initialize_checkpoint_variables() assert hasattr(pretrainer, 'last_checkpoint_time') assert pretrainer.checkpoint_interval == 2 * 60 * 60 assert pretrainer.last_22_date is None assert pretrainer.recent_checkpoints == [] @patch('torch.nn.MSELoss') @patch('torch.nn.CrossEntropyLoss') def test_initialize_loss_functions(mock_ce_loss, mock_mse_loss): """Test loss function initialization.""" pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) criterion_denoise, criterion_rotate, best_val_loss = pretrainer._initialize_loss_functions() assert mock_mse_loss.called assert mock_ce_loss.called assert best_val_loss == float('inf') @patch('pandas.concat') @patch('pandas.read_parquet') @patch('aiia.pretrain.pretrainer.AIIADataLoader') @patch('os.path.join', return_value='mocked/path/model.pt') @patch('builtins.print') def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat): """Test the train method under normal conditions with comprehensive verification.""" # Setup test data and mocks real_df = pd.DataFrame({ 'image_bytes': [torch.randn(1, 3, 224, 224).tolist()] }) mock_read_parquet.return_value.head.return_value = real_df mock_concat.return_value = real_df # Mock the model and related components mock_model = MagicMock() mock_projection_head = MagicMock() pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) pretrainer.projection_head = mock_projection_head pretrainer.optimizer = MagicMock() pretrainer.checkpoint_dir = None # Initialize checkpoint_dir # Setup dataset paths and mock batch data dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet'] output_path = "AIIA_test" # Create mock batch data with proper structure mock_batch_data = { 'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)), 'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1])) } # Configure batch loss mock_batch_loss = torch.tensor(0.5, requires_grad=True) loader_instance = MagicMock() loader_instance.train_loader = [mock_batch_data] loader_instance.val_loader = [mock_batch_data] mock_data_loader.return_value = loader_instance # Execute training with patched methods with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss) as mock_process_batch, \ patch.object(Pretrainer, '_validate', side_effect=[0.8, 0.3]) as mock_validate, \ patch.object(Pretrainer, 'save_losses') as mock_save_losses, \ patch('builtins.open', mock_open()): pretrainer.train(dataset_paths, output_path=output_path, num_epochs=2) # Verify method calls assert mock_read_parquet.call_count == len(dataset_paths) assert mock_process_batch.call_count == 2 assert mock_validate.call_count == 2 mock_print.assert_any_call("Best model saved!") mock_save_losses.assert_called_once() assert len(pretrainer.train_losses) == 2 assert pretrainer.train_losses == [0.5, 0.5] @patch('datetime.datetime') @patch('time.time') def test_handle_checkpoints(mock_time, mock_datetime): """Test checkpoint handling logic.""" pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) pretrainer.checkpoint_dir = "checkpoints" pretrainer.current_epoch = 1 pretrainer._initialize_checkpoint_variables() # Set a base time value base_time = 1000 # Set the last checkpoint time to base_time pretrainer.last_checkpoint_time = base_time # Mock time to return base_time + interval + 1 to trigger checkpoint save mock_time.return_value = base_time + pretrainer.checkpoint_interval + 1 # Mock datetime for 22:00 checkpoint mock_dt = MagicMock() mock_dt.hour = 22 mock_dt.minute = 0 mock_dt.date.return_value = datetime.date(2023, 1, 1) mock_datetime.now.return_value = mock_dt with patch.object(pretrainer, '_save_checkpoint') as mock_save: pretrainer._handle_checkpoints(100) # Should be called twice - once for regular interval and once for 22:00 assert mock_save.call_count == 2 def test_training_phase(): """Test the training phase logic.""" pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) pretrainer.optimizer = MagicMock() pretrainer.checkpoint_dir = None # Initialize checkpoint_dir pretrainer._initialize_checkpoint_variables() pretrainer.current_epoch = 0 # Create mock batch data with requires_grad=True mock_batch_data = { 'denoise': ( torch.randn(2, 3, 32, 32, requires_grad=True), torch.randn(2, 3, 32, 32, requires_grad=True) ), 'rotate': ( torch.randn(2, 3, 32, 32, requires_grad=True), torch.tensor([0, 1], dtype=torch.long) # Labels typically don't need gradients ) } mock_train_loader = [(0, mock_batch_data)] # Include batch index # Mock the loss functions to return tensors that require gradients criterion_denoise = MagicMock(return_value=torch.tensor(0.5, requires_grad=True)) criterion_rotate = MagicMock(return_value=torch.tensor(0.5, requires_grad=True)) with patch.object(pretrainer, '_process_batch', return_value=torch.tensor(0.5, requires_grad=True)), \ patch.object(pretrainer, '_handle_checkpoints') as mock_handle_checkpoints: total_loss, batch_count = pretrainer._training_phase( mock_train_loader, 0, criterion_denoise, criterion_rotate) assert total_loss == 0.5 assert batch_count == 1 mock_handle_checkpoints.assert_called_once_with(1) # Check if checkpoint handling was called def test_validation_phase(): """Test the validation phase logic.""" pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) pretrainer.projection_head = MagicMock() mock_val_loader = [MagicMock()] criterion_denoise = MagicMock() criterion_rotate = MagicMock() with patch.object(pretrainer, '_validate', return_value=0.4): val_loss = pretrainer._validation_phase( mock_val_loader, criterion_denoise, criterion_rotate) assert val_loss == 0.4 @patch('pandas.read_parquet') def test_load_and_merge_datasets(mock_read_parquet): """Test dataset loading and merging.""" pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) mock_df = pd.DataFrame({'col': [1, 2, 3]}) mock_read_parquet.return_value.head.return_value = mock_df result = pretrainer._load_and_merge_datasets(['path1.parquet', 'path2.parquet'], 1000) assert len(result) == 6 # 2 datasets * 3 rows each def test_process_batch_none_tasks(): """Test processing batch with no tasks.""" pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) batch_data = { 'denoise': None, 'rotate': None } loss = pretrainer._process_batch( batch_data, criterion_denoise=MagicMock(), criterion_rotate=MagicMock() ) assert loss == 0 @patch('aiia.pretrain.pretrainer.Pretrainer._process_batch') def test_validate(mock_process_batch): pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) val_loader = [MagicMock()] criterion_denoise = MagicMock() criterion_rotate = MagicMock() mock_process_batch.return_value = torch.tensor(0.5) loss = pretrainer._validate(val_loader, criterion_denoise, criterion_rotate) assert loss == 0.5 # Test the save_losses method @patch('aiia.pretrain.pretrainer.Pretrainer.save_losses') def test_save_losses(mock_save_losses): pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig()) pretrainer.train_losses = [0.1, 0.2] pretrainer.val_losses = [0.3, 0.4] csv_file = 'losses.csv' pretrainer.save_losses(csv_file) mock_save_losses.assert_called_once_with(csv_file)