aiuNN/tests/finetune/test_trainer.py

75 lines
2.2 KiB
Python

import pytest
import torch
import torch.nn as nn
from aiunn import aiuNNTrainer
# Simple mock dataset
class MockDataset(torch.utils.data.Dataset):
def __init__(self, num_samples=10):
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return torch.randn(3, 64, 64), torch.randn(3, 128, 128)
# Simple mock model
class MockModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 3, padding=1)
def forward(self, x):
return self.conv(x)
def save(self, path):
pass # Mock save method
@pytest.fixture
def trainer():
model = MockModel()
return aiuNNTrainer(model, dataset_class=MockDataset)
def test_trainer_initialization(trainer):
"""Test basic trainer initialization"""
assert trainer.model is not None
assert isinstance(trainer.criterion, nn.MSELoss)
assert trainer.optimizer is None
assert trainer.device in [torch.device('cuda'), torch.device('cpu')]
def test_load_data_basic(trainer):
"""Test basic data loading"""
train_loader, val_loader = trainer.load_data(
dataset_params={'num_samples': 10},
batch_size=2,
validation_split=0.2
)
assert train_loader is not None
assert val_loader is not None
assert len(train_loader) > 0
assert len(val_loader) > 0
def test_load_custom_datasets(trainer):
"""Test loading custom datasets"""
train_dataset = MockDataset(num_samples=10)
val_dataset = MockDataset(num_samples=5)
train_loader, val_loader = trainer.load_data(
custom_train_dataset=train_dataset,
custom_val_dataset=val_dataset,
batch_size=2
)
assert train_loader is not None
assert val_loader is not None
assert len(train_loader) == 5 # 10 samples with batch size 2
assert len(val_loader) == 3 # 5 samples with batch size 2 (rounded up)
def test_error_no_dataset():
"""Test error when no dataset is provided"""
trainer = aiuNNTrainer(MockModel(), dataset_class=None)
with pytest.raises(ValueError):
trainer.load_data(dataset_params={})