aiuNN/tests/inference/test_inference.py

90 lines
2.7 KiB
Python

# tests/inference/test_inference.py
import pytest
import numpy as np
import torch
from PIL import Image
from aiunn import aiuNNInference
from aiunn.upsampler.aiunn import aiuNN, aiuNNConfig
from aiia import AIIABase, AIIAConfig
import os
import json
from unittest.mock import patch, MagicMock, mock_open
@pytest.fixture
def real_model(tmp_path):
# Create temporary directory for model
model_dir = tmp_path / "model"
model_dir.mkdir()
config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config)
# Make sure aiuNN is properly configured with all required attributes
upsampler = aiuNN(config=ai_config)
upsampler.load_base_model(base_model)
# Save the model and config to temporary directory
save_path = str(model_dir / "save")
os.makedirs(save_path, exist_ok=True)
# Save config file
config_data = {
"model_type": "test_model",
"scale": 4,
"in_channels": 3,
"out_channels": 3
}
with open(os.path.join(save_path, "config.json"), "w") as f:
json.dump(config_data, f)
# Save model
upsampler.save_pretrained(save_path)
# Load model in inference mode
inference_model = aiuNNInference(model_path=save_path, device='cpu')
return inference_model
@pytest.fixture
def inference(real_model):
return real_model
def test_preprocess_image(inference):
# Create a small test image
test_array = np.zeros((100, 100, 3), dtype=np.uint8)
test_image = Image.fromarray(test_array)
# Test with PIL Image
result = inference.preprocess_image(test_image)
assert isinstance(result, torch.Tensor)
assert result.shape[0] == 1 # batch dimension
assert result.shape[1] == 3 # channels
assert result.shape[2:] == (100, 100) # height, width
def test_postprocess_tensor(inference):
# Create a test tensor
test_tensor = torch.zeros(1, 3, 100, 100)
result = inference.postprocess_tensor(test_tensor)
assert isinstance(result, Image.Image)
assert result.size == (100, 100)
assert result.mode == 'RGB'
def test_save(inference):
# Create a test image
test_image = Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8))
output_path = "test_output.png"
with patch('os.makedirs') as mock_makedirs:
inference.save(test_image, output_path)
mock_makedirs.assert_called_with(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
def test_convert_to_binary(inference):
# Create a test image
test_image = Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8))
result = inference.convert_to_binary(test_image)
assert isinstance(result, bytes)
assert len(result) > 0