aiuNN/README.md

5.1 KiB

aiuNN

Adaptive Image Upscaler using Neural Networks

Overview

aiuNN is an adaptive image upscaling model built on top of the Adaptive Image Intelligence Architecture (AIIA). This project provides fine-tuned versions of AIIA models specifically designed for high-quality image upscaling. By leveraging neural networks, aiuNN can significantly enhance the resolution and detail of images.

Features

  • High-Quality Upscaling: Achieve superior image quality with detailed and sharp outputs.
  • Fine-Tuned Models: Pre-trained on a diverse dataset to ensure optimal performance.
  • Easy Integration: Simple API for integrating upscaling capabilities into your applications.
  • Customizable: Fine-tune the models further on your own datasets for specific use cases.

Installation

You can install aiuNN using pip. Run the following command:

pip install git+https://gitea.fabelous.app/Machine-Learning/aiuNN.git

Usage

Here's a basic example of how to use aiuNN for image upscaling:

from aiia import AIIABase
from aiunn import aiuNN, aiuNNTrainer
import pandas as pd
from torchvision import transforms

# Load your base model and upscaler
pretrained_model_path = "path/to/aiia/model"
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
upscaler = aiuNN(base_model)

# Create trainer with your dataset class
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)

# Load data using parameters for your dataset
dataset_params = {
    'parquet_files': [
        "path/to/dataset1",
        "path/to/dataset2"
    ],
    'transform': transforms.Compose([transforms.ToTensor()]),
    'samples_per_file': 5000 # Your training samples you want to load per file
}
trainer.load_data(dataset_params=dataset_params, batch_size=1)

# Fine-tune the model
trainer.finetune(output_path="trained_model")

Dataset

The UpscaleDataset class is designed to handle Parquet files containing image data. It loads a subset of images from each file and validates the data types to ensure consistency.

This is an example dataset that you can use with the AIIUN model:

class UpscaleDataset(Dataset):
    def __init__(self, parquet_files: list, transform=None, samples_per_file=10_000):
        combined_df = pd.DataFrame()
        for parquet_file in parquet_files:
            # Load a subset from each parquet file
            df = pd.read_parquet(parquet_file, columns=['image_410', 'image_820']).head(samples_per_file)
            combined_df = pd.concat([combined_df, df], ignore_index=True)

        # Validate rows (ensuring each value is bytes or str)
        self.df = combined_df.apply(self._validate_row, axis=1)
        self.transform = transform
        self.failed_indices = set()

    def _validate_row(self, row):
        for col in ['image_410', 'image_820']:
            if not isinstance(row[col], (bytes, str)):
                raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
        return row

    def _decode_image(self, data):
        try:
            if isinstance(data, str):
                return base64.b64decode(data)
            elif isinstance(data, bytes):
                return data
            raise ValueError(f"Unsupported data type: {type(data)}")
        except Exception as e:
            raise RuntimeError(f"Decoding failed: {str(e)}")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # If previous call failed for this index, use a different index
        if idx in self.failed_indices:
            return self[(idx + 1) % len(self)]
        try:
            row = self.df.iloc[idx]
            low_res_bytes = self._decode_image(row['image_410'])
            high_res_bytes = self._decode_image(row['image_820'])
            ImageFile.LOAD_TRUNCATED_IMAGES = True
            # Open image bytes with Pillow and convert to RGBA first
            low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA')
            high_res_rgba = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA')
            
            # Create a new RGB image with black background
            low_res_rgb = Image.new("RGB", low_res_rgba.size, (0, 0, 0))
            high_res_rgb = Image.new("RGB", high_res_rgba.size, (0, 0, 0))
            
            # Composite the original image over the black background
            low_res_rgb.paste(low_res_rgba, mask=low_res_rgba.split()[3])
            high_res_rgb.paste(high_res_rgba, mask=high_res_rgba.split()[3])
            
            # Now we have true 3-channel RGB images with transparent areas converted to black
            low_res = low_res_rgb
            high_res = high_res_rgb
                        
            # If a transform is provided (e.g. conversion to Tensor), apply it
            if self.transform:
                low_res = self.transform(low_res)
                high_res = self.transform(high_res)
            return low_res, high_res
        except Exception as e:
            print(f"\nError at index {idx}: {str(e)}")
            self.failed_indices.add(idx)
            return self[(idx + 1) % len(self)]