# tests/inference/test_inference.py import pytest import numpy as np import torch from PIL import Image from aiunn import aiuNNInference from aiunn.upsampler.aiunn import aiuNN, aiuNNConfig from aiia import AIIABase, AIIAConfig import os import json from unittest.mock import patch, MagicMock, mock_open @pytest.fixture def real_model(tmp_path): # Create temporary directory for model model_dir = tmp_path / "model" model_dir.mkdir() config = AIIAConfig() ai_config = aiuNNConfig() base_model = AIIABase(config) # Make sure aiuNN is properly configured with all required attributes upsampler = aiuNN(config=ai_config) upsampler.load_base_model(base_model) # Save the model and config to temporary directory save_path = str(model_dir / "save") os.makedirs(save_path, exist_ok=True) # Save config file config_data = { "model_type": "test_model", "scale": 4, "in_channels": 3, "out_channels": 3 } with open(os.path.join(save_path, "config.json"), "w") as f: json.dump(config_data, f) # Save model upsampler.save_pretrained(save_path) # Load model in inference mode inference_model = aiuNNInference(model_path=save_path, device='cpu') return inference_model @pytest.fixture def inference(real_model): return real_model def test_preprocess_image(inference): # Create a small test image test_array = np.zeros((100, 100, 3), dtype=np.uint8) test_image = Image.fromarray(test_array) # Test with PIL Image result = inference.preprocess_image(test_image) assert isinstance(result, torch.Tensor) assert result.shape[0] == 1 # batch dimension assert result.shape[1] == 3 # channels assert result.shape[2:] == (100, 100) # height, width def test_postprocess_tensor(inference): # Create a test tensor test_tensor = torch.zeros(1, 3, 100, 100) result = inference.postprocess_tensor(test_tensor) assert isinstance(result, Image.Image) assert result.size == (100, 100) assert result.mode == 'RGB' def test_save(inference): # Create a test image test_image = Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8)) output_path = "test_output.png" with patch('os.makedirs') as mock_makedirs: inference.save(test_image, output_path) mock_makedirs.assert_called_with(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) def test_convert_to_binary(inference): # Create a test image test_image = Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8)) result = inference.convert_to_binary(test_image) assert isinstance(result, bytes) assert len(result) > 0