import pytest import torch import torch.nn as nn from aiunn import aiuNNTrainer # Simple mock dataset class MockDataset(torch.utils.data.Dataset): def __init__(self, num_samples=10): self.num_samples = num_samples def __len__(self): return self.num_samples def __getitem__(self, idx): return torch.randn(3, 64, 64), torch.randn(3, 128, 128) # Simple mock model class MockModel(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 3, 3, padding=1) def forward(self, x): return self.conv(x) def save(self, path): pass # Mock save method @pytest.fixture def trainer(): model = MockModel() return aiuNNTrainer(model, dataset_class=MockDataset) def test_trainer_initialization(trainer): """Test basic trainer initialization""" assert trainer.model is not None assert isinstance(trainer.criterion, nn.MSELoss) assert trainer.optimizer is None assert trainer.device in [torch.device('cuda'), torch.device('cpu')] def test_load_data_basic(trainer): """Test basic data loading""" train_loader, val_loader = trainer.load_data( dataset_params={'num_samples': 10}, batch_size=2, validation_split=0.2 ) assert train_loader is not None assert val_loader is not None assert len(train_loader) > 0 assert len(val_loader) > 0 def test_load_custom_datasets(trainer): """Test loading custom datasets""" train_dataset = MockDataset(num_samples=10) val_dataset = MockDataset(num_samples=5) train_loader, val_loader = trainer.load_data( custom_train_dataset=train_dataset, custom_val_dataset=val_dataset, batch_size=2 ) assert train_loader is not None assert val_loader is not None assert len(train_loader) == 5 # 10 samples with batch size 2 assert len(val_loader) == 3 # 5 samples with batch size 2 (rounded up) def test_error_no_dataset(): """Test error when no dataset is provided""" trainer = aiuNNTrainer(MockModel(), dataset_class=None) with pytest.raises(ValueError): trainer.load_data(dataset_params={})