feat/energy_efficenty #38

Merged
Fabel merged 6 commits from feat/energy_efficenty into develop 2025-04-17 10:52:25 +00:00
3 changed files with 426 additions and 233 deletions

View File

@ -6,19 +6,15 @@ from aiia.pretrain import Pretrainer
config = AIIAConfig(model_name="AIIA-Base-512x20k") config = AIIAConfig(model_name="AIIA-Base-512x20k")
model = AIIABase(config) model = AIIABase(config)
# Initialize pretrainer with the model
pretrainer = Pretrainer(model, learning_rate=1e-4, config=config) pretrainer = Pretrainer(model, learning_rate=1e-4, config=config)
# List of dataset paths # Set checkpoint directory
dataset_paths = [ checkpoint_dir = "checkpoints/my_model"
"/path/to/dataset1.parquet",
"/path/to/dataset2.parquet"
]
# Start training with multiple datasets # Start training (will automatically load checkpoint if available)
pretrainer.train( pretrainer.train(
dataset_paths=dataset_paths, dataset_paths=["path/to/dataset1.parquet", "path/to/dataset2.parquet"],
num_epochs=10, output_path="trained_models/my_model",
batch_size=2, checkpoint_dir=checkpoint_dir,
sample_size=10000 num_epochs=10
) )

View File

@ -1,6 +1,8 @@
import torch import torch
from torch import nn from torch import nn
import csv import csv
import datetime
import time
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from ..model.Model import AIIA from ..model.Model import AIIA
@ -112,78 +114,135 @@ class Pretrainer:
return batch_loss return batch_loss
def train(self, dataset_paths,output_path:str="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000): def _save_checkpoint(self, checkpoint_dir, epoch, batch_count, checkpoint_name):
""" """Save a model checkpoint.
Train the model using multiple specified datasets.
Args: Args:
dataset_paths (list): List of paths to parquet datasets checkpoint_dir (str): Directory to save the checkpoint
num_epochs (int): Number of training epochs epoch (int): Current epoch number
batch_size (int): Batch size for training batch_count (int): Current batch count
sample_size (int): Number of samples to use from each dataset checkpoint_name (str): Name for the checkpoint file
Returns:
str: Path to the saved checkpoint
""" """
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
checkpoint_data = {
'epoch': epoch + 1,
'batch': batch_count,
'model_state_dict': self.model.state_dict(),
'projection_head_state_dict': self.projection_head.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_losses': self.train_losses,
'val_losses': self.val_losses,
}
torch.save(checkpoint_data, checkpoint_path)
return checkpoint_path
def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None):
"""
Check for checkpoints and load if available.
Args:
checkpoint_dir (str): Directory where checkpoints are stored
specific_checkpoint (str, optional): Specific checkpoint file to load. If None, loads the most recent.
Returns:
tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise
"""
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
# If a specific checkpoint is requested
if specific_checkpoint:
checkpoint_path = os.path.join(checkpoint_dir, specific_checkpoint)
if os.path.exists(checkpoint_path):
return self._load_checkpoint_file(checkpoint_path)
else:
print(f"Specified checkpoint {specific_checkpoint} not found.")
return None
# Find all checkpoint files
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pt")]
if not checkpoint_files:
print("No checkpoints found in directory.")
return None
# Find the most recent checkpoint
checkpoint_files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
most_recent = checkpoint_files[0]
checkpoint_path = os.path.join(checkpoint_dir, most_recent)
return self._load_checkpoint_file(checkpoint_path)
def _load_checkpoint_file(self, checkpoint_path):
"""
Load a specific checkpoint file.
Args:
checkpoint_path (str): Path to the checkpoint file
Returns:
tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise
"""
try:
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# Load model state
self.model.load_state_dict(checkpoint['model_state_dict'])
# Load projection head state
self.projection_head.load_state_dict(checkpoint['projection_head_state_dict'])
# Load optimizer state
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Load loss history
self.train_losses = checkpoint.get('train_losses', [])
self.val_losses = checkpoint.get('val_losses', [])
loaded_epoch = checkpoint['epoch']
loaded_batch = checkpoint['batch']
print(f"Checkpoint loaded from {checkpoint_path}")
print(f"Resuming from epoch {loaded_epoch}, batch {loaded_batch}")
return loaded_epoch, loaded_batch
except Exception as e:
print(f"Error loading checkpoint: {e}")
return None
def train(self, dataset_paths, output_path="AIIA", column="image_bytes",
num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None):
"""Train the model using multiple specified datasets with checkpoint resumption support."""
if not dataset_paths: if not dataset_paths:
raise ValueError("No dataset paths provided") raise ValueError("No dataset paths provided")
# Read and merge all datasets self._initialize_checkpoint_variables()
dataframes = [] start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir)
for path in dataset_paths:
try:
df = pd.read_parquet(path).head(sample_size)
dataframes.append(df)
except Exception as e:
print(f"Error loading dataset {path}: {e}")
if not dataframes:
raise ValueError("No valid datasets could be loaded")
merged_df = pd.concat(dataframes, ignore_index=True)
# Initialize data loader dataframes = self._load_and_merge_datasets(dataset_paths, sample_size)
aiia_loader = AIIADataLoader( aiia_loader = self._initialize_data_loader(dataframes, column, batch_size)
merged_df,
column=column,
batch_size=batch_size,
pretraining=True,
collate_fn=self.safe_collate
)
criterion_denoise = nn.MSELoss() criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions()
criterion_rotate = nn.CrossEntropyLoss()
best_val_loss = float('inf')
for epoch in range(num_epochs): for epoch in range(start_epoch, num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}") print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20) print("-" * 20)
total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader,
# Training phase start_batch if (epoch == start_epoch and resume_training) else 0,
self.model.train() criterion_denoise,
self.projection_head.train() criterion_rotate)
total_train_loss = 0.0
batch_count = 0
for batch_data in tqdm(aiia_loader.train_loader):
if batch_data is None:
continue
self.optimizer.zero_grad()
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
if batch_loss > 0:
batch_loss.backward()
self.optimizer.step()
total_train_loss += batch_loss.item()
batch_count += 1
avg_train_loss = total_train_loss / max(batch_count, 1) avg_train_loss = total_train_loss / max(batch_count, 1)
self.train_losses.append(avg_train_loss) self.train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}") print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase val_loss = self._validation_phase(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
self.model.eval()
self.projection_head.eval()
val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
if val_loss < best_val_loss: if val_loss < best_val_loss:
best_val_loss = val_loss best_val_loss = val_loss
self.model.save(output_path) self.model.save(output_path)
@ -192,6 +251,125 @@ class Pretrainer:
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
self.save_losses(losses_path) self.save_losses(losses_path)
def _initialize_checkpoint_variables(self):
"""Initialize checkpoint tracking variables."""
self.last_checkpoint_time = time.time()
self.checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds
self.last_22_date = None
self.recent_checkpoints = []
def _load_checkpoints(self, checkpoint_dir):
"""Load checkpoints and return start epoch, batch, and resumption flag."""
start_epoch = 0
start_batch = 0
resume_training = False
if checkpoint_dir is not None:
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_info = self.load_checkpoint(checkpoint_dir)
if checkpoint_info:
start_epoch, start_batch = checkpoint_info
resume_training = True
# Adjust epoch to be 0-indexed for the loop
start_epoch -= 1
return start_epoch, start_batch, resume_training
def _load_and_merge_datasets(self, dataset_paths, sample_size):
"""Load and merge datasets."""
dataframes = []
for path in dataset_paths:
try:
df = pd.read_parquet(path).head(sample_size)
dataframes.append(df)
except Exception as e:
print(f"Error loading dataset {path}: {e}")
if not dataframes:
raise ValueError("No valid datasets could be loaded")
return pd.concat(dataframes, ignore_index=True)
def _initialize_data_loader(self, merged_df, column, batch_size):
"""Initialize the data loader."""
return AIIADataLoader(
merged_df,
column=column,
batch_size=batch_size,
pretraining=True,
collate_fn=self.safe_collate
)
def _initialize_loss_functions(self):
"""Initialize loss functions and tracking variables."""
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
best_val_loss = float('inf')
return criterion_denoise, criterion_rotate, best_val_loss
def _training_phase(self, train_loader, skip_batches, criterion_denoise, criterion_rotate):
"""Handle the training phase."""
self.model.train()
self.projection_head.train()
total_train_loss = 0.0
batch_count = 0
train_batches = list(enumerate(train_loader))
for i, batch_data in tqdm(train_batches[skip_batches:],
initial=skip_batches,
total=len(train_batches)):
if batch_data is None:
continue
current_batch = i + 1
self._handle_checkpoints(current_batch)
self.optimizer.zero_grad()
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
if batch_loss > 0:
batch_loss.backward()
self.optimizer.step()
total_train_loss += batch_loss.item()
batch_count += 1
return total_train_loss, batch_count
def _handle_checkpoints(self, current_batch):
"""Handle checkpoint saving logic."""
current_time = time.time()
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time
today = current_dt.date()
if self.checkpoint_dir and (current_time - self.last_checkpoint_time) >= self.checkpoint_interval:
checkpoint_name = f"checkpoint_epoch{self.current_epoch+1}_batch{current_batch}.pt"
checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
# Track and maintain only 3 recent checkpoints
self.recent_checkpoints.append(checkpoint_path)
if len(self.recent_checkpoints) > 3:
oldest = self.recent_checkpoints.pop(0)
if os.path.exists(oldest):
os.remove(oldest)
self.last_checkpoint_time = current_time
print(f"Checkpoint saved at {checkpoint_path}")
# Special 22:00 checkpoint (considering it's currently 10:15 PM)
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
if self.checkpoint_dir and is_22_oclock and self.last_22_date != today:
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
self.last_22_date = today
print(f"22:00 Checkpoint saved at {checkpoint_path}")
def _validation_phase(self, val_loader, criterion_denoise, criterion_rotate):
"""Handle the validation phase."""
self.model.eval()
self.projection_head.eval()
return self._validate(val_loader, criterion_denoise, criterion_rotate)
def _validate(self, val_loader, criterion_denoise, criterion_rotate): def _validate(self, val_loader, criterion_denoise, criterion_rotate):
"""Perform validation and return average validation loss.""" """Perform validation and return average validation loss."""
val_loss = 0.0 val_loss = 0.0

View File

@ -3,6 +3,8 @@ import torch
from unittest.mock import MagicMock, patch, MagicMock, mock_open, call from unittest.mock import MagicMock, patch, MagicMock, mock_open, call
from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig from aiia import Pretrainer, ProjectionHead, AIIABase, AIIAConfig
import pandas as pd import pandas as pd
import os
import datetime
# Test the ProjectionHead class # Test the ProjectionHead class
def test_projection_head(): def test_projection_head():
@ -53,11 +55,94 @@ def test_process_batch(mock_process_batch):
loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate) loss = pretrainer._process_batch(batch_data, criterion_denoise, criterion_rotate)
assert loss == 0.5 assert loss == 0.5
# Error cases
# New tests for checkpoint handling
@patch('torch.save')
@patch('os.path.join')
def test_save_checkpoint(mock_join, mock_save):
"""Test checkpoint saving functionality."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
checkpoint_dir = "checkpoints"
epoch = 1
batch_count = 100
checkpoint_name = "test_checkpoint.pt"
mock_join.return_value = os.path.join(checkpoint_dir, checkpoint_name)
path = pretrainer._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name)
assert path == os.path.join(checkpoint_dir, checkpoint_name)
mock_save.assert_called_once()
@patch('os.makedirs')
@patch('os.path.exists')
@patch('torch.load')
def test_load_checkpoint_specific(mock_load, mock_exists, mock_makedirs):
"""Test loading a specific checkpoint."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
checkpoint_dir = "checkpoints"
specific_checkpoint = "specific_checkpoint.pt"
mock_exists.return_value = True
mock_load.return_value = {
'epoch': 2,
'batch': 150,
'model_state_dict': {},
'projection_head_state_dict': {},
'optimizer_state_dict': {},
'train_losses': [],
'val_losses': []
}
result = pretrainer.load_checkpoint(checkpoint_dir, specific_checkpoint)
assert result == (2, 150)
@patch('os.listdir')
@patch('os.path.getmtime')
def test_load_checkpoint_most_recent(mock_getmtime, mock_listdir):
"""Test loading the most recent checkpoint."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
checkpoint_dir = "checkpoints"
mock_listdir.return_value = ["checkpoint_1.pt", "checkpoint_2.pt"]
mock_getmtime.side_effect = [100, 200] # checkpoint_2.pt is more recent
with patch.object(pretrainer, '_load_checkpoint_file', return_value=(2, 150)):
result = pretrainer.load_checkpoint(checkpoint_dir)
assert result == (2, 150)
def test_initialize_checkpoint_variables():
"""Test initialization of checkpoint variables."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer._initialize_checkpoint_variables()
assert hasattr(pretrainer, 'last_checkpoint_time')
assert pretrainer.checkpoint_interval == 2 * 60 * 60
assert pretrainer.last_22_date is None
assert pretrainer.recent_checkpoints == []
@patch('torch.nn.MSELoss')
@patch('torch.nn.CrossEntropyLoss')
def test_initialize_loss_functions(mock_ce_loss, mock_mse_loss):
"""Test loss function initialization."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
criterion_denoise, criterion_rotate, best_val_loss = pretrainer._initialize_loss_functions()
assert mock_mse_loss.called
assert mock_ce_loss.called
assert best_val_loss == float('inf')
@patch('pandas.concat') @patch('pandas.concat')
@patch('pandas.read_parquet') @patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader') @patch('aiia.pretrain.pretrainer.AIIADataLoader')
@patch('os.path.join', return_value='mocked/path/model.pt') @patch('os.path.join', return_value='mocked/path/model.pt')
@patch('builtins.print') # Add this to mock the print function @patch('builtins.print')
def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat): def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_read_parquet, mock_concat):
"""Test the train method under normal conditions with comprehensive verification.""" """Test the train method under normal conditions with comprehensive verification."""
# Setup test data and mocks # Setup test data and mocks
@ -73,6 +158,7 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig()) pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
pretrainer.projection_head = mock_projection_head pretrainer.projection_head = mock_projection_head
pretrainer.optimizer = MagicMock() pretrainer.optimizer = MagicMock()
pretrainer.checkpoint_dir = None # Initialize checkpoint_dir
# Setup dataset paths and mock batch data # Setup dataset paths and mock batch data
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet'] dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
@ -104,185 +190,118 @@ def test_train_happy_path(mock_print, mock_path_join, mock_data_loader, mock_rea
assert mock_process_batch.call_count == 2 assert mock_process_batch.call_count == 2
assert mock_validate.call_count == 2 assert mock_validate.call_count == 2
# Check for "Best model saved!" instead of model.save()
mock_print.assert_any_call("Best model saved!") mock_print.assert_any_call("Best model saved!")
mock_save_losses.assert_called_once() mock_save_losses.assert_called_once()
# Verify state changes
assert len(pretrainer.train_losses) == 2 assert len(pretrainer.train_losses) == 2
assert pretrainer.train_losses == [0.5, 0.5] assert pretrainer.train_losses == [0.5, 0.5]
@patch('datetime.datetime')
# Error cases @patch('time.time')
def test_train_no_dataset_paths(): def test_handle_checkpoints(mock_time, mock_datetime):
"""Test ValueError when no dataset paths are provided.""" """Test checkpoint handling logic."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.checkpoint_dir = "checkpoints"
pretrainer.current_epoch = 1
pretrainer._initialize_checkpoint_variables()
# Set a base time value
base_time = 1000
# Set the last checkpoint time to base_time
pretrainer.last_checkpoint_time = base_time
# Mock time to return base_time + interval + 1 to trigger checkpoint save
mock_time.return_value = base_time + pretrainer.checkpoint_interval + 1
# Mock datetime for 22:00 checkpoint
mock_dt = MagicMock()
mock_dt.hour = 22
mock_dt.minute = 0
mock_dt.date.return_value = datetime.date(2023, 1, 1)
mock_datetime.now.return_value = mock_dt
with patch.object(pretrainer, '_save_checkpoint') as mock_save:
pretrainer._handle_checkpoints(100)
# Should be called twice - once for regular interval and once for 22:00
assert mock_save.call_count == 2
with pytest.raises(ValueError, match="No dataset paths provided"): def test_training_phase():
pretrainer.train([]) """Test the training phase logic."""
@patch('pandas.read_parquet')
def test_train_all_datasets_fail(mock_read_parquet):
"""Test handling when all datasets fail to load."""
mock_read_parquet.side_effect = Exception("Failed to load dataset")
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig()) pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
dataset_paths = ['path/to/dataset1.parquet', 'path/to/dataset2.parquet']
with pytest.raises(ValueError, match="No valid datasets could be loaded"):
pretrainer.train(dataset_paths)
# Edge cases
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
def test_train_empty_loaders(mock_data_loader, mock_read_parquet, mock_concat):
"""Test behavior with empty data loaders."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
loader_instance = MagicMock()
loader_instance.train_loader = [] # Empty train loader
loader_instance.val_loader = [] # Empty val loader
mock_data_loader.return_value = loader_instance
mock_model = MagicMock()
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock() pretrainer.optimizer = MagicMock()
pretrainer.checkpoint_dir = None # Initialize checkpoint_dir
with patch.object(Pretrainer, 'save_losses') as mock_save_losses: pretrainer._initialize_checkpoint_variables()
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1) pretrainer.current_epoch = 0
# Verify empty loader behavior # Create mock batch data with requires_grad=True
assert len(pretrainer.train_losses) == 1
assert pretrainer.train_losses[0] == 0.0
mock_save_losses.assert_called_once()
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
def test_train_none_batch_data(mock_data_loader, mock_read_parquet, mock_concat):
"""Test behavior when batch_data is None."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
loader_instance = MagicMock()
loader_instance.train_loader = [None] # Loader returns None
loader_instance.val_loader = []
mock_data_loader.return_value = loader_instance
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
with patch.object(Pretrainer, '_process_batch') as mock_process_batch, \
patch.object(Pretrainer, 'save_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=1)
# Verify None batch handling
mock_process_batch.assert_not_called()
assert pretrainer.train_losses[0] == 0.0
# Parameter variations
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
def test_train_with_custom_parameters(mock_data_loader, mock_read_parquet, mock_concat):
"""Test that custom parameters are properly passed through."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
loader_instance = MagicMock()
loader_instance.train_loader = []
loader_instance.val_loader = []
mock_data_loader.return_value = loader_instance
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
# Custom parameters
custom_output_path = "custom/output/path"
custom_column = "custom_column"
custom_batch_size = 16
custom_sample_size = 5000
with patch.object(Pretrainer, 'save_losses'):
pretrainer.train(
['path/to/dataset.parquet'],
output_path=custom_output_path,
column=custom_column,
batch_size=custom_batch_size,
sample_size=custom_sample_size
)
# Verify custom parameters were used
mock_read_parquet.return_value.head.assert_called_once_with(custom_sample_size)
assert mock_data_loader.call_args[1]['column'] == custom_column
assert mock_data_loader.call_args[1]['batch_size'] == custom_batch_size
@patch('pandas.concat')
@patch('pandas.read_parquet')
@patch('aiia.pretrain.pretrainer.AIIADataLoader')
@patch('builtins.print') # Add this to mock the print function
def test_train_validation_loss_improvement(mock_print, mock_data_loader, mock_read_parquet, mock_concat):
"""Test that model is saved only when validation loss improves."""
real_df = pd.DataFrame({'image_bytes': [torch.randn(1, 3, 224, 224).tolist()]})
mock_read_parquet.return_value.head.return_value = real_df
mock_concat.return_value = real_df
# Create mock batch data with proper structure
mock_batch_data = { mock_batch_data = {
'denoise': (torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)), 'denoise': (
'rotate': (torch.randn(2, 3, 32, 32), torch.tensor([0, 1])) torch.randn(2, 3, 32, 32, requires_grad=True),
torch.randn(2, 3, 32, 32, requires_grad=True)
),
'rotate': (
torch.randn(2, 3, 32, 32, requires_grad=True),
torch.tensor([0, 1], dtype=torch.long) # Labels typically don't need gradients
)
} }
mock_train_loader = [(0, mock_batch_data)] # Include batch index
# Mock the loss functions to return tensors that require gradients
criterion_denoise = MagicMock(return_value=torch.tensor(0.5, requires_grad=True))
criterion_rotate = MagicMock(return_value=torch.tensor(0.5, requires_grad=True))
with patch.object(pretrainer, '_process_batch', return_value=torch.tensor(0.5, requires_grad=True)), \
patch.object(pretrainer, '_handle_checkpoints') as mock_handle_checkpoints:
total_loss, batch_count = pretrainer._training_phase(
mock_train_loader, 0, criterion_denoise, criterion_rotate)
assert total_loss == 0.5
assert batch_count == 1
mock_handle_checkpoints.assert_called_once_with(1) # Check if checkpoint handling was called
loader_instance = MagicMock() def test_validation_phase():
loader_instance.train_loader = [mock_batch_data] """Test the validation phase logic."""
loader_instance.val_loader = [mock_batch_data] pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
mock_data_loader.return_value = loader_instance
mock_model = MagicMock()
pretrainer = Pretrainer(model=mock_model, config=AIIAConfig())
pretrainer.projection_head = MagicMock() pretrainer.projection_head = MagicMock()
pretrainer.optimizer = MagicMock()
mock_val_loader = [MagicMock()]
criterion_denoise = MagicMock()
criterion_rotate = MagicMock()
with patch.object(pretrainer, '_validate', return_value=0.4):
val_loss = pretrainer._validation_phase(
mock_val_loader, criterion_denoise, criterion_rotate)
assert val_loss == 0.4
# Initialize the best validation loss @patch('pandas.read_parquet')
pretrainer.best_val_loss = float('inf') def test_load_and_merge_datasets(mock_read_parquet):
"""Test dataset loading and merging."""
pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
mock_df = pd.DataFrame({'col': [1, 2, 3]})
mock_read_parquet.return_value.head.return_value = mock_df
result = pretrainer._load_and_merge_datasets(['path1.parquet', 'path2.parquet'], 1000)
assert len(result) == 6 # 2 datasets * 3 rows each
mock_batch_loss = torch.tensor(0.5, requires_grad=True) def test_process_batch_none_tasks():
"""Test processing batch with no tasks."""
# Test improving validation loss pretrainer = Pretrainer(model=MagicMock(), config=AIIAConfig())
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
patch.object(Pretrainer, '_validate', side_effect=[3.0, 2.0, 1.0]), \ batch_data = {
patch.object(Pretrainer, 'save_losses'): 'denoise': None,
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3) 'rotate': None
}
# Check for "Best model saved!" 3 times
assert mock_print.call_args_list.count(call("Best model saved!")) == 3 loss = pretrainer._process_batch(
batch_data,
# Reset for next test criterion_denoise=MagicMock(),
mock_print.reset_mock() criterion_rotate=MagicMock()
pretrainer.train_losses = [] )
# Reset best validation loss for the second test assert loss == 0
pretrainer.best_val_loss = float('inf')
# Test fluctuating validation loss
with patch.object(Pretrainer, '_process_batch', return_value=mock_batch_loss), \
patch.object(Pretrainer, '_validate', side_effect=[3.0, 4.0, 2.0]), \
patch.object(Pretrainer, 'save_losses'):
pretrainer.train(['path/to/dataset.parquet'], num_epochs=3)
# Should print "Best model saved!" only on first and third epochs
assert mock_print.call_args_list.count(call("Best model saved!")) == 2
@patch('aiia.pretrain.pretrainer.Pretrainer._process_batch') @patch('aiia.pretrain.pretrainer.Pretrainer._process_batch')