feat/energy_efficenty #38
|
@ -3,6 +3,8 @@ import torch
|
|||
from unittest.mock import MagicMock, patch, MagicMock, mock_open, call
|
||||
from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig
|
||||
import pandas as pd
|
||||
import os
|
||||
import datetime
|
||||
|
||||
# Test the ProjectionHead class
|
||||
def test_projection_head():
|
||||
|
@ -53,11 +55,94 @@ def test_process_batch(mock_process_batch):
|
|||
loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
||||
assert loss == 0.5
|
||||
|
||||
# Error cases
|
||||
# New tests for checkpoint handling
|
||||
@patch('torch.save')
|
||||
@patch('os.path.join')
|
||||
def test_save_checkpoint(mock_join, mock_save):
|
||||
"""Test checkpoint saving functionality."""
|
||||
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||
pretrainer.projection_head = MagicMock()
|
||||
pretrainer.optimizer = MagicMock()
|
||||
|
||||
checkpoint_dir = "checkpoints"
|
||||
epoch = 1
|
||||
batch_count = 100
|
||||
checkpoint_name = "test_checkpoint.pt"
|
||||
|
||||
mock_join.return_value = os.path.join(checkpoint_dir, checkpoint_name)
|
||||
|
||||
path = pretrainer._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name)
|
||||
|
||||
assert path == os.path.join(checkpoint_dir, checkpoint_name)
|
||||
mock_save.assert_called_once()
|
||||
|
||||
@patch('os.makedirs')
|
||||
@patch('os.path.exists')
|
||||
@patch('torch.load')
|
||||
def test_load_checkpoint_specific(mock_load, mock_exists, mock_makedirs):
|
||||
"""Test loading a specific checkpoint."""
|
||||
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||
pretrainer.projection_head = MagicMock()
|
||||
pretrainer.optimizer = MagicMock()
|
||||
|
||||
checkpoint_dir = "checkpoints"
|
||||
specific_checkpoint = "specific_checkpoint.pt"
|
||||
mock_exists.return_value = True
|
||||
|
||||
mock_load.return_value = {
|
||||
'epoch': 2,
|
||||
'batch': 150,
|
||||
'model_state_dict': {},
|
||||
'projection_head_state_dict': {},
|
||||
'optimizer_state_dict': {},
|
||||
'train_losses': [],
|
||||
'val_losses': []
|
||||
}
|
||||
|
||||
result = pretrainer.load_checkpoint(checkpoint_dir, specific_checkpoint)
|
||||
assert result == (2, 150)
|
||||
|
||||
@patch('os.listdir')
|
||||
@patch('os.path.getmtime')
|
||||
def test_load_checkpoint_most_recent(mock_getmtime, mock_listdir):
|
||||
"""Test loading the most recent checkpoint."""
|
||||
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||
|
||||
checkpoint_dir = "checkpoints"
|
||||
mock_listdir.return_value = ["checkpoint_1.pt", "checkpoint_2.pt"]
|
||||
mock_getmtime.side_effect = [100, 200] # checkpoint_2.pt is more recent
|
||||
|
||||
with patch.object(pretrainer, '_load_checkpoint_file', return_value=(2, 150)):
|
||||
result = pretrainer.load_checkpoint(checkpoint_dir)
|
||||
assert result == (2, 150)
|
||||
|
||||
def test_initialize_checkpoint_variables():
|
||||
"""Test initialization of checkpoint variables."""
|
||||
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||
pretrainer._initialize_checkpoint_variables()
|
||||
|
||||
assert hasattr(pretrainer, 'last_checkpoint_time')
|
||||
assert pretrainer.checkpoint_interval == 2 * 60 * 60
|
||||
assert pretrainer.last_22_date is None
|
||||
assert pretrainer.recent_checkpoints == []
|
||||
|
||||
@patch('torch.nn.MSELoss')
|
||||
@patch('torch.nn.CrossEntropyLoss')
|
||||
def test_initialize_loss_functions(mock_ce_loss, mock_mse_loss):
|
||||
"""Test loss function initialization."""
|
||||
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||
criterion_denoise, criterion_rotate, best_val_loss = pretrainer._initialize_loss_functions()
|
||||
|
||||
assert mock_mse_loss.called
|
||||
assert mock_ce_loss.called
|
||||
assert best_val_loss == float('inf')
|
||||
|
||||
@patch('pandas.concat')
|
||||
@patch('pandas.read_parquet')
|
||||
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
||||
@patch('os.path.join', return_value='mocked/path/model.pt')
|
||||
@patch('builtins.print') # Add this to mock the print function
|
||||
@patch('builtins.print')
|
||||
def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat):
|
||||
"""Test the train method under normal conditions with comprehensive verification."""
|
||||
# Setup test data and mocks
|
||||
|
@ -73,6 +158,7 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea
|
|||
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
|
||||
pretrainer.projection_head = mock_projection_head
|
||||
pretrainer.optimizer = MagicMock()
|
||||
pretrainer.checkpoint_dir = None # Initialize checkpoint_dir
|
||||
|
||||
# Setup dataset paths and mock batch data
|
||||
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
|
||||
|
@ -104,185 +190,118 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea
|
|||
assert mock_process_batch.call_count == 2
|
||||
assert mock_validate.call_count == 2
|
||||
|
||||
# Check for "Best model saved!" instead of model.save()
|
||||
mock_print.assert_any_call("Best model saved!")
|
||||
|
||||
mock_save_losses.assert_called_once()
|
||||
|
||||
# Verify state changes
|
||||
assert len(pretrainer.train_losses) == 2
|
||||
assert pretrainer.train_losses == [0.5, 0.5]
|
||||
|
||||
|
||||
# Error cases
|
||||
def test_train_no_dataset_paths():
|
||||
"""Test ValueError when no dataset paths are provided."""
|
||||
@patch('datetime.datetime')
|
||||
@patch('time.time')
|
||||
def test_handle_checkpoints(mock_time, mock_datetime):
|
||||
"""Test checkpoint handling logic."""
|
||||
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||
pretrainer.checkpoint_dir = "checkpoints"
|
||||
pretrainer.current_epoch = 1
|
||||
pretrainer._initialize_checkpoint_variables()
|
||||
|
||||
with pytest.raises(ValueError, match="No dataset paths provided"):
|
||||
pretrainer.train([])
|
||||
# Set a base time value
|
||||
base_time = 1000
|
||||
# Set the last checkpoint time to base_time
|
||||
pretrainer.last_checkpoint_time = base_time
|
||||
|
||||
@patch('pandas.read_parquet')
|
||||
def test_train_all_datasets_fail(mock_read_parquet):
|
||||
"""Test handling when all datasets fail to load."""
|
||||
mock_read_parquet.side_effect = Exception("Failed to load dataset")
|
||||
# Mock time to return base_time + interval + 1 to trigger checkpoint save
|
||||
mock_time.return_value = base_time + pretrainer.checkpoint_interval + 1
|
||||
|
||||
# Mock datetime for 22:00 checkpoint
|
||||
mock_dt = MagicMock()
|
||||
mock_dt.hour = 22
|
||||
mock_dt.minute = 0
|
||||
mock_dt.date.return_value = datetime.date(2023, 1, 1)
|
||||
mock_datetime.now.return_value = mock_dt
|
||||
|
||||
with patch.object(pretrainer, '_save_checkpoint') as mock_save:
|
||||
pretrainer._handle_checkpoints(100)
|
||||
# Should be called twice - once for regular interval and once for 22:00
|
||||
assert mock_save.call_count == 2
|
||||
|
||||
def test_training_phase():
|
||||
"""Test the training phase logic."""
|
||||
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
|
||||
|
||||
with pytest.raises(ValueError, match="No valid datasets could be loaded"):
|
||||
pretrainer.train(dataset_paths)
|
||||
|
||||
# Edge cases
|
||||
@patch('pandas.concat')
|
||||
@patch('pandas.read_parquet')
|
||||
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
||||
def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat):
|
||||
"""Test behavior with empty data loaders."""
|
||||
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
|
||||
mock_read_parquet.return_value.head.return_value = real_df
|
||||
mock_concat.return_value = real_df
|
||||
|
||||
loader_instance = MagicMock()
|
||||
loader_instance.train_loader = [] # Empty train loader
|
||||
loader_instance.val_loader = [] # Empty val loader
|
||||
mock_data_loader.return_value = loader_instance
|
||||
|
||||
mock_model = MagicMock()
|
||||
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
|
||||
pretrainer.projection_head = MagicMock()
|
||||
pretrainer.optimizer = MagicMock()
|
||||
pretrainer.checkpoint_dir = None # Initialize checkpoint_dir
|
||||
pretrainer._initialize_checkpoint_variables()
|
||||
pretrainer.current_epoch = 0
|
||||
|
||||
with patch.object(Pretrainer, 'save_losses') as mock_save_losses:
|
||||
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
|
||||
|
||||
# Verify empty loader behavior
|
||||
assert len(pretrainer.train_losses) == 1
|
||||
assert pretrainer.train_losses[0] == 0.0
|
||||
mock_save_losses.assert_called_once()
|
||||
|
||||
@patch('pandas.concat')
|
||||
@patch('pandas.read_parquet')
|
||||
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
||||
def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat):
|
||||
"""Test behavior when batch_data is None."""
|
||||
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
|
||||
mock_read_parquet.return_value.head.return_value = real_df
|
||||
mock_concat.return_value = real_df
|
||||
|
||||
loader_instance = MagicMock()
|
||||
loader_instance.train_loader = [None] # Loader returns None
|
||||
loader_instance.val_loader = []
|
||||
mock_data_loader.return_value = loader_instance
|
||||
|
||||
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||
pretrainer.projection_head = MagicMock()
|
||||
pretrainer.optimizer = MagicMock()
|
||||
|
||||
with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \
|
||||
patch.object(Pretrainer, 'save_losses'):
|
||||
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
|
||||
|
||||
# Verify None batch handling
|
||||
mock_process_batch.assert_not_called()
|
||||
assert pretrainer.train_losses[0] == 0.0
|
||||
|
||||
# Parameter variations
|
||||
@patch('pandas.concat')
|
||||
@patch('pandas.read_parquet')
|
||||
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
||||
def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_concat):
|
||||
"""Test that custom parameters are properly passed through."""
|
||||
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
|
||||
mock_read_parquet.return_value.head.return_value = real_df
|
||||
mock_concat.return_value = real_df
|
||||
|
||||
loader_instance = MagicMock()
|
||||
loader_instance.train_loader = []
|
||||
loader_instance.val_loader = []
|
||||
mock_data_loader.return_value = loader_instance
|
||||
|
||||
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||
pretrainer.projection_head = MagicMock()
|
||||
pretrainer.optimizer = MagicMock()
|
||||
|
||||
# Custom parameters
|
||||
custom_output_path = "custom/output/path"
|
||||
custom_column = "custom_column"
|
||||
custom_batch_size = 16
|
||||
custom_sample_size = 5000
|
||||
|
||||
with patch.object(Pretrainer, 'save_losses'):
|
||||
pretrainer.train(
|
||||
['path/to/dataset.parquet'],
|
||||
output_path=custom_output_path,
|
||||
column=custom_column,
|
||||
batch_size=custom_batch_size,
|
||||
sample_size=custom_sample_size
|
||||
)
|
||||
|
||||
# Verify custom parameters were used
|
||||
mock_read_parquet.return_value.head.assert_called_once_with(custom_sample_size)
|
||||
assert mock_data_loader.call_args[1]['column'] == custom_column
|
||||
assert mock_data_loader.call_args[1]['batch_size'] == custom_batch_size
|
||||
|
||||
|
||||
|
||||
@patch('pandas.concat')
|
||||
@patch('pandas.read_parquet')
|
||||
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
|
||||
@patch('builtins.print') # Add this to mock the print function
|
||||
def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat):
|
||||
"""Test that model is saved only when validation loss improves."""
|
||||
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
|
||||
mock_read_parquet.return_value.head.return_value = real_df
|
||||
mock_concat.return_value = real_df
|
||||
|
||||
# Create mock batch data with proper structure
|
||||
# Create mock batch data with requires_grad=True
|
||||
mock_batch_data = {
|
||||
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)),
|
||||
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1]))
|
||||
'denoise': (
|
||||
torch.randn(2, 3, 32, 32, requires_grad=True),
|
||||
torch.randn(2, 3, 32, 32, requires_grad=True)
|
||||
),
|
||||
'rotate': (
|
||||
torch.randn(2, 3, 32, 32, requires_grad=True),
|
||||
torch.tensor([0, 1], dtype=torch.long) # Labels typically don't need gradients
|
||||
)
|
||||
}
|
||||
mock_train_loader = [(0, mock_batch_data)] # Include batch index
|
||||
|
||||
# Mock the loss functions to return tensors that require gradients
|
||||
criterion_denoise = MagicMock(return_value=torch.tensor(0.5, requires_grad=True))
|
||||
criterion_rotate = MagicMock(return_value=torch.tensor(0.5, requires_grad=True))
|
||||
|
||||
with patch.object(pretrainer, '_process_batch', return_value=torch.tensor(0.5, requires_grad=True)), \
|
||||
patch.object(pretrainer, '_handle_checkpoints') as mock_handle_checkpoints:
|
||||
|
||||
total_loss, batch_count = pretrainer._training_phase(
|
||||
mock_train_loader, 0, criterion_denoise, criterion_rotate)
|
||||
|
||||
assert total_loss == 0.5
|
||||
assert batch_count == 1
|
||||
mock_handle_checkpoints.assert_called_once_with(1) # Check if checkpoint handling was called
|
||||
|
||||
def test_validation_phase():
|
||||
"""Test the validation phase logic."""
|
||||
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||
pretrainer.projection_head = MagicMock()
|
||||
|
||||
mock_val_loader = [MagicMock()]
|
||||
criterion_denoise = MagicMock()
|
||||
criterion_rotate = MagicMock()
|
||||
|
||||
with patch.object(pretrainer, '_validate', return_value=0.4):
|
||||
val_loss = pretrainer._validation_phase(
|
||||
mock_val_loader, criterion_denoise, criterion_rotate)
|
||||
|
||||
assert val_loss == 0.4
|
||||
|
||||
@patch('pandas.read_parquet')
|
||||
def test_load_and_merge_datasets(mock_read_parquet):
|
||||
"""Test dataset loading and merging."""
|
||||
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||
|
||||
mock_df = pd.DataFrame({'col': [1, 2, 3]})
|
||||
mock_read_parquet.return_value.head.return_value = mock_df
|
||||
|
||||
result = pretrainer._load_and_merge_datasets(['path1.parquet', 'path2.parquet'], 1000)
|
||||
assert len(result) == 6 # 2 datasets * 3 rows each
|
||||
|
||||
def test_process_batch_none_tasks():
|
||||
"""Test processing batch with no tasks."""
|
||||
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
|
||||
|
||||
batch_data = {
|
||||
'denoise': None,
|
||||
'rotate': None
|
||||
}
|
||||
|
||||
loader_instance = MagicMock()
|
||||
loader_instance.train_loader = [mock_batch_data]
|
||||
loader_instance.val_loader = [mock_batch_data]
|
||||
mock_data_loader.return_value = loader_instance
|
||||
loss = pretrainer._process_batch(
|
||||
batch_data,
|
||||
criterion_denoise=MagicMock(),
|
||||
criterion_rotate=MagicMock()
|
||||
)
|
||||
|
||||
mock_model = MagicMock()
|
||||
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
|
||||
pretrainer.projection_head = MagicMock()
|
||||
pretrainer.optimizer = MagicMock()
|
||||
|
||||
# Initialize the best validation loss
|
||||
pretrainer.best_val_loss = float('inf')
|
||||
|
||||
mock_batch_loss = torch.tensor(0.5, requires_grad=True)
|
||||
|
||||
# Test improving validation loss
|
||||
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
|
||||
patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \
|
||||
patch.object(Pretrainer, 'save_losses'):
|
||||
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
|
||||
|
||||
# Check for "Best model saved!" 3 times
|
||||
assert mock_print.call_args_list.count(call("Best model saved!")) == 3
|
||||
|
||||
# Reset for next test
|
||||
mock_print.reset_mock()
|
||||
pretrainer.train_losses = []
|
||||
|
||||
# Reset best validation loss for the second test
|
||||
pretrainer.best_val_loss = float('inf')
|
||||
|
||||
# Test fluctuating validation loss
|
||||
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
|
||||
patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \
|
||||
patch.object(Pretrainer, 'save_losses'):
|
||||
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
|
||||
|
||||
# Should print "Best model saved!" only on first and third epochs
|
||||
assert mock_print.call_args_list.count(call("Best model saved!")) == 2
|
||||
assert loss == 0
|
||||
|
||||
|
||||
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')
|
||||
|
|
Loading…
Reference in New Issue