5.1 KiB
5.1 KiB
aiuNN
Adaptive Image Upscaler using Neural Networks
Overview
aiuNN
is an adaptive image upscaling model built on top of the Adaptive Image Intelligence Architecture (AIIA). This project provides fine-tuned versions of AIIA models specifically designed for high-quality image upscaling. By leveraging neural networks, aiuNN
can significantly enhance the resolution and detail of images.
Features
- High-Quality Upscaling: Achieve superior image quality with detailed and sharp outputs.
- Fine-Tuned Models: Pre-trained on a diverse dataset to ensure optimal performance.
- Easy Integration: Simple API for integrating upscaling capabilities into your applications.
- Customizable: Fine-tune the models further on your own datasets for specific use cases.
Installation
You can install aiuNN
using pip. Run the following command:
pip install git+https://gitea.fabelous.app/Machine-Learning/aiuNN.git
Usage
Here's a basic example of how to use aiuNN
for image upscaling:
from aiia import AIIABase
from aiunn import aiuNN, aiuNNTrainer
import pandas as pd
from torchvision import transforms
# Load your base model and upscaler
pretrained_model_path = "path/to/aiia/model"
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
upscaler = aiuNN(base_model)
# Create trainer with your dataset class
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)
# Load data using parameters for your dataset
dataset_params = {
'parquet_files': [
"path/to/dataset1",
"path/to/dataset2"
],
'transform': transforms.Compose([transforms.ToTensor()]),
'samples_per_file': 5000 # Your training samples you want to load per file
}
trainer.load_data(dataset_params=dataset_params, batch_size=1)
# Fine-tune the model
trainer.finetune(output_path="trained_model")
Dataset
The UpscaleDataset
class is designed to handle Parquet files containing image data. It loads a subset of images from each file and validates the data types to ensure consistency.
This is an example dataset that you can use with the AIIUN model:
class UpscaleDataset(Dataset):
def __init__(self, parquet_files: list, transform=None, samples_per_file=10_000):
combined_df = pd.DataFrame()
for parquet_file in parquet_files:
# Load a subset from each parquet file
df = pd.read_parquet(parquet_file, columns=['image_410', 'image_820']).head(samples_per_file)
combined_df = pd.concat([combined_df, df], ignore_index=True)
# Validate rows (ensuring each value is bytes or str)
self.df = combined_df.apply(self._validate_row, axis=1)
self.transform = transform
self.failed_indices = set()
def _validate_row(self, row):
for col in ['image_410', 'image_820']:
if not isinstance(row[col], (bytes, str)):
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
return row
def _decode_image(self, data):
try:
if isinstance(data, str):
return base64.b64decode(data)
elif isinstance(data, bytes):
return data
raise ValueError(f"Unsupported data type: {type(data)}")
except Exception as e:
raise RuntimeError(f"Decoding failed: {str(e)}")
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# If previous call failed for this index, use a different index
if idx in self.failed_indices:
return self[(idx + 1) % len(self)]
try:
row = self.df.iloc[idx]
low_res_bytes = self._decode_image(row['image_410'])
high_res_bytes = self._decode_image(row['image_820'])
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Open image bytes with Pillow and convert to RGBA first
low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA')
high_res_rgba = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA')
# Create a new RGB image with black background
low_res_rgb = Image.new("RGB", low_res_rgba.size, (0, 0, 0))
high_res_rgb = Image.new("RGB", high_res_rgba.size, (0, 0, 0))
# Composite the original image over the black background
low_res_rgb.paste(low_res_rgba, mask=low_res_rgba.split()[3])
high_res_rgb.paste(high_res_rgba, mask=high_res_rgba.split()[3])
# Now we have true 3-channel RGB images with transparent areas converted to black
low_res = low_res_rgb
high_res = high_res_rgb
# If a transform is provided (e.g. conversion to Tensor), apply it
if self.transform:
low_res = self.transform(low_res)
high_res = self.transform(high_res)
return low_res, high_res
except Exception as e:
print(f"\nError at index {idx}: {str(e)}")
self.failed_indices.add(idx)
return self[(idx + 1) % len(self)]