Merge pull request 'feat/added_tests' (#30) from feat/added_tests into develop

Reviewed-on: #30
This commit is contained in:
Falko Victor Habel 2025-03-16 11:26:34 +00:00
commit 64787cbffc
7 changed files with 692 additions and 41 deletions

10
.coveragerc Normal file
View File

@ -0,0 +1,10 @@
[run]
branch = True
source = src
omit =
*/tests/*
*/migrations/*
[report]
show_missing = True
fail_under = 80

3
pytest.ini Normal file
View File

@ -0,0 +1,3 @@
[pytest]
testpaths = tests/
python_files = test_*.py

View File

@ -15,7 +15,7 @@ class FilePathLoader:
self.successful_count = 0
self.skipped_count = 0
if self.file_path_column not in dataset.column_names:
if self.file_path_column not in dataset.columns:
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
def _get_image(self, item):
@ -106,7 +106,11 @@ class JPGImageLoader:
print(f"Skipped {self.skipped_count} images due to errors.")
class AIIADataLoader:
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path",
label_column=None, pretraining=False, **dataloader_kwargs):
if column not in dataset.columns:
raise ValueError(f"Column '{column}' not found in dataset")
self.batch_size = batch_size
self.val_split = val_split
self.seed = seed
@ -145,7 +149,6 @@ class AIIADataLoader:
if not self.items:
raise ValueError("No valid items were loaded from the dataset")
train_indices, val_indices = self._split_data()
self.train_dataset = self._create_subset(train_indices)
@ -192,9 +195,11 @@ class AIIADataset(torch.utils.data.Dataset):
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
# Check image dimensions before transform
if image.size[0] < 224 or image.size[1] < 224:
raise ValueError("Invalid image dimensions")
image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
if task == 'denoise':
noise_std = 0.1
@ -214,15 +219,20 @@ class AIIADataset(torch.utils.data.Dataset):
image, label = item
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
# Check image dimensions before transform
if image.size[0] < 224 or image.size[1] < 224:
raise ValueError("Invalid image dimensions")
image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
return image, label
else:
if isinstance(item, Image.Image):
image = self.transform(item)
else:
image = self.transform(item[0])
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
image = item[0] if isinstance(item, tuple) else item
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
# Check image dimensions before transform
if image.size[0] < 224 or image.size[1] < 224:
raise ValueError("Invalid image dimensions")
image = self.transform(image)
return image

View File

@ -0,0 +1,112 @@
import pytest
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import pandas as pd
import numpy as np
from aiia.data.DataLoader import FilePathLoader, JPGImageLoader, AIIADataLoader, AIIADataset
def create_sample_dataset(file_paths=None):
if file_paths is None:
file_paths = ['path/to/image1.jpg', 'path/to/image2.png']
data = {
'file_path': file_paths,
'label': [0] * len(file_paths) # Match length of labels to file_paths
}
df = pd.DataFrame(data)
return df
def create_sample_bytes_dataset(bytes_data=None):
if bytes_data is None:
bytes_data = [b'fake_image_data_1', b'fake_image_data_2']
data = {
'jpg': bytes_data,
'label': [0] * len(bytes_data) # Match length of labels to bytes_data
}
df = pd.DataFrame(data)
return df
def test_file_path_loader(mocker):
# Mock Image.open to return a fake image
mock_image = Image.new('RGB', (224, 224))
mocker.patch('PIL.Image.open', return_value=mock_image)
dataset = create_sample_dataset()
loader = FilePathLoader(dataset, label_column='label') # Added label_column
item = loader.get_item(0)
assert isinstance(item[0], Image.Image)
assert item[1] == 0
loader.print_summary()
def test_jpg_image_loader(mocker):
# Mock Image.open to return a fake image
mock_image = Image.new('RGB', (224, 224))
mocker.patch('PIL.Image.open', return_value=mock_image)
dataset = create_sample_bytes_dataset()
loader = JPGImageLoader(dataset, label_column='label') # Added label_column
item = loader.get_item(0)
assert isinstance(item[0], Image.Image)
assert item[1] == 0
loader.print_summary()
def test_aiia_data_loader(mocker):
# Mock Image.open to return a fake image
mock_image = Image.new('RGB', (224, 224))
mocker.patch('PIL.Image.open', return_value=mock_image)
dataset = create_sample_dataset()
data_loader = AIIADataLoader(dataset, batch_size=2, label_column='label')
# Test train loader
batch = next(iter(data_loader.train_loader))
assert isinstance(batch, list)
assert len(batch) == 2 # (images, labels)
assert batch[0].shape[0] == 1 # batch size
def test_aiia_dataset():
items = [(Image.new('RGB', (224, 224)), 0), (Image.new('RGB', (224, 224)), 1)]
dataset = AIIADataset(items)
assert len(dataset) == 2
item = dataset[0]
assert isinstance(item[0], torch.Tensor)
assert item[1] == 0
def test_aiia_dataset_pre_training():
items = [(Image.new('RGB', (224, 224)), 'denoise', Image.new('RGB', (224, 224)))]
dataset = AIIADataset(items, pretraining=True)
assert len(dataset) == 1
item = dataset[0]
assert isinstance(item[0], torch.Tensor)
assert isinstance(item[2], str)
def test_aiia_dataset_invalid_image():
items = [(Image.new('RGB', (50, 50)), 0)] # Create small image
dataset = AIIADataset(items)
with pytest.raises(ValueError, match="Invalid image dimensions"):
dataset[0]
def test_aiia_dataset_invalid_task():
items = [(Image.new('RGB', (224, 224)), 'invalid_task', Image.new('RGB', (224, 224)))]
dataset = AIIADataset(items, pretraining=True)
with pytest.raises(ValueError):
dataset[0]
def test_aiia_data_loader_invalid_column():
dataset = create_sample_dataset()
with pytest.raises(ValueError, match="Column 'invalid_column' not found"):
AIIADataLoader(dataset, column='invalid_column')
if __name__ == "__main__":
pytest.main(['-v'])

133
tests/model/test_aiia.py Normal file
View File

@ -0,0 +1,133 @@
import os
import torch
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIAConfig
def test_aiiabase_creation():
config = AIIAConfig()
model = AIIABase(config)
assert isinstance(model, AIIABase)
def test_aiiabase_save_load():
config = AIIAConfig()
model = AIIABase(config)
save_path = "test_aiiabase_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIABase.load(save_path)
# Check if the loaded model is an instance of AIIABase
assert isinstance(loaded_model, AIIABase)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiabase_shared_creation():
config = AIIAConfig()
model = AIIABaseShared(config)
assert isinstance(model, AIIABaseShared)
def test_aiiabase_shared_save_load():
config = AIIAConfig()
model = AIIABaseShared(config)
save_path = "test_aiiabase_shared_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIABaseShared.load(save_path)
# Check if the loaded model is an instance of AIIABaseShared
assert isinstance(loaded_model, AIIABaseShared)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiaexpert_creation():
config = AIIAConfig()
model = AIIAExpert(config)
assert isinstance(model, AIIAExpert)
def test_aiiaexpert_save_load():
config = AIIAConfig()
model = AIIAExpert(config)
save_path = "test_aiiaexpert_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIAExpert.load(save_path)
# Check if the loaded model is an instance of AIIAExpert
assert isinstance(loaded_model, AIIAExpert)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiamoe_creation():
config = AIIAConfig()
model = AIIAmoe(config, num_experts=5)
assert isinstance(model, AIIAmoe)
def test_aiiamoe_save_load():
config = AIIAConfig()
model = AIIAmoe(config, num_experts=5)
save_path = "test_aiiamoe_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIAmoe.load(save_path)
# Check if the loaded model is an instance of AIIAmoe
assert isinstance(loaded_model, AIIAmoe)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)
def test_aiiachunked_creation():
config = AIIAConfig()
model = AIIAchunked(config)
assert isinstance(model, AIIAchunked)
def test_aiiachunked_save_load():
config = AIIAConfig()
model = AIIAchunked(config)
save_path = "test_aiiachunked_save_load"
# Save the model
model.save(save_path)
assert os.path.exists(os.path.join(save_path, "model.pth"))
assert os.path.exists(os.path.join(save_path, "config.json"))
# Load the model
loaded_model = AIIAchunked.load(save_path)
# Check if the loaded model is an instance of AIIAchunked
assert isinstance(loaded_model, AIIAchunked)
# Clean up
os.remove(os.path.join(save_path, "model.pth"))
os.remove(os.path.join(save_path, "config.json"))
os.rmdir(save_path)

View File

@ -0,0 +1,75 @@
import os
import tempfile
import pytest
import torch.nn as nn
from aiia import AIIAConfig
def test_aiia_config_initialization():
config = AIIAConfig()
assert config.model_name == "AIIA"
assert config.kernel_size == 3
assert config.activation_function == "GELU"
assert config.hidden_size == 512
assert config.num_hidden_layers == 12
assert config.num_channels == 3
assert config.learning_rate == 5e-5
def test_aiia_config_custom_initialization():
config = AIIAConfig(
model_name="CustomModel",
kernel_size=5,
activation_function="ReLU",
hidden_size=1024,
num_hidden_layers=8,
num_channels=1,
learning_rate=1e-4
)
assert config.model_name == "CustomModel"
assert config.kernel_size == 5
assert config.activation_function == "ReLU"
assert config.hidden_size == 1024
assert config.num_hidden_layers == 8
assert config.num_channels == 1
assert config.learning_rate == 1e-4
def test_aiia_config_invalid_activation_function():
with pytest.raises(ValueError):
AIIAConfig(activation_function="InvalidFunction")
def test_aiia_config_to_dict():
config = AIIAConfig()
config_dict = config.to_dict()
assert isinstance(config_dict, dict)
assert config_dict["model_name"] == "AIIA"
assert config_dict["kernel_size"] == 3
def test_aiia_config_save_and_load():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel")
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
loaded_config = AIIAConfig.load(save_path)
assert loaded_config.model_name == "TempModel"
assert loaded_config.kernel_size == 3
assert loaded_config.activation_function == "GELU"
def test_aiia_config_save_and_load_with_custom_attributes():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel", custom_attr="value")
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
loaded_config = AIIAConfig.load(save_path)
assert loaded_config.model_name == "TempModel"
assert loaded_config.custom_attr == "value"
def test_aiia_config_save_and_load_with_nested_attributes():
with tempfile.TemporaryDirectory() as tmpdir:
config = AIIAConfig(model_name="TempModel", nested={"key": "value"})
save_path = os.path.join(tmpdir, "config")
config.save(save_path)
loaded_config = AIIAConfig.load(save_path)
assert loaded_config.model_name == "TempModel"
assert loaded_config.nested == {"key": "value"}

View File

@ -0,0 +1,308 @@
import pytest
import torch
from unittest.mock import MagicMock, patch, MagicMock, mock_open, call
from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig
import pandas as pd
# Test the ProjectionHead class
def test_projection_head():
head = ProjectionHead(hidden_size=512)
x = torch.randn(1, 512, 32, 32)
# Test denoise task
output_denoise = head(x, task='denoise')
assert output_denoise.shape == (1, 3, 32, 32)
# Test rotate task
output_rotate = head(x, task='rotate')
assert output_rotate.shape == (1, 4)
# Test the Pretrainer class initialization
def test_pretrainer_initialization():
config = AIIAConfig()
model = AIIABase(config=config)
pretrainer = Pretrainer(model=model, learning_rate=0.001, config=config)
assert pretrainer.device in ["cuda", "cpu"]
assert isinstance(pretrainer.projection_head, ProjectionHead)
assert isinstance(pretrainer.optimizer, torch.optim.AdamW)
# Test the safe_collate method
def test_safe_collate():
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
batch = [
(torch.randn(3, 32, 32), torch.randn(3, 32, 32), 'denoise'),
(torch.randn(3, 32, 32), torch.tensor(1), 'rotate')
]
collated_batch = pretrainer.safe_collate(batch)
assert 'denoise' in collated_batch
assert 'rotate' in collated_batch
# Test the _process_batch method
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
def test_process_batch(mock_process_batch):
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
batch_data = {
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)),
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1]))
}
criterion_denoise = MagicMock()
criterion_rotate = MagicMock()
mock_process_batch.return_value = 0.5
loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate)
assert loss == 0.5
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
@patch('os.path.join', return_value='mocked/path/model.pt')
@patch('builtins.print') # Add this to mock the print function
def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat):
"""Test the train method under normal conditions with comprehensive verification."""
# Setup test data and mocks
real_df = pd.DataFrame({
'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]
})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
# Mock the model and related components
mock_model = MagicMock()
mock_projection_head = MagicMock()
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
pretrainer.projection_head = mock_projection_head
pretrainer.optimizer = MagicMock()
# Setup dataset paths and mock batch data
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
output_path = "AIIA_test"
# Create mock batch data with proper structure
mock_batch_data = {
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)),
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1]))
}
# Configure batch loss
mock_batch_loss = torch.tensor(0.5, requires_grad=True)
loader_instance = MagicMock()
loader_instance.train_loader = [mock_batch_data]
loader_instance.val_loader = [mock_batch_data]
mock_data_loader.return_value = loader_instance
# Execute training with patched methods
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss) as mock_process_batch, \
patch.object(Pretrainer, '_validate', side_effect=[0.8, 0.3]) as mock_validate, \
patch.object(Pretrainer, 'save_losses') as mock_save_losses, \
patch('builtins.open', mock_open()):
pretrainer.train(dataset_paths, output_path=output_path, num_epochs=2)
# Verify method calls
assert mock_read_parquet.call_count == len(dataset_paths)
assert mock_process_batch.call_count == 2
assert mock_validate.call_count == 2
# Check for "Best model saved!" instead of model.save()
mock_print.assert_any_call("Best model saved!")
mock_save_losses.assert_called_once()
# Verify state changes
assert len(pretrainer.train_losses) == 2
assert pretrainer.train_losses == [0.5, 0.5]
# Error cases
def test_train_no_dataset_paths():
"""Test ValueError when no dataset paths are provided."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
with pytest.raises(ValueError, match="No dataset paths provided"):
pretrainer.train([])
@patch('pandas.read_parquet')
def test_train_all_datasets_fail(mock_read_parquet):
"""Test handling when all datasets fail to load."""
mock_read_parquet.side_effect = Exception("Failed to load dataset")
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
with pytest.raises(ValueError, match="No valid datasets could be loaded"):
pretrainer.train(dataset_paths)
# Edge cases
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat):
"""Test behavior with empty data loaders."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
loader_instance = MagicMock()
loader_instance.train_loader = [] # Empty train loader
loader_instance.val_loader = [] # Empty val loader
mock_data_loader.return_value = loader_instance
mock_model = MagicMock()
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
with patch.object(Pretrainer, 'save_losses') as mock_save_losses:
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
# Verify empty loader behavior
assert len(pretrainer.train_losses) == 1
assert pretrainer.train_losses[0] == 0.0
mock_save_losses.assert_called_once()
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat):
"""Test behavior when batch_data is None."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
loader_instance = MagicMock()
loader_instance.train_loader = [None] # Loader returns None
loader_instance.val_loader = []
mock_data_loader.return_value = loader_instance
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \
patch.object(Pretrainer, 'save_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
# Verify None batch handling
mock_process_batch.assert_not_called()
assert pretrainer.train_losses[0] == 0.0
# Parameter variations
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_concat):
"""Test that custom parameters are properly passed through."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
loader_instance = MagicMock()
loader_instance.train_loader = []
loader_instance.val_loader = []
mock_data_loader.return_value = loader_instance
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
# Custom parameters
custom_output_path = "custom/output/path"
custom_column = "custom_column"
custom_batch_size = 16
custom_sample_size = 5000
with patch.object(Pretrainer, 'save_losses'):
pretrainer.train(
['path/to/dataset.parquet'],
output_path=custom_output_path,
column=custom_column,
batch_size=custom_batch_size,
sample_size=custom_sample_size
)
# Verify custom parameters were used
mock_read_parquet.return_value.head.assert_called_once_with(custom_sample_size)
assert mock_data_loader.call_args[1]['column'] == custom_column
assert mock_data_loader.call_args[1]['batch_size'] == custom_batch_size
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
@patch('builtins.print') # Add this to mock the print function
def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat):
"""Test that model is saved only when validation loss improves."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
# Create mock batch data with proper structure
mock_batch_data = {
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)),
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1]))
}
loader_instance = MagicMock()
loader_instance.train_loader = [mock_batch_data]
loader_instance.val_loader = [mock_batch_data]
mock_data_loader.return_value = loader_instance
mock_model = MagicMock()
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
# Initialize the best validation loss
pretrainer.best_val_loss = float('inf')
mock_batch_loss = torch.tensor(0.5, requires_grad=True)
# Test improving validation loss
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \
patch.object(Pretrainer, 'save_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
# Check for "Best model saved!" 3 times
assert mock_print.call_args_list.count(call("Best model saved!")) == 3
# Reset for next test
mock_print.reset_mock()
pretrainer.train_losses = []
# Reset best validation loss for the second test
pretrainer.best_val_loss = float('inf')
# Test fluctuating validation loss
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \
patch.object(Pretrainer, 'save_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
# Should print "Best model saved!" only on first and third epochs
assert mock_print.call_args_list.count(call("Best model saved!")) == 2
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
def test_validate(mock_process_batch):
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
val_loader = [MagicMock()]
criterion_denoise = MagicMock()
criterion_rotate = MagicMock()
mock_process_batch.return_value = torch.tensor(0.5)
loss = pretrainer._validate(val_loader, criterion_denoise, criterion_rotate)
assert loss == 0.5
# Test the save_losses method
@patch('aiia.pretrain.pretrainer.Pretrainer.save_losses')
def test_save_losses(mock_save_losses):
pretrainer = Pretrainer(model=AIIABase(config=AIIAConfig()), config=AIIAConfig())
pretrainer.train_losses = [0.1, 0.2]
pretrainer.val_losses = [0.3, 0.4]
csv_file = 'losses.csv'
pretrainer.save_losses(csv_file)
mock_save_losses.assert_called_once_with(csv_file)