develop #11

Merged
Fabel merged 22 commits from develop into main 2025-04-04 20:12:21 +00:00
18 changed files with 613 additions and 52 deletions

10
.coveragerc Normal file
View File

@ -0,0 +1,10 @@
[run]
branch = True
source = src
omit =
*/tests/*
*/migrations/*
[report]
show_missing = True
fail_under = 80

View File

@ -0,0 +1,37 @@
name: Run VectorLoader Script
on:
push:
branches:
- main
jobs:
Explore-Gitea-Actions:
runs-on: ubuntu-latest
container: catthehacker/ubuntu:act-latest
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.11.7'
- name: Clone additional repository
run: |
git config --global credential.helper cache
git clone https://fabel:${{ secrets.CICD }}@gitea.fabelous.app/fabel/VectorLoader.git
- name: Install dependencies
run: |
cd VectorLoader
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Run vectorizing
env:
VECTORDB_TOKEN: ${{ secrets.VECTORDB_TOKEN }}
run: |
cd VectorLoader
python -m src.run --full

View File

@ -0,0 +1,37 @@
name: Gitea Actions For AIIA
run-name: ${{ gitea.actor }} is testing out Gitea Actions 🚀
on: [push]
jobs:
Explore-Gitea-Actions:
runs-on: ubuntu-latest
container: catthehacker/ubuntu:act-latest
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11.7'
- name: Cache pip and model
uses: actions/cache@v3
with:
path: |
~/.cache/pip
./fabel
key: ${{ runner.os }}-pip-model-${{ hashFiles('requirements.txt', 'requirements-dev.txt') }}
restore-keys: |
${{ runner.os }}-pip-model-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-dev.txt
pip install -e .
- name: Run tests
run: |
pytest tests/

4
MANIFEST.in Normal file
View File

@ -0,0 +1,4 @@
include LICENSE
include README.md
include requirements.txt
recursive-include src/aiia *

131
README.md
View File

@ -1,3 +1,130 @@
# aiunn # aiuNN
Advanced Image Upscaler using Neural Networks Adaptive Image Upscaler using Neural Networks
## Overview
`aiuNN` is an adaptive image upscaling model built on top of the Adaptive Image Intelligence Architecture (AIIA). This project provides fine-tuned versions of AIIA models specifically designed for high-quality image upscaling. By leveraging neural networks, `aiuNN` can significantly enhance the resolution and detail of images.
## Features
- **High-Quality Upscaling**: Achieve superior image quality with detailed and sharp outputs.
- **Fine-Tuned Models**: Pre-trained on a diverse dataset to ensure optimal performance.
- **Easy Integration**: Simple API for integrating upscaling capabilities into your applications.
- **Customizable**: Fine-tune the models further on your own datasets for specific use cases.
## Installation
You can install `aiuNN` using pip. Run the following command:
```sh
pip install git+https://gitea.fabelous.app/Machine-Learning/aiuNN.git
```
## Usage
Here's a basic example of how to use `aiuNN` for image upscaling:
```python src/main.py
from aiia import AIIABase
from aiunn import aiuNN, aiuNNTrainer
import pandas as pd
from torchvision import transforms
# Load your base model and upscaler
pretrained_model_path = "path/to/aiia/model"
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
upscaler = aiuNN(base_model)
# Create trainer with your dataset class
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)
# Load data using parameters for your dataset
dataset_params = {
'parquet_files': [
"path/to/dataset1",
"path/to/dataset2"
],
'transform': transforms.Compose([transforms.ToTensor()]),
'samples_per_file': 5000 # Your training samples you want to load per file
}
trainer.load_data(dataset_params=dataset_params, batch_size=1)
# Fine-tune the model
trainer.finetune(output_path="trained_model")
```
## Dataset
The `UpscaleDataset` class is designed to handle Parquet files containing image data. It loads a subset of images from each file and validates the data types to ensure consistency.
This is an example dataset that you can use with the AIIUN model:
```python src/example.py
class UpscaleDataset(Dataset):
def __init__(self, parquet_files: list, transform=None, samples_per_file=10_000):
combined_df = pd.DataFrame()
for parquet_file in parquet_files:
# Load a subset from each parquet file
df = pd.read_parquet(parquet_file, columns=['image_410', 'image_820']).head(samples_per_file)
combined_df = pd.concat([combined_df, df], ignore_index=True)
# Validate rows (ensuring each value is bytes or str)
self.df = combined_df.apply(self._validate_row, axis=1)
self.transform = transform
self.failed_indices = set()
def _validate_row(self, row):
for col in ['image_410', 'image_820']:
if not isinstance(row[col], (bytes, str)):
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
return row
def _decode_image(self, data):
try:
if isinstance(data, str):
return base64.b64decode(data)
elif isinstance(data, bytes):
return data
raise ValueError(f"Unsupported data type: {type(data)}")
except Exception as e:
raise RuntimeError(f"Decoding failed: {str(e)}")
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# If previous call failed for this index, use a different index
if idx in self.failed_indices:
return self[(idx + 1) % len(self)]
try:
row = self.df.iloc[idx]
low_res_bytes = self._decode_image(row['image_410'])
high_res_bytes = self._decode_image(row['image_820'])
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Open image bytes with Pillow and convert to RGBA first
low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA')
high_res_rgba = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA')
# Create a new RGB image with black background
low_res_rgb = Image.new("RGB", low_res_rgba.size, (0, 0, 0))
high_res_rgb = Image.new("RGB", high_res_rgba.size, (0, 0, 0))
# Composite the original image over the black background
low_res_rgb.paste(low_res_rgba, mask=low_res_rgba.split()[3])
high_res_rgb.paste(high_res_rgba, mask=high_res_rgba.split()[3])
# Now we have true 3-channel RGB images with transparent areas converted to black
low_res = low_res_rgb
high_res = high_res_rgb
# If a transform is provided (e.g. conversion to Tensor), apply it
if self.transform:
low_res = self.transform(low_res)
high_res = self.transform(high_res)
return low_res, high_res
except Exception as e:
print(f"\nError at index {idx}: {str(e)}")
self.failed_indices.add(idx)
return self[(idx + 1) % len(self)]
```

View File

@ -15,7 +15,7 @@ class UpscaleDataset(Dataset):
combined_df = pd.DataFrame() combined_df = pd.DataFrame()
for parquet_file in parquet_files: for parquet_file in parquet_files:
# Load a subset from each parquet file # Load a subset from each parquet file
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(samples_per_file) df = pd.read_parquet(parquet_file, columns=['image_410', 'image_820']).head(samples_per_file)
combined_df = pd.concat([combined_df, df], ignore_index=True) combined_df = pd.concat([combined_df, df], ignore_index=True)
# Validate rows (ensuring each value is bytes or str) # Validate rows (ensuring each value is bytes or str)
@ -24,7 +24,7 @@ class UpscaleDataset(Dataset):
self.failed_indices = set() self.failed_indices = set()
def _validate_row(self, row): def _validate_row(self, row):
for col in ['image_512', 'image_1024']: for col in ['image_410', 'image_820']:
if not isinstance(row[col], (bytes, str)): if not isinstance(row[col], (bytes, str)):
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}") raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
return row return row
@ -48,8 +48,8 @@ class UpscaleDataset(Dataset):
return self[(idx + 1) % len(self)] return self[(idx + 1) % len(self)]
try: try:
row = self.df.iloc[idx] row = self.df.iloc[idx]
low_res_bytes = self._decode_image(row['image_512']) low_res_bytes = self._decode_image(row['image_410'])
high_res_bytes = self._decode_image(row['image_1024']) high_res_bytes = self._decode_image(row['image_820'])
ImageFile.LOAD_TRUNCATED_IMAGES = True ImageFile.LOAD_TRUNCATED_IMAGES = True
# Open image bytes with Pillow and convert to RGBA first # Open image bytes with Pillow and convert to RGBA first
low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA') low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA')
@ -67,10 +67,6 @@ class UpscaleDataset(Dataset):
low_res = low_res_rgb low_res = low_res_rgb
high_res = high_res_rgb high_res = high_res_rgb
# Resize the images to reduce VRAM usage
low_res = low_res.resize((410, 410), Image.LANCZOS)
high_res = high_res.resize((820, 820), Image.LANCZOS)
# If a transform is provided (e.g. conversion to Tensor), apply it # If a transform is provided (e.g. conversion to Tensor), apply it
if self.transform: if self.transform:
low_res = self.transform(low_res) low_res = self.transform(low_res)
@ -98,7 +94,7 @@ if __name__ =="__main__":
"/root/training_data/vision-dataset/image_vec_upscaler.parquet" "/root/training_data/vision-dataset/image_vec_upscaler.parquet"
], ],
'transform': transforms.Compose([transforms.ToTensor()]), 'transform': transforms.Compose([transforms.ToTensor()]),
'samples_per_file': 5000 'samples_per_file': 20_000
} }
trainer.load_data(dataset_params=dataset_params, batch_size=1) trainer.load_data(dataset_params=dataset_params, batch_size=1)

BIN
input.jpg

Binary file not shown.

Before

Width:  |  Height:  |  Size: 157 KiB

View File

@ -1,7 +1,6 @@
[build-system] [build-system]
requires = ["setuptools>=45", "wheel"] requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project] [project]
name = "aiunn" name = "aiunn"
version = "0.1.1" version = "0.1.1"
@ -10,8 +9,7 @@ readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
license = {file = "LICENSE"} license = {file = "LICENSE"}
authors = [ authors = [
{name = "Falko Habel", email = "falko.habel@gmx.de"}, {name = "Falko Habel", email = "falko.habel@gmx.de"},
] ]
[project.urls] [project.urls]
"Homepage" = "https://gitea.fabelous.app/Machine-Learning/aiuNN" "Homepage"="https://gitea.fabelous.app/Machine-Learning/aiuNN"

3
pytest.ini Normal file
View File

@ -0,0 +1,3 @@
[pytest]
testpaths = tests/
python_files = test_*.py

2
requirements-dev.txt Normal file
View File

@ -0,0 +1,2 @@
pytest
pytest-mock

View File

@ -2,4 +2,5 @@ torch
aiia aiia
pillow pillow
torchvision torchvision
sklearn scikit-learn
git+https://gitea.fabelous.app/Machine-Learning/AIIA.git

View File

@ -2,13 +2,18 @@ from setuptools import setup, find_packages
setup( setup(
name="aiunn", name="aiunn",
version="0.1.1", version="0.1.2",
packages=find_packages(where="src"), packages=find_packages(where="src"),
package_dir={"": "src"}, package_dir={"": "src"},
install_requires=[ install_requires=[
line.strip() "torch",
for line in open("requirements.txt") "aiia",
if line.strip() and not line.startswith("#") "pillow",
"torchvision",
"scikit-learn",
],
dependency_links=[
"git+https://gitea.fabelous.app/Machine-Learning/AIIA.git#egg=aiia"
], ],
python_requires=">=3.10", python_requires=">=3.10",
) )

View File

@ -3,4 +3,4 @@ from .upsampler.aiunn import aiuNN
from .upsampler.config import aiuNNConfig from .upsampler.config import aiuNNConfig
from .inference.inference import aiuNNInference from .inference.inference import aiuNNInference
__version__ = "0.1.1" __version__ = "0.1.2"

View File

@ -2,50 +2,120 @@ import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import warnings import warnings
from aiia import AIIA, AIIAConfig, AIIABase from aiia.model.Model import AIIA, AIIAConfig, AIIABase
from .config import aiuNNConfig from .config import aiuNNConfig
import warnings import warnings
class aiuNN(AIIA): class aiuNN(AIIA):
def __init__(self, base_model: AIIABase): def __init__(self, base_model: AIIA, config:aiuNNConfig):
super().__init__(base_model.config) super().__init__(base_model.config)
self.base_model = base_model self.base_model = base_model
# Pass the unified base configuration using the new parameter. # Pass the unified base configuration using the new parameter.
self.config = aiuNNConfig(base_config=base_model.config) self.config = config
self.upsample = nn.Upsample( # Enhanced approach
scale_factor=self.config.upsample_scale, scale_factor = self.config.upsample_scale
mode=self.config.upsample_mode, out_channels = self.base_model.config.num_channels * (scale_factor ** 2)
align_corners=self.config.upsample_align_corners self.pixel_shuffle_conv = nn.Conv2d(
)
# Conversion layer: change from hidden size channels to 3 channels.
self.to_rgb = nn.Conv2d(
in_channels=self.base_model.config.hidden_size, in_channels=self.base_model.config.hidden_size,
out_channels=3, out_channels=out_channels,
kernel_size=1 kernel_size=self.base_model.config.kernel_size,
padding=1
) )
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
def forward(self, x): def forward(self, x):
x = self.base_model(x) x = self.base_model(x) # Get base features
x = self.upsample(x) x = self.pixel_shuffle_conv(x) # Expand channels for shuffling
x = self.to_rgb(x) # Ensures output has 3 channels. x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
return x return x
@classmethod
def load(cls, path, precision: str = None):
# Load the configuration from disk.
config = AIIAConfig.load(path)
# Reconstruct the base model from the loaded configuration.
base_model = AIIABase(config)
# Instantiate the Upsampler using the proper base model.
upsampler = cls(base_model)
# Load state dict and handle precision conversion if needed. @classmethod
def load(cls, path, precision: str = None, **kwargs):
"""
Load a aiuNN model from disk with automatic detection of base model type.
Args:
path (str): Directory containing the stored configuration and model parameters.
precision (str, optional): Desired precision for the model's parameters.
**kwargs: Additional keyword arguments to override configuration parameters.
Returns:
An instance of aiuNN with loaded weights.
"""
# Load the configuration
config = aiuNNConfig.load(path)
# Determine the device
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
state_dict = torch.load(f"{path}/model.pth", map_location=device)
# Load the state dictionary
state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device)
# Import all model types
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive
# Helper function to detect base class type from key patterns
def detect_base_class_type(keys_prefix):
if any(f"{keys_prefix}.shared_layer" in key for key in state_dict.keys()):
return AIIABaseShared
else:
return AIIABase
# Detect base model type
base_model = None
# Check for AIIAmoe with multiple experts
if any("base_model.experts" in key for key in state_dict.keys()):
# Count the number of experts
max_expert_idx = -1
for key in state_dict.keys():
if "base_model.experts." in key:
try:
parts = key.split("base_model.experts.")[1].split(".")
expert_idx = int(parts[0])
max_expert_idx = max(max_expert_idx, expert_idx)
except (ValueError, IndexError):
pass
if max_expert_idx >= 0:
# Determine the type of base_cnn each expert is using
base_class_for_experts = detect_base_class_type("base_model.experts.0.base_cnn")
# Create AIIAmoe with the detected expert count and base class
base_model = AIIAmoe(config, num_experts=max_expert_idx+1, base_class=base_class_for_experts, **kwargs)
# Check for AIIAchunked or AIIArecursive
elif any("base_model.chunked_cnn" in key for key in state_dict.keys()):
if any("recursion_depth" in key for key in state_dict.keys()):
# This is an AIIArecursive model
base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn")
base_model = AIIArecursive(config, base_class=base_class, **kwargs)
else:
# This is an AIIAchunked model
base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn")
base_model = AIIAchunked(config, base_class=base_class, **kwargs)
# Check for AIIAExpert
elif any("base_model.base_cnn" in key for key in state_dict.keys()):
# Determine which base class the expert is using
base_class = detect_base_class_type("base_model.base_cnn")
base_model = AIIAExpert(config, base_class=base_class, **kwargs)
# If none of the above, use AIIABase or AIIABaseShared directly
else:
base_class = detect_base_class_type("base_model")
base_model = base_class(config, **kwargs)
# Create the aiuNN model with the detected base model
model = cls(base_model, config=base_model.config)
# Handle precision conversion
dtype = None
if precision is not None: if precision is not None:
if precision.lower() == 'fp16': if precision.lower() == 'fp16':
dtype = torch.float16 dtype = torch.float16
@ -58,11 +128,14 @@ class aiuNN(AIIA):
else: else:
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.") raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
if dtype is not None:
for key, param in state_dict.items(): for key, param in state_dict.items():
if torch.is_tensor(param): if torch.is_tensor(param):
state_dict[key] = param.to(dtype) state_dict[key] = param.to(dtype)
upsampler.load_state_dict(state_dict)
return upsampler # Load the state dict
model.load_state_dict(state_dict)
return model
@ -70,13 +143,14 @@ if __name__ == "__main__":
from aiia import AIIABase, AIIAConfig from aiia import AIIABase, AIIAConfig
# Create a configuration and build a base model. # Create a configuration and build a base model.
config = AIIAConfig() config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config) base_model = AIIABase(config)
# Instantiate Upsampler from the base model (works correctly). # Instantiate Upsampler from the base model (works correctly).
upsampler = aiuNN(base_model) upsampler = aiuNN(base_model, config=ai_config)
# Save the model (both configuration and weights). # Save the model (both configuration and weights).
upsampler.save("hehe") upsampler.save("aiunn")
# Now load using the overridden load method; this will load the complete model. # Now load using the overridden load method; this will load the complete model.
upsampler_loaded = aiuNN.load("hehe", precision="bf16") upsampler_loaded = aiuNN.load("aiunn", precision="bf16")
print("Updated configuration:", upsampler_loaded.config.__dict__) print("Updated configuration:", upsampler_loaded.config.__dict__)

View File

@ -0,0 +1,75 @@
import pytest
import torch
import torch.nn as nn
from aiunn import aiuNNTrainer
# Simple mock dataset
class MockDataset(torch.utils.data.Dataset):
def __init__(self, num_samples=10):
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return torch.randn(3, 64, 64), torch.randn(3, 128, 128)
# Simple mock model
class MockModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 3, padding=1)
def forward(self, x):
return self.conv(x)
def save(self, path):
pass # Mock save method
@pytest.fixture
def trainer():
model = MockModel()
return aiuNNTrainer(model, dataset_class=MockDataset)
def test_trainer_initialization(trainer):
"""Test basic trainer initialization"""
assert trainer.model is not None
assert isinstance(trainer.criterion, nn.MSELoss)
assert trainer.optimizer is None
assert trainer.device in [torch.device('cuda'), torch.device('cpu')]
def test_load_data_basic(trainer):
"""Test basic data loading"""
train_loader, val_loader = trainer.load_data(
dataset_params={'num_samples': 10},
batch_size=2,
validation_split=0.2
)
assert train_loader is not None
assert val_loader is not None
assert len(train_loader) > 0
assert len(val_loader) > 0
def test_load_custom_datasets(trainer):
"""Test loading custom datasets"""
train_dataset = MockDataset(num_samples=10)
val_dataset = MockDataset(num_samples=5)
train_loader, val_loader = trainer.load_data(
custom_train_dataset=train_dataset,
custom_val_dataset=val_dataset,
batch_size=2
)
assert train_loader is not None
assert val_loader is not None
assert len(train_loader) == 5 # 10 samples with batch size 2
assert len(val_loader) == 3 # 5 samples with batch size 2 (rounded up)
def test_error_no_dataset():
"""Test error when no dataset is provided"""
trainer = aiuNNTrainer(MockModel(), dataset_class=None)
with pytest.raises(ValueError):
trainer.load_data(dataset_params={})

View File

@ -0,0 +1,99 @@
# tests/inference/test_inference.py
import pytest
import numpy as np
import torch
from PIL import Image
from aiunn import aiuNNInference
from aiunn.upsampler.aiunn import aiuNN, aiuNNConfig
from aiia import AIIABase, AIIAConfig
import os
import json
from unittest.mock import patch, MagicMock, mock_open
@pytest.fixture
def real_model(tmp_path):
# Create temporary directory for model
model_dir = tmp_path / "model"
model_dir.mkdir()
config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config)
# Make sure aiuNN is properly configured with all required attributes
upsampler = aiuNN(base_model, config=ai_config)
# Ensure the upsample attribute is properly set if needed
# upsampler.upsample = ... # Add any necessary initialization
# Save the model and config to temporary directory
save_path = str(model_dir / "save")
os.makedirs(save_path, exist_ok=True)
# Save config file
config_data = {
"model_type": "test_model",
"scale": 4,
"in_channels": 3,
"out_channels": 3
}
with open(os.path.join(save_path, "config.json"), "w") as f:
json.dump(config_data, f)
# Save model
upsampler.save(save_path)
# Load model in inference mode
inference_model = aiuNNInference(model_path=save_path, precision='fp16', device='cpu')
return inference_model
@pytest.fixture
def inference(real_model):
return real_model
def test_preprocess_image(inference):
# Create a small test image
test_array = np.zeros((100, 100, 3), dtype=np.uint8)
test_image = Image.fromarray(test_array)
# Test with PIL Image
result = inference.preprocess_image(test_image)
assert isinstance(result, torch.Tensor)
assert result.shape[0] == 1 # batch dimension
assert result.shape[1] == 3 # channels
assert result.shape[2:] == (100, 100) # height, width
def test_postprocess_tensor(inference):
# Create a test tensor
test_tensor = torch.zeros(1, 3, 100, 100)
result = inference.postprocess_tensor(test_tensor)
assert isinstance(result, Image.Image)
assert result.size == (100, 100)
assert result.mode == 'RGB'
def test_save(inference):
# Create a test image
test_image = Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8))
output_path = "test_output.png"
with patch('os.makedirs') as mock_makedirs:
inference.save(test_image, output_path)
mock_makedirs.assert_called_with(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
def test_convert_to_binary(inference):
# Create a test image
test_image = Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8))
result = inference.convert_to_binary(test_image)
assert isinstance(result, bytes)
assert len(result) > 0
def test_process_batch(inference):
# Create test images
test_array = np.zeros((100, 100, 3), dtype=np.uint8)
test_images = [Image.fromarray(test_array) for _ in range(2)]
results = inference.process_batch(test_images)
assert len(results) == 2
assert all(isinstance(img, Image.Image) for img in results)

View File

@ -0,0 +1,48 @@
import os
import tempfile
from aiia import AIIABase, AIIAConfig
from aiunn import aiuNN, aiuNNConfig
def test_save_and_load_model():
# Create a temporary directory to save the model
with tempfile.TemporaryDirectory() as tmpdirname:
# Create configurations and build a base model
config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config)
upsampler = aiuNN(base_model, config=ai_config)
# Save the model
save_path = os.path.join(tmpdirname, "model")
upsampler.save(save_path)
# Load the model
loaded_upsampler = aiuNN.load(save_path)
# Verify that the loaded model is the same as the original model
assert isinstance(loaded_upsampler, aiuNN)
assert loaded_upsampler.config.__dict__ == upsampler.config.__dict__
def test_save_and_load_model_with_precision():
# Create a temporary directory to save the model
with tempfile.TemporaryDirectory() as tmpdirname:
# Create configurations and build a base model
config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config)
upsampler = aiuNN(base_model, config=ai_config)
# Save the model
save_path = os.path.join(tmpdirname, "model")
upsampler.save(save_path)
# Load the model with precision 'bf16'
loaded_upsampler = aiuNN.load(save_path, precision="bf16")
# Verify that the loaded model is the same as the original model
assert isinstance(loaded_upsampler, aiuNN)
assert loaded_upsampler.config.__dict__ == upsampler.config.__dict__
if __name__ == "__main__":
test_save_and_load_model()
test_save_and_load_model_with_precision()

View File

@ -0,0 +1,45 @@
import pytest
from aiunn import aiuNNConfig
def test_default_initialization():
config = aiuNNConfig()
assert config.upsample_scale == 2
assert config.upsample_mode == 'bilinear'
assert not config.upsample_align_corners
assert len(config.layers) == 1
assert config.layers[0]['name'] == 'Upsample'
def test_custom_initialization():
custom_config = {
'some_key': 'some_value',
}
config = aiuNNConfig(base_config=custom_config, upsample_scale=3, upsample_mode='nearest', upsample_align_corners=True)
assert config.upsample_scale == 3
assert config.upsample_mode == 'nearest'
assert config.upsample_align_corners
assert len(config.layers) == 1
assert config.layers[0]['name'] == 'Upsample'
def test_add_upsample_layer():
config = aiuNNConfig()
config.add_upsample_layer()
assert len(config.layers) == 1
def test_upsample_layer_not_duplicated():
config = aiuNNConfig()
initial_length = len(config.layers)
# Remove all existing 'Upsample' layers.
for layer in list(config.layers):
if layer.get('name') == 'Upsample':
config.layers.remove(layer)
config.add_upsample_layer()
assert len(config.layers) == 1
def test_base_config_with_to_dict():
class MockBaseConfig:
def to_dict(self):
return {'base_key': 'base_value'}
base_config = MockBaseConfig()
config = aiuNNConfig(base_config=base_config)
assert config.to_dict()['base_key'] == 'base_value'