75 lines
2.2 KiB
Python
75 lines
2.2 KiB
Python
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
from aiunn import aiuNNTrainer
|
|
|
|
# Simple mock dataset
|
|
class MockDataset(torch.utils.data.Dataset):
|
|
def __init__(self, num_samples=10):
|
|
self.num_samples = num_samples
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
def __getitem__(self, idx):
|
|
return torch.randn(3, 64, 64), torch.randn(3, 128, 128)
|
|
|
|
# Simple mock model
|
|
class MockModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(3, 3, 3, padding=1)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
def save(self, path):
|
|
pass # Mock save method
|
|
|
|
@pytest.fixture
|
|
def trainer():
|
|
model = MockModel()
|
|
return aiuNNTrainer(model, dataset_class=MockDataset)
|
|
|
|
def test_trainer_initialization(trainer):
|
|
"""Test basic trainer initialization"""
|
|
assert trainer.model is not None
|
|
assert isinstance(trainer.criterion, nn.MSELoss)
|
|
assert trainer.optimizer is None
|
|
assert trainer.device in [torch.device('cuda'), torch.device('cpu')]
|
|
|
|
def test_load_data_basic(trainer):
|
|
"""Test basic data loading"""
|
|
train_loader, val_loader = trainer.load_data(
|
|
dataset_params={'num_samples': 10},
|
|
batch_size=2,
|
|
validation_split=0.2
|
|
)
|
|
|
|
assert train_loader is not None
|
|
assert val_loader is not None
|
|
assert len(train_loader) > 0
|
|
assert len(val_loader) > 0
|
|
|
|
def test_load_custom_datasets(trainer):
|
|
"""Test loading custom datasets"""
|
|
train_dataset = MockDataset(num_samples=10)
|
|
val_dataset = MockDataset(num_samples=5)
|
|
|
|
train_loader, val_loader = trainer.load_data(
|
|
custom_train_dataset=train_dataset,
|
|
custom_val_dataset=val_dataset,
|
|
batch_size=2
|
|
)
|
|
|
|
assert train_loader is not None
|
|
assert val_loader is not None
|
|
assert len(train_loader) == 5 # 10 samples with batch size 2
|
|
assert len(val_loader) == 3 # 5 samples with batch size 2 (rounded up)
|
|
|
|
def test_error_no_dataset():
|
|
"""Test error when no dataset is provided"""
|
|
trainer = aiuNNTrainer(MockModel(), dataset_class=None)
|
|
|
|
with pytest.raises(ValueError):
|
|
trainer.load_data(dataset_params={}) |