fixed tests
This commit is contained in:
parent
55f00b7906
commit
47d3ee89a6
|
@ -1,11 +1,9 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch, MagicMock, mock_open, call
|
||||||
from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig
|
from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
# Test the ProjectionHead class
|
# Test the ProjectionHead class
|
||||||
def test_projection_head():
|
def test_projection_head():
|
||||||
head = ProjectionHead(hidden_size=512)
|
head = ProjectionHead(hidden_size=512)
|
||||||
|
@ -58,38 +56,233 @@ def test_process_batch(mock_process_batch):
|
||||||
@patch('pandas.concat')
|
@patch('pandas.concat')
|
||||||
@patch('pandas.read_parquet')
|
@patch('pandas.read_parquet')
|
||||||
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
||||||
def test_train(mock_data_loader, mock_read_parquet, mock_concat):
|
@patch('os.path.join', return_value='mocked/path/model.pt')
|
||||||
# Create a real DataFrame for testing
|
@patch('builtins.print') # Add this to mock the print function
|
||||||
|
def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat):
|
||||||
|
"""Test the train method under normal conditions with comprehensive verification."""
|
||||||
|
# Setup test data and mocks
|
||||||
real_df = pd.DataFrame({
|
real_df = pd.DataFrame({
|
||||||
'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]
|
'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]
|
||||||
})
|
})
|
||||||
# Configure read_parquet so that for each dataset path, .head(10000) returns real_df
|
|
||||||
mock_read_parquet.return_value.head.return_value = real_df
|
mock_read_parquet.return_value.head.return_value = real_df
|
||||||
|
|
||||||
# When merging DataFrames, bypass type-checks by letting pd.concat just return real_df
|
|
||||||
mock_concat.return_value = real_df
|
mock_concat.return_value = real_df
|
||||||
|
|
||||||
# Create an instance of the Pretrainer
|
# Mock the model and related components
|
||||||
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
|
mock_model = MagicMock()
|
||||||
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
|
mock_projection_head = MagicMock()
|
||||||
|
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
|
||||||
|
pretrainer.projection_head = mock_projection_head
|
||||||
|
pretrainer.optimizer = MagicMock()
|
||||||
|
|
||||||
# Setup the data loader mock instance with empty loaders
|
# Setup dataset paths and mock batch data
|
||||||
|
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
|
||||||
|
output_path = "AIIA_test"
|
||||||
|
|
||||||
|
# Create mock batch data with proper structure
|
||||||
|
mock_batch_data = {
|
||||||
|
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)),
|
||||||
|
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1]))
|
||||||
|
}
|
||||||
|
|
||||||
|
# Configure batch loss
|
||||||
|
mock_batch_loss = torch.tensor(0.5, requires_grad=True)
|
||||||
loader_instance = MagicMock()
|
loader_instance = MagicMock()
|
||||||
loader_instance.train_loader = [] # so the training loop is immediately skipped
|
loader_instance.train_loader = [mock_batch_data]
|
||||||
loader_instance.val_loader = [] # so the validation loop is also skipped
|
loader_instance.val_loader = [mock_batch_data]
|
||||||
mock_data_loader.return_value = loader_instance
|
mock_data_loader.return_value = loader_instance
|
||||||
|
|
||||||
# Patch _validate to avoid any actual validation computations.
|
# Execute training with patched methods
|
||||||
with patch.object(Pretrainer, '_validate', return_value=0.5):
|
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss) as mock_process_batch, \
|
||||||
pretrainer.train(dataset_paths, num_epochs=1)
|
patch.object(Pretrainer, '_validate', side_effect=[0.8, 0.3]) as mock_validate, \
|
||||||
|
patch.object(Pretrainer, 'save_losses') as mock_save_losses, \
|
||||||
|
patch('builtins.open', mock_open()):
|
||||||
|
|
||||||
# Verify that AIIADataLoader was instantiated exactly once...
|
pretrainer.train(dataset_paths, output_path=output_path, num_epochs=2)
|
||||||
mock_data_loader.assert_called_once()
|
|
||||||
# ...and that pd.read_parquet was called once per dataset path (i.e. 2 times in this test)
|
# Verify method calls
|
||||||
expected_calls = len(dataset_paths)
|
assert mock_read_parquet.call_count == len(dataset_paths)
|
||||||
assert mock_read_parquet.call_count == expected_calls, (
|
assert mock_process_batch.call_count == 2
|
||||||
f"Expected {expected_calls} calls to pd.read_parquet, got {mock_read_parquet.call_count}"
|
assert mock_validate.call_count == 2
|
||||||
)
|
|
||||||
|
# Check for "Best model saved!" instead of model.save()
|
||||||
|
mock_print.assert_any_call("Best model saved!")
|
||||||
|
|
||||||
|
mock_save_losses.assert_called_once()
|
||||||
|
|
||||||
|
# Verify state changes
|
||||||
|
assert len(pretrainer.train_losses) == 2
|
||||||
|
assert pretrainer.train_losses == [0.5, 0.5]
|
||||||
|
|
||||||
|
|
||||||
|
# Error cases
|
||||||
|
def test_train_no_dataset_paths():
|
||||||
|
"""Test ValueError when no dataset paths are provided."""
|
||||||
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No dataset paths provided"):
|
||||||
|
pretrainer.train([])
|
||||||
|
|
||||||
|
@patch('pandas.read_parquet')
|
||||||
|
def test_train_all_datasets_fail(mock_read_parquet):
|
||||||
|
"""Test handling when all datasets fail to load."""
|
||||||
|
mock_read_parquet.side_effect = Exception("Failed to load dataset")
|
||||||
|
|
||||||
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
|
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No valid datasets could be loaded"):
|
||||||
|
pretrainer.train(dataset_paths)
|
||||||
|
|
||||||
|
# Edge cases
|
||||||
|
@patch('pandas.concat')
|
||||||
|
@patch('pandas.read_parquet')
|
||||||
|
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
||||||
|
def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat):
|
||||||
|
"""Test behavior with empty data loaders."""
|
||||||
|
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
|
||||||
|
mock_read_parquet.return_value.head.return_value = real_df
|
||||||
|
mock_concat.return_value = real_df
|
||||||
|
|
||||||
|
loader_instance = MagicMock()
|
||||||
|
loader_instance.train_loader = [] # Empty train loader
|
||||||
|
loader_instance.val_loader = [] # Empty val loader
|
||||||
|
mock_data_loader.return_value = loader_instance
|
||||||
|
|
||||||
|
mock_model = MagicMock()
|
||||||
|
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
|
||||||
|
pretrainer.projection_head = MagicMock()
|
||||||
|
pretrainer.optimizer = MagicMock()
|
||||||
|
|
||||||
|
with patch.object(Pretrainer, 'save_losses') as mock_save_losses:
|
||||||
|
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
|
||||||
|
|
||||||
|
# Verify empty loader behavior
|
||||||
|
assert len(pretrainer.train_losses) == 1
|
||||||
|
assert pretrainer.train_losses[0] == 0.0
|
||||||
|
mock_save_losses.assert_called_once()
|
||||||
|
|
||||||
|
@patch('pandas.concat')
|
||||||
|
@patch('pandas.read_parquet')
|
||||||
|
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
||||||
|
def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat):
|
||||||
|
"""Test behavior when batch_data is None."""
|
||||||
|
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
|
||||||
|
mock_read_parquet.return_value.head.return_value = real_df
|
||||||
|
mock_concat.return_value = real_df
|
||||||
|
|
||||||
|
loader_instance = MagicMock()
|
||||||
|
loader_instance.train_loader = [None] # Loader returns None
|
||||||
|
loader_instance.val_loader = []
|
||||||
|
mock_data_loader.return_value = loader_instance
|
||||||
|
|
||||||
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
|
pretrainer.projection_head = MagicMock()
|
||||||
|
pretrainer.optimizer = MagicMock()
|
||||||
|
|
||||||
|
with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \
|
||||||
|
patch.object(Pretrainer, 'save_losses'):
|
||||||
|
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
|
||||||
|
|
||||||
|
# Verify None batch handling
|
||||||
|
mock_process_batch.assert_not_called()
|
||||||
|
assert pretrainer.train_losses[0] == 0.0
|
||||||
|
|
||||||
|
# Parameter variations
|
||||||
|
@patch('pandas.concat')
|
||||||
|
@patch('pandas.read_parquet')
|
||||||
|
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
||||||
|
def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_concat):
|
||||||
|
"""Test that custom parameters are properly passed through."""
|
||||||
|
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
|
||||||
|
mock_read_parquet.return_value.head.return_value = real_df
|
||||||
|
mock_concat.return_value = real_df
|
||||||
|
|
||||||
|
loader_instance = MagicMock()
|
||||||
|
loader_instance.train_loader = []
|
||||||
|
loader_instance.val_loader = []
|
||||||
|
mock_data_loader.return_value = loader_instance
|
||||||
|
|
||||||
|
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||||
|
pretrainer.projection_head = MagicMock()
|
||||||
|
pretrainer.optimizer = MagicMock()
|
||||||
|
|
||||||
|
# Custom parameters
|
||||||
|
custom_output_path = "custom/output/path"
|
||||||
|
custom_column = "custom_column"
|
||||||
|
custom_batch_size = 16
|
||||||
|
custom_sample_size = 5000
|
||||||
|
|
||||||
|
with patch.object(Pretrainer, 'save_losses'):
|
||||||
|
pretrainer.train(
|
||||||
|
['path/to/dataset.parquet'],
|
||||||
|
output_path=custom_output_path,
|
||||||
|
column=custom_column,
|
||||||
|
batch_size=custom_batch_size,
|
||||||
|
sample_size=custom_sample_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify custom parameters were used
|
||||||
|
mock_read_parquet.return_value.head.assert_called_once_with(custom_sample_size)
|
||||||
|
assert mock_data_loader.call_args[1]['column'] == custom_column
|
||||||
|
assert mock_data_loader.call_args[1]['batch_size'] == custom_batch_size
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@patch('pandas.concat')
|
||||||
|
@patch('pandas.read_parquet')
|
||||||
|
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
||||||
|
@patch('builtins.print') # Add this to mock the print function
|
||||||
|
def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat):
|
||||||
|
"""Test that model is saved only when validation loss improves."""
|
||||||
|
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
|
||||||
|
mock_read_parquet.return_value.head.return_value = real_df
|
||||||
|
mock_concat.return_value = real_df
|
||||||
|
|
||||||
|
# Create mock batch data with proper structure
|
||||||
|
mock_batch_data = {
|
||||||
|
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)),
|
||||||
|
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1]))
|
||||||
|
}
|
||||||
|
|
||||||
|
loader_instance = MagicMock()
|
||||||
|
loader_instance.train_loader = [mock_batch_data]
|
||||||
|
loader_instance.val_loader = [mock_batch_data]
|
||||||
|
mock_data_loader.return_value = loader_instance
|
||||||
|
|
||||||
|
mock_model = MagicMock()
|
||||||
|
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
|
||||||
|
pretrainer.projection_head = MagicMock()
|
||||||
|
pretrainer.optimizer = MagicMock()
|
||||||
|
|
||||||
|
# Initialize the best validation loss
|
||||||
|
pretrainer.best_val_loss = float('inf')
|
||||||
|
|
||||||
|
mock_batch_loss = torch.tensor(0.5, requires_grad=True)
|
||||||
|
|
||||||
|
# Test improving validation loss
|
||||||
|
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
|
||||||
|
patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \
|
||||||
|
patch.object(Pretrainer, 'save_losses'):
|
||||||
|
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
|
||||||
|
|
||||||
|
# Check for "Best model saved!" 3 times
|
||||||
|
assert mock_print.call_args_list.count(call("Best model saved!")) == 3
|
||||||
|
|
||||||
|
# Reset for next test
|
||||||
|
mock_print.reset_mock()
|
||||||
|
pretrainer.train_losses = []
|
||||||
|
|
||||||
|
# Reset best validation loss for the second test
|
||||||
|
pretrainer.best_val_loss = float('inf')
|
||||||
|
|
||||||
|
# Test fluctuating validation loss
|
||||||
|
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
|
||||||
|
patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \
|
||||||
|
patch.object(Pretrainer, 'save_losses'):
|
||||||
|
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
|
||||||
|
|
||||||
|
# Should print "Best model saved!" only on first and third epochs
|
||||||
|
assert mock_print.call_args_list.count(call("Best model saved!")) == 2
|
||||||
|
|
||||||
|
|
||||||
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
|
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
|
||||||
|
|
Loading…
Reference in New Issue