115 lines
4.5 KiB
Python
115 lines
4.5 KiB
Python
import pytest
|
|
import torch
|
|
from unittest.mock import MagicMock, patch
|
|
from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig
|
|
from unittest.mock import patch, MagicMock
|
|
import pandas as pd
|
|
|
|
|
|
# Test the ProjectionHead class
|
|
def test_projection_head():
|
|
head = ProjectionHead(hidden_size=512)
|
|
x = torch.randn(1, 512, 32, 32)
|
|
|
|
# Test denoise task
|
|
output_denoise = head(x, task='denoise')
|
|
assert output_denoise.shape == (1, 3, 32, 32)
|
|
|
|
# Test rotate task
|
|
output_rotate = head(x, task='rotate')
|
|
assert output_rotate.shape == (1, 4)
|
|
|
|
# Test the Pretrainer class initialization
|
|
def test_pretrainer_initialization():
|
|
config = AIIAConfig()
|
|
model = AIIABase(config=config)
|
|
pretrainer = Pretrainer(model=model, learning_rate=0.001, config=config)
|
|
assert pretrainer.device in ["cuda", "cpu"]
|
|
assert isinstance(pretrainer.projection_head, ProjectionHead)
|
|
assert isinstance(pretrainer.optimizer, torch.optim.AdamW)
|
|
|
|
# Test the safe_collate method
|
|
def test_safe_collate():
|
|
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
|
|
batch = [
|
|
(torch.randn(3, 32, 32), torch.randn(3, 32, 32), 'denoise'),
|
|
(torch.randn(3, 32, 32), torch.tensor(1), 'rotate')
|
|
]
|
|
|
|
collated_batch = pretrainer.safe_collate(batch)
|
|
assert 'denoise' in collated_batch
|
|
assert 'rotate' in collated_batch
|
|
|
|
# Test the _process_batch method
|
|
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
|
|
def test_process_batch(mock_process_batch):
|
|
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
|
|
batch_data = {
|
|
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)),
|
|
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1]))
|
|
}
|
|
criterion_denoise = MagicMock()
|
|
criterion_rotate = MagicMock()
|
|
|
|
mock_process_batch.return_value = 0.5
|
|
loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
|
assert loss == 0.5
|
|
|
|
@patch('pandas.concat')
|
|
@patch('pandas.read_parquet')
|
|
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
|
def test_train(mock_data_loader, mock_read_parquet, mock_concat):
|
|
# Create a real DataFrame for testing
|
|
real_df = pd.DataFrame({
|
|
'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]
|
|
})
|
|
# Configure read_parquet so that for each dataset path, .head(10000) returns real_df
|
|
mock_read_parquet.return_value.head.return_value = real_df
|
|
|
|
# When merging DataFrames, bypass type-checks by letting pd.concat just return real_df
|
|
mock_concat.return_value = real_df
|
|
|
|
# Create an instance of the Pretrainer
|
|
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
|
|
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
|
|
|
|
# Setup the data loader mock instance with empty loaders
|
|
loader_instance = MagicMock()
|
|
loader_instance.train_loader = [] # so the training loop is immediately skipped
|
|
loader_instance.val_loader = [] # so the validation loop is also skipped
|
|
mock_data_loader.return_value = loader_instance
|
|
|
|
# Patch _validate to avoid any actual validation computations.
|
|
with patch.object(Pretrainer, '_validate', return_value=0.5):
|
|
pretrainer.train(dataset_paths, num_epochs=1)
|
|
|
|
# Verify that AIIADataLoader was instantiated exactly once...
|
|
mock_data_loader.assert_called_once()
|
|
# ...and that pd.read_parquet was called once per dataset path (i.e. 2 times in this test)
|
|
expected_calls = len(dataset_paths)
|
|
assert mock_read_parquet.call_count == expected_calls, (
|
|
f"Expected {expected_calls} calls to pd.read_parquet, got {mock_read_parquet.call_count}"
|
|
)
|
|
|
|
|
|
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
|
|
def test_validate(mock_process_batch):
|
|
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
|
|
val_loader = [MagicMock()]
|
|
criterion_denoise = MagicMock()
|
|
criterion_rotate = MagicMock()
|
|
|
|
mock_process_batch.return_value = torch.tensor(0.5)
|
|
loss = pretrainer._validate(val_loader, criterion_denoise, criterion_rotate)
|
|
assert loss == 0.5
|
|
|
|
# Test the save_losses method
|
|
@patch('aiia.pretrain.pretrainer.Pretrainer.save_losses')
|
|
def test_save_losses(mock_save_losses):
|
|
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
|
|
pretrainer.train_losses = [0.1, 0.2]
|
|
pretrainer.val_losses = [0.3, 0.4]
|
|
|
|
csv_file = 'losses.csv'
|
|
pretrainer.save_losses(csv_file)
|
|
mock_save_losses.assert_called_once_with(csv_file) |