feat/energy_efficenty #38

Merged
Fabel merged 6 commits from feat/energy_efficenty into develop 2025-04-17 10:52:25 +00:00
3 changed files with 426 additions and 233 deletions

View File

@ -6,19 +6,15 @@ from aiia.pretrain import Pretrainer
config = AIIAConfig(model_name="AIIA-Base-512x20k")
model = AIIABase(config)
# Initialize pretrainer with the model
pretrainer = Pretrainer(model, learning_rate=1e-4, config=config)
# List of dataset paths
dataset_paths = [
"/path/to/dataset1.parquet",
"/path/to/dataset2.parquet"
]
# Set checkpoint directory
checkpoint_dir = "checkpoints/my_model"
# Start training with multiple datasets
# Start training (will automatically load checkpoint if available)
pretrainer.train(
dataset_paths=dataset_paths,
num_epochs=10,
batch_size=2,
sample_size=10000
)
dataset_paths=["path/to/dataset1.parquet", "path/to/dataset2.parquet"],
output_path="trained_models/my_model",
checkpoint_dir=checkpoint_dir,
num_epochs=10
)

View File

@ -1,6 +1,8 @@
import torch
from torch import nn
import csv
import datetime
import time
import pandas as pd
from tqdm import tqdm
from ..model.Model import AIIA
@ -112,78 +114,135 @@ class Pretrainer:
return batch_loss
def train(self, dataset_paths,output_path:str="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000):
"""
Train the model using multiple specified datasets.
def _save_checkpoint(self, checkpoint_dir, epoch, batch_count, checkpoint_name):
"""Save a model checkpoint.
Args:
dataset_paths (list): List of paths to parquet datasets
num_epochs (int): Number of training epochs
batch_size (int): Batch size for training
sample_size (int): Number of samples to use from each dataset
checkpoint_dir (str): Directory to save the checkpoint
epoch (int): Current epoch number
batch_count (int): Current batch count
checkpoint_name (str): Name for the checkpoint file
Returns:
str: Path to the saved checkpoint
"""
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
checkpoint_data = {
'epoch': epoch + 1,
'batch': batch_count,
'model_state_dict': self.model.state_dict(),
'projection_head_state_dict': self.projection_head.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_losses': self.train_losses,
'val_losses': self.val_losses,
}
torch.save(checkpoint_data, checkpoint_path)
return checkpoint_path
def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None):
"""
Check for checkpoints and load if available.
Args:
checkpoint_dir (str): Directory where checkpoints are stored
specific_checkpoint (str, optional): Specific checkpoint file to load. If None, loads the most recent.
Returns:
tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise
"""
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
# If a specific checkpoint is requested
if specific_checkpoint:
checkpoint_path = os.path.join(checkpoint_dir, specific_checkpoint)
if os.path.exists(checkpoint_path):
return self._load_checkpoint_file(checkpoint_path)
else:
print(f"Specified checkpoint {specific_checkpoint} not found.")
return None
# Find all checkpoint files
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pt")]
if not checkpoint_files:
print("No checkpoints found in directory.")
return None
# Find the most recent checkpoint
checkpoint_files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
most_recent = checkpoint_files[0]
checkpoint_path = os.path.join(checkpoint_dir, most_recent)
return self._load_checkpoint_file(checkpoint_path)
def _load_checkpoint_file(self, checkpoint_path):
"""
Load a specific checkpoint file.
Args:
checkpoint_path (str): Path to the checkpoint file
Returns:
tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise
"""
try:
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# Load model state
self.model.load_state_dict(checkpoint['model_state_dict'])
# Load projection head state
self.projection_head.load_state_dict(checkpoint['projection_head_state_dict'])
# Load optimizer state
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Load loss history
self.train_losses = checkpoint.get('train_losses', [])
self.val_losses = checkpoint.get('val_losses', [])
loaded_epoch = checkpoint['epoch']
loaded_batch = checkpoint['batch']
print(f"Checkpoint loaded from {checkpoint_path}")
print(f"Resuming from epoch {loaded_epoch}, batch {loaded_batch}")
return loaded_epoch, loaded_batch
except Exception as e:
print(f"Error loading checkpoint: {e}")
return None
def train(self, dataset_paths, output_path="AIIA", column="image_bytes",
num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None):
"""Train the model using multiple specified datasets with checkpoint resumption support."""
if not dataset_paths:
raise ValueError("No dataset paths provided")
# Read and merge all datasets
dataframes = []
for path in dataset_paths:
try:
df = pd.read_parquet(path).head(sample_size)
dataframes.append(df)
except Exception as e:
print(f"Error loading dataset {path}: {e}")
if not dataframes:
raise ValueError("No valid datasets could be loaded")
merged_df = pd.concat(dataframes, ignore_index=True)
self._initialize_checkpoint_variables()
start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir)
# Initialize data loader
aiia_loader = AIIADataLoader(
merged_df,
column=column,
batch_size=batch_size,
pretraining=True,
collate_fn=self.safe_collate
)
dataframes = self._load_and_merge_datasets(dataset_paths, sample_size)
aiia_loader = self._initialize_data_loader(dataframes, column, batch_size)
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
best_val_loss = float('inf')
criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions()
for epoch in range(num_epochs):
for epoch in range(start_epoch, num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
# Training phase
self.model.train()
self.projection_head.train()
total_train_loss = 0.0
batch_count = 0
for batch_data in tqdm(aiia_loader.train_loader):
if batch_data is None:
continue
self.optimizer.zero_grad()
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
if batch_loss > 0:
batch_loss.backward()
self.optimizer.step()
total_train_loss += batch_loss.item()
batch_count += 1
total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader,
start_batch if (epoch == start_epoch and resume_training) else 0,
criterion_denoise,
criterion_rotate)
avg_train_loss = total_train_loss / max(batch_count, 1)
self.train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase
self.model.eval()
self.projection_head.eval()
val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
val_loss = self._validation_phase(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
if val_loss < best_val_loss:
best_val_loss = val_loss
self.model.save(output_path)
@ -192,6 +251,125 @@ class Pretrainer:
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
self.save_losses(losses_path)
def _initialize_checkpoint_variables(self):
"""Initialize checkpoint tracking variables."""
self.last_checkpoint_time = time.time()
self.checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds
self.last_22_date = None
self.recent_checkpoints = []
def _load_checkpoints(self, checkpoint_dir):
"""Load checkpoints and return start epoch, batch, and resumption flag."""
start_epoch = 0
start_batch = 0
resume_training = False
if checkpoint_dir is not None:
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_info = self.load_checkpoint(checkpoint_dir)
if checkpoint_info:
start_epoch, start_batch = checkpoint_info
resume_training = True
# Adjust epoch to be 0-indexed for the loop
start_epoch -= 1
return start_epoch, start_batch, resume_training
def _load_and_merge_datasets(self, dataset_paths, sample_size):
"""Load and merge datasets."""
dataframes = []
for path in dataset_paths:
try:
df = pd.read_parquet(path).head(sample_size)
dataframes.append(df)
except Exception as e:
print(f"Error loading dataset {path}: {e}")
if not dataframes:
raise ValueError("No valid datasets could be loaded")
return pd.concat(dataframes, ignore_index=True)
def _initialize_data_loader(self, merged_df, column, batch_size):
"""Initialize the data loader."""
return AIIADataLoader(
merged_df,
column=column,
batch_size=batch_size,
pretraining=True,
collate_fn=self.safe_collate
)
def _initialize_loss_functions(self):
"""Initialize loss functions and tracking variables."""
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
best_val_loss = float('inf')
return criterion_denoise, criterion_rotate, best_val_loss
def _training_phase(self, train_loader, skip_batches, criterion_denoise, criterion_rotate):
"""Handle the training phase."""
self.model.train()
self.projection_head.train()
total_train_loss = 0.0
batch_count = 0
train_batches = list(enumerate(train_loader))
for i, batch_data in tqdm(train_batches[skip_batches:],
initial=skip_batches,
total=len(train_batches)):
if batch_data is None:
continue
current_batch = i + 1
self._handle_checkpoints(current_batch)
self.optimizer.zero_grad()
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
if batch_loss > 0:
batch_loss.backward()
self.optimizer.step()
total_train_loss += batch_loss.item()
batch_count += 1
return total_train_loss, batch_count
def _handle_checkpoints(self, current_batch):
"""Handle checkpoint saving logic."""
current_time = time.time()
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time
today = current_dt.date()
if self.checkpoint_dir and (current_time - self.last_checkpoint_time) >= self.checkpoint_interval:
checkpoint_name = f"checkpoint_epoch{self.current_epoch+1}_batch{current_batch}.pt"
checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
# Track and maintain only 3 recent checkpoints
self.recent_checkpoints.append(checkpoint_path)
if len(self.recent_checkpoints) > 3:
oldest = self.recent_checkpoints.pop(0)
if os.path.exists(oldest):
os.remove(oldest)
self.last_checkpoint_time = current_time
print(f"Checkpoint saved at {checkpoint_path}")
# Special 22:00 checkpoint (considering it's currently 10:15 PM)
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
if self.checkpoint_dir and is_22_oclock and self.last_22_date != today:
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
self.last_22_date = today
print(f"22:00 Checkpoint saved at {checkpoint_path}")
def _validation_phase(self, val_loader, criterion_denoise, criterion_rotate):
"""Handle the validation phase."""
self.model.eval()
self.projection_head.eval()
return self._validate(val_loader, criterion_denoise, criterion_rotate)
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
"""Perform validation and return average validation loss."""
val_loss = 0.0

View File

@ -3,6 +3,8 @@ import torch
from unittest.mock import MagicMock, patch, MagicMock, mock_open, call
from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig
import pandas as pd
import os
import datetime
# Test the ProjectionHead class
def test_projection_head():
@ -53,11 +55,94 @@ def test_process_batch(mock_process_batch):
loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate)
assert loss == 0.5
# Error cases
# New tests for checkpoint handling
@patch('torch.save')
@patch('os.path.join')
def test_save_checkpoint(mock_join, mock_save):
"""Test checkpoint saving functionality."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
checkpoint_dir = "checkpoints"
epoch = 1
batch_count = 100
checkpoint_name = "test_checkpoint.pt"
mock_join.return_value = os.path.join(checkpoint_dir, checkpoint_name)
path = pretrainer._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name)
assert path == os.path.join(checkpoint_dir, checkpoint_name)
mock_save.assert_called_once()
@patch('os.makedirs')
@patch('os.path.exists')
@patch('torch.load')
def test_load_checkpoint_specific(mock_load, mock_exists, mock_makedirs):
"""Test loading a specific checkpoint."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
checkpoint_dir = "checkpoints"
specific_checkpoint = "specific_checkpoint.pt"
mock_exists.return_value = True
mock_load.return_value = {
'epoch': 2,
'batch': 150,
'model_state_dict': {},
'projection_head_state_dict': {},
'optimizer_state_dict': {},
'train_losses': [],
'val_losses': []
}
result = pretrainer.load_checkpoint(checkpoint_dir, specific_checkpoint)
assert result == (2, 150)
@patch('os.listdir')
@patch('os.path.getmtime')
def test_load_checkpoint_most_recent(mock_getmtime, mock_listdir):
"""Test loading the most recent checkpoint."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
checkpoint_dir = "checkpoints"
mock_listdir.return_value = ["checkpoint_1.pt", "checkpoint_2.pt"]
mock_getmtime.side_effect = [100, 200] # checkpoint_2.pt is more recent
with patch.object(pretrainer, '_load_checkpoint_file', return_value=(2, 150)):
result = pretrainer.load_checkpoint(checkpoint_dir)
assert result == (2, 150)
def test_initialize_checkpoint_variables():
"""Test initialization of checkpoint variables."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer._initialize_checkpoint_variables()
assert hasattr(pretrainer, 'last_checkpoint_time')
assert pretrainer.checkpoint_interval == 2 * 60 * 60
assert pretrainer.last_22_date is None
assert pretrainer.recent_checkpoints == []
@patch('torch.nn.MSELoss')
@patch('torch.nn.CrossEntropyLoss')
def test_initialize_loss_functions(mock_ce_loss, mock_mse_loss):
"""Test loss function initialization."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
criterion_denoise, criterion_rotate, best_val_loss = pretrainer._initialize_loss_functions()
assert mock_mse_loss.called
assert mock_ce_loss.called
assert best_val_loss == float('inf')
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
@patch('os.path.join', return_value='mocked/path/model.pt')
@patch('builtins.print') # Add this to mock the print function
@patch('builtins.print')
def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat):
"""Test the train method under normal conditions with comprehensive verification."""
# Setup test data and mocks
@ -73,6 +158,7 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
pretrainer.projection_head = mock_projection_head
pretrainer.optimizer = MagicMock()
pretrainer.checkpoint_dir = None # Initialize checkpoint_dir
# Setup dataset paths and mock batch data
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
@ -104,185 +190,118 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea
assert mock_process_batch.call_count == 2
assert mock_validate.call_count == 2
# Check for "Best model saved!" instead of model.save()
mock_print.assert_any_call("Best model saved!")
mock_save_losses.assert_called_once()
# Verify state changes
assert len(pretrainer.train_losses) == 2
assert pretrainer.train_losses == [0.5, 0.5]
# Error cases
def test_train_no_dataset_paths():
"""Test ValueError when no dataset paths are provided."""
@patch('datetime.datetime')
@patch('time.time')
def test_handle_checkpoints(mock_time, mock_datetime):
"""Test checkpoint handling logic."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.checkpoint_dir = "checkpoints"
pretrainer.current_epoch = 1
pretrainer._initialize_checkpoint_variables()
# Set a base time value
base_time = 1000
# Set the last checkpoint time to base_time
pretrainer.last_checkpoint_time = base_time
# Mock time to return base_time + interval + 1 to trigger checkpoint save
mock_time.return_value = base_time + pretrainer.checkpoint_interval + 1
# Mock datetime for 22:00 checkpoint
mock_dt = MagicMock()
mock_dt.hour = 22
mock_dt.minute = 0
mock_dt.date.return_value = datetime.date(2023, 1, 1)
mock_datetime.now.return_value = mock_dt
with patch.object(pretrainer, '_save_checkpoint') as mock_save:
pretrainer._handle_checkpoints(100)
# Should be called twice - once for regular interval and once for 22:00
assert mock_save.call_count == 2
with pytest.raises(ValueError, match="No dataset paths provided"):
pretrainer.train([])
@patch('pandas.read_parquet')
def test_train_all_datasets_fail(mock_read_parquet):
"""Test handling when all datasets fail to load."""
mock_read_parquet.side_effect = Exception("Failed to load dataset")
def test_training_phase():
"""Test the training phase logic."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
with pytest.raises(ValueError, match="No valid datasets could be loaded"):
pretrainer.train(dataset_paths)
# Edge cases
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat):
"""Test behavior with empty data loaders."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
loader_instance = MagicMock()
loader_instance.train_loader = [] # Empty train loader
loader_instance.val_loader = [] # Empty val loader
mock_data_loader.return_value = loader_instance
mock_model = MagicMock()
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
with patch.object(Pretrainer, 'save_losses') as mock_save_losses:
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
# Verify empty loader behavior
assert len(pretrainer.train_losses) == 1
assert pretrainer.train_losses[0] == 0.0
mock_save_losses.assert_called_once()
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat):
"""Test behavior when batch_data is None."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
loader_instance = MagicMock()
loader_instance.train_loader = [None] # Loader returns None
loader_instance.val_loader = []
mock_data_loader.return_value = loader_instance
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \
patch.object(Pretrainer, 'save_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
# Verify None batch handling
mock_process_batch.assert_not_called()
assert pretrainer.train_losses[0] == 0.0
# Parameter variations
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_concat):
"""Test that custom parameters are properly passed through."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
loader_instance = MagicMock()
loader_instance.train_loader = []
loader_instance.val_loader = []
mock_data_loader.return_value = loader_instance
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
# Custom parameters
custom_output_path = "custom/output/path"
custom_column = "custom_column"
custom_batch_size = 16
custom_sample_size = 5000
with patch.object(Pretrainer, 'save_losses'):
pretrainer.train(
['path/to/dataset.parquet'],
output_path=custom_output_path,
column=custom_column,
batch_size=custom_batch_size,
sample_size=custom_sample_size
)
# Verify custom parameters were used
mock_read_parquet.return_value.head.assert_called_once_with(custom_sample_size)
assert mock_data_loader.call_args[1]['column'] == custom_column
assert mock_data_loader.call_args[1]['batch_size'] == custom_batch_size
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
@patch('builtins.print') # Add this to mock the print function
def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat):
"""Test that model is saved only when validation loss improves."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
# Create mock batch data with proper structure
pretrainer.checkpoint_dir = None # Initialize checkpoint_dir
pretrainer._initialize_checkpoint_variables()
pretrainer.current_epoch = 0
# Create mock batch data with requires_grad=True
mock_batch_data = {
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)),
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1]))
'denoise': (
torch.randn(2, 3, 32, 32, requires_grad=True),
torch.randn(2, 3, 32, 32, requires_grad=True)
),
'rotate': (
torch.randn(2, 3, 32, 32, requires_grad=True),
torch.tensor([0, 1], dtype=torch.long) # Labels typically don't need gradients
)
}
mock_train_loader = [(0, mock_batch_data)] # Include batch index
# Mock the loss functions to return tensors that require gradients
criterion_denoise = MagicMock(return_value=torch.tensor(0.5, requires_grad=True))
criterion_rotate = MagicMock(return_value=torch.tensor(0.5, requires_grad=True))
with patch.object(pretrainer, '_process_batch', return_value=torch.tensor(0.5, requires_grad=True)), \
patch.object(pretrainer, '_handle_checkpoints') as mock_handle_checkpoints:
total_loss, batch_count = pretrainer._training_phase(
mock_train_loader, 0, criterion_denoise, criterion_rotate)
assert total_loss == 0.5
assert batch_count == 1
mock_handle_checkpoints.assert_called_once_with(1) # Check if checkpoint handling was called
loader_instance = MagicMock()
loader_instance.train_loader = [mock_batch_data]
loader_instance.val_loader = [mock_batch_data]
mock_data_loader.return_value = loader_instance
mock_model = MagicMock()
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
def test_validation_phase():
"""Test the validation phase logic."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
mock_val_loader = [MagicMock()]
criterion_denoise = MagicMock()
criterion_rotate = MagicMock()
with patch.object(pretrainer, '_validate', return_value=0.4):
val_loss = pretrainer._validation_phase(
mock_val_loader, criterion_denoise, criterion_rotate)
assert val_loss == 0.4
# Initialize the best validation loss
pretrainer.best_val_loss = float('inf')
@patch('pandas.read_parquet')
def test_load_and_merge_datasets(mock_read_parquet):
"""Test dataset loading and merging."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
mock_df = pd.DataFrame({'col': [1, 2, 3]})
mock_read_parquet.return_value.head.return_value = mock_df
result = pretrainer._load_and_merge_datasets(['path1.parquet', 'path2.parquet'], 1000)
assert len(result) == 6 # 2 datasets * 3 rows each
mock_batch_loss = torch.tensor(0.5, requires_grad=True)
# Test improving validation loss
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \
patch.object(Pretrainer, 'save_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
# Check for "Best model saved!" 3 times
assert mock_print.call_args_list.count(call("Best model saved!")) == 3
# Reset for next test
mock_print.reset_mock()
pretrainer.train_losses = []
# Reset best validation loss for the second test
pretrainer.best_val_loss = float('inf')
# Test fluctuating validation loss
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \
patch.object(Pretrainer, 'save_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
# Should print "Best model saved!" only on first and third epochs
assert mock_print.call_args_list.count(call("Best model saved!")) == 2
def test_process_batch_none_tasks():
"""Test processing batch with no tasks."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
batch_data = {
'denoise': None,
'rotate': None
}
loss = pretrainer._process_batch(
batch_data,
criterion_denoise=MagicMock(),
criterion_rotate=MagicMock()
)
assert loss == 0
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')