added tests for pretrainer
This commit is contained in:
parent
cd317ce0e9
commit
24a3a7bf56
|
@ -0,0 +1,115 @@
|
|||
import pytest
|
||||
import torch
|
||||
from unittest.mock import MagicMock, patch
|
||||
from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pandas as pd
|
||||
|
||||
|
||||
# Test the ProjectionHead class
|
||||
def test_projection_head():
|
||||
head = ProjectionHead(hidden_size=512)
|
||||
x = torch.randn(1, 512, 32, 32)
|
||||
|
||||
# Test denoise task
|
||||
output_denoise = head(x, task='denoise')
|
||||
assert output_denoise.shape == (1, 3, 32, 32)
|
||||
|
||||
# Test rotate task
|
||||
output_rotate = head(x, task='rotate')
|
||||
assert output_rotate.shape == (1, 4)
|
||||
|
||||
# Test the Pretrainer class initialization
|
||||
def test_pretrainer_initialization():
|
||||
config = AIIAConfig()
|
||||
model = AIIABase(config=config)
|
||||
pretrainer = Pretrainer(model=model, learning_rate=0.001, config=config)
|
||||
assert pretrainer.device in ["cuda", "cpu"]
|
||||
assert isinstance(pretrainer.projection_head, ProjectionHead)
|
||||
assert isinstance(pretrainer.optimizer, torch.optim.AdamW)
|
||||
|
||||
# Test the safe_collate method
|
||||
def test_safe_collate():
|
||||
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
|
||||
batch = [
|
||||
(torch.randn(3, 32, 32), torch.randn(3, 32, 32), 'denoise'),
|
||||
(torch.randn(3, 32, 32), torch.tensor(1), 'rotate')
|
||||
]
|
||||
|
||||
collated_batch = pretrainer.safe_collate(batch)
|
||||
assert 'denoise' in collated_batch
|
||||
assert 'rotate' in collated_batch
|
||||
|
||||
# Test the _process_batch method
|
||||
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
|
||||
def test_process_batch(mock_process_batch):
|
||||
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
|
||||
batch_data = {
|
||||
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)),
|
||||
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1]))
|
||||
}
|
||||
criterion_denoise = MagicMock()
|
||||
criterion_rotate = MagicMock()
|
||||
|
||||
mock_process_batch.return_value = 0.5
|
||||
loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
||||
assert loss == 0.5
|
||||
|
||||
@patch('pandas.concat')
|
||||
@patch('pandas.read_parquet')
|
||||
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
||||
def test_train(mock_data_loader, mock_read_parquet, mock_concat):
|
||||
# Create a real DataFrame for testing
|
||||
real_df = pd.DataFrame({
|
||||
'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]
|
||||
})
|
||||
# Configure read_parquet so that for each dataset path, .head(10000) returns real_df
|
||||
mock_read_parquet.return_value.head.return_value = real_df
|
||||
|
||||
# When merging DataFrames, bypass type-checks by letting pd.concat just return real_df
|
||||
mock_concat.return_value = real_df
|
||||
|
||||
# Create an instance of the Pretrainer
|
||||
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
|
||||
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
|
||||
|
||||
# Setup the data loader mock instance with empty loaders
|
||||
loader_instance = MagicMock()
|
||||
loader_instance.train_loader = [] # so the training loop is immediately skipped
|
||||
loader_instance.val_loader = [] # so the validation loop is also skipped
|
||||
mock_data_loader.return_value = loader_instance
|
||||
|
||||
# Patch _validate to avoid any actual validation computations.
|
||||
with patch.object(Pretrainer, '_validate', return_value=0.5):
|
||||
pretrainer.train(dataset_paths, num_epochs=1)
|
||||
|
||||
# Verify that AIIADataLoader was instantiated exactly once...
|
||||
mock_data_loader.assert_called_once()
|
||||
# ...and that pd.read_parquet was called once per dataset path (i.e. 2 times in this test)
|
||||
expected_calls = len(dataset_paths)
|
||||
assert mock_read_parquet.call_count == expected_calls, (
|
||||
f"Expected {expected_calls} calls to pd.read_parquet, got {mock_read_parquet.call_count}"
|
||||
)
|
||||
|
||||
|
||||
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
|
||||
def test_validate(mock_process_batch):
|
||||
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
|
||||
val_loader = [MagicMock()]
|
||||
criterion_denoise = MagicMock()
|
||||
criterion_rotate = MagicMock()
|
||||
|
||||
mock_process_batch.return_value = torch.tensor(0.5)
|
||||
loss = pretrainer._validate(val_loader, criterion_denoise, criterion_rotate)
|
||||
assert loss == 0.5
|
||||
|
||||
# Test the save_losses method
|
||||
@patch('aiia.pretrain.pretrainer.Pretrainer.save_losses')
|
||||
def test_save_losses(mock_save_losses):
|
||||
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
|
||||
pretrainer.train_losses = [0.1, 0.2]
|
||||
pretrainer.val_losses = [0.3, 0.4]
|
||||
|
||||
csv_file = 'losses.csv'
|
||||
pretrainer.save_losses(csv_file)
|
||||
mock_save_losses.assert_called_once_with(csv_file)
|
Loading…
Reference in New Issue