aiuNN/tests/inference/test_inference.py

99 lines
3.1 KiB
Python

# tests/inference/test_inference.py
import pytest
import numpy as np
import torch
from PIL import Image
from aiunn import aiuNNInference
from aiunn.upsampler.aiunn import aiuNN, aiuNNConfig
from aiia import AIIABase, AIIAConfig
import os
import json
from unittest.mock import patch, MagicMock, mock_open
@pytest.fixture
def real_model(tmp_path):
# Create temporary directory for model
model_dir = tmp_path / "model"
model_dir.mkdir()
config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config)
# Make sure aiuNN is properly configured with all required attributes
upsampler = aiuNN(base_model, config=ai_config)
# Ensure the upsample attribute is properly set if needed
# upsampler.upsample = ... # Add any necessary initialization
# Save the model and config to temporary directory
save_path = str(model_dir / "save")
os.makedirs(save_path, exist_ok=True)
# Save config file
config_data = {
"model_type": "test_model",
"scale": 4,
"in_channels": 3,
"out_channels": 3
}
with open(os.path.join(save_path, "config.json"), "w") as f:
json.dump(config_data, f)
# Save model
upsampler.save(save_path)
# Load model in inference mode
inference_model = aiuNNInference(model_path=save_path, precision='fp16', device='cpu')
return inference_model
@pytest.fixture
def inference(real_model):
return real_model
def test_preprocess_image(inference):
# Create a small test image
test_array = np.zeros((100, 100, 3), dtype=np.uint8)
test_image = Image.fromarray(test_array)
# Test with PIL Image
result = inference.preprocess_image(test_image)
assert isinstance(result, torch.Tensor)
assert result.shape[0] == 1 # batch dimension
assert result.shape[1] == 3 # channels
assert result.shape[2:] == (100, 100) # height, width
def test_postprocess_tensor(inference):
# Create a test tensor
test_tensor = torch.zeros(1, 3, 100, 100)
result = inference.postprocess_tensor(test_tensor)
assert isinstance(result, Image.Image)
assert result.size == (100, 100)
assert result.mode == 'RGB'
def test_save(inference):
# Create a test image
test_image = Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8))
output_path = "test_output.png"
with patch('os.makedirs') as mock_makedirs:
inference.save(test_image, output_path)
mock_makedirs.assert_called_with(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
def test_convert_to_binary(inference):
# Create a test image
test_image = Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8))
result = inference.convert_to_binary(test_image)
assert isinstance(result, bytes)
assert len(result) > 0
def test_process_batch(inference):
# Create test images
test_array = np.zeros((100, 100, 3), dtype=np.uint8)
test_images = [Image.fromarray(test_array) for _ in range(2)]
results = inference.process_batch(test_images)
assert len(results) == 2
assert all(isinstance(img, Image.Image) for img in results)