Compare commits
3 Commits
af8565bc68
...
2b2ac57ddd
Author | SHA1 | Date |
---|---|---|
|
2b2ac57ddd | |
|
cb7a3da8a4 | |
|
2114d3adbb |
131
README.md
131
README.md
|
@ -1,3 +1,130 @@
|
|||
# aiunn
|
||||
# aiuNN
|
||||
|
||||
Advanced Image Upscaler using Neural Networks
|
||||
Adaptive Image Upscaler using Neural Networks
|
||||
|
||||
## Overview
|
||||
|
||||
`aiuNN` is an adaptive image upscaling model built on top of the Adaptive Image Intelligence Architecture (AIIA). This project provides fine-tuned versions of AIIA models specifically designed for high-quality image upscaling. By leveraging neural networks, `aiuNN` can significantly enhance the resolution and detail of images.
|
||||
|
||||
## Features
|
||||
|
||||
- **High-Quality Upscaling**: Achieve superior image quality with detailed and sharp outputs.
|
||||
- **Fine-Tuned Models**: Pre-trained on a diverse dataset to ensure optimal performance.
|
||||
- **Easy Integration**: Simple API for integrating upscaling capabilities into your applications.
|
||||
- **Customizable**: Fine-tune the models further on your own datasets for specific use cases.
|
||||
|
||||
## Installation
|
||||
|
||||
You can install `aiuNN` using pip. Run the following command:
|
||||
|
||||
```sh
|
||||
pip install git+https://gitea.fabelous.app/Machine-Learning/aiuNN.git
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
Here's a basic example of how to use `aiuNN` for image upscaling:
|
||||
|
||||
```python src/main.py
|
||||
from aiia import AIIABase
|
||||
from aiunn import aiuNN, aiuNNTrainer
|
||||
import pandas as pd
|
||||
from torchvision import transforms
|
||||
|
||||
# Load your base model and upscaler
|
||||
pretrained_model_path = "path/to/aiia/model"
|
||||
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
|
||||
upscaler = aiuNN(base_model)
|
||||
|
||||
# Create trainer with your dataset class
|
||||
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)
|
||||
|
||||
# Load data using parameters for your dataset
|
||||
dataset_params = {
|
||||
'parquet_files': [
|
||||
"path/to/dataset1",
|
||||
"path/to/dataset2"
|
||||
],
|
||||
'transform': transforms.Compose([transforms.ToTensor()]),
|
||||
'samples_per_file': 5000 # Your training samples you want to load per file
|
||||
}
|
||||
trainer.load_data(dataset_params=dataset_params, batch_size=1)
|
||||
|
||||
# Fine-tune the model
|
||||
trainer.finetune(output_path="trained_model")
|
||||
```
|
||||
|
||||
## Dataset
|
||||
|
||||
The `UpscaleDataset` class is designed to handle Parquet files containing image data. It loads a subset of images from each file and validates the data types to ensure consistency.
|
||||
|
||||
This is an example dataset that you can use with the AIIUN model:
|
||||
|
||||
```python src/example.py
|
||||
class UpscaleDataset(Dataset):
|
||||
def __init__(self, parquet_files: list, transform=None, samples_per_file=10_000):
|
||||
combined_df = pd.DataFrame()
|
||||
for parquet_file in parquet_files:
|
||||
# Load a subset from each parquet file
|
||||
df = pd.read_parquet(parquet_file, columns=['image_410', 'image_820']).head(samples_per_file)
|
||||
combined_df = pd.concat([combined_df, df], ignore_index=True)
|
||||
|
||||
# Validate rows (ensuring each value is bytes or str)
|
||||
self.df = combined_df.apply(self._validate_row, axis=1)
|
||||
self.transform = transform
|
||||
self.failed_indices = set()
|
||||
|
||||
def _validate_row(self, row):
|
||||
for col in ['image_410', 'image_820']:
|
||||
if not isinstance(row[col], (bytes, str)):
|
||||
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
|
||||
return row
|
||||
|
||||
def _decode_image(self, data):
|
||||
try:
|
||||
if isinstance(data, str):
|
||||
return base64.b64decode(data)
|
||||
elif isinstance(data, bytes):
|
||||
return data
|
||||
raise ValueError(f"Unsupported data type: {type(data)}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Decoding failed: {str(e)}")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# If previous call failed for this index, use a different index
|
||||
if idx in self.failed_indices:
|
||||
return self[(idx + 1) % len(self)]
|
||||
try:
|
||||
row = self.df.iloc[idx]
|
||||
low_res_bytes = self._decode_image(row['image_410'])
|
||||
high_res_bytes = self._decode_image(row['image_820'])
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
# Open image bytes with Pillow and convert to RGBA first
|
||||
low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA')
|
||||
high_res_rgba = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA')
|
||||
|
||||
# Create a new RGB image with black background
|
||||
low_res_rgb = Image.new("RGB", low_res_rgba.size, (0, 0, 0))
|
||||
high_res_rgb = Image.new("RGB", high_res_rgba.size, (0, 0, 0))
|
||||
|
||||
# Composite the original image over the black background
|
||||
low_res_rgb.paste(low_res_rgba, mask=low_res_rgba.split()[3])
|
||||
high_res_rgb.paste(high_res_rgba, mask=high_res_rgba.split()[3])
|
||||
|
||||
# Now we have true 3-channel RGB images with transparent areas converted to black
|
||||
low_res = low_res_rgb
|
||||
high_res = high_res_rgb
|
||||
|
||||
# If a transform is provided (e.g. conversion to Tensor), apply it
|
||||
if self.transform:
|
||||
low_res = self.transform(low_res)
|
||||
high_res = self.transform(high_res)
|
||||
return low_res, high_res
|
||||
except Exception as e:
|
||||
print(f"\nError at index {idx}: {str(e)}")
|
||||
self.failed_indices.add(idx)
|
||||
return self[(idx + 1) % len(self)]
|
||||
```
|
||||
|
|
14
example.py
14
example.py
|
@ -15,7 +15,7 @@ class UpscaleDataset(Dataset):
|
|||
combined_df = pd.DataFrame()
|
||||
for parquet_file in parquet_files:
|
||||
# Load a subset from each parquet file
|
||||
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(samples_per_file)
|
||||
df = pd.read_parquet(parquet_file, columns=['image_410', 'image_820']).head(samples_per_file)
|
||||
combined_df = pd.concat([combined_df, df], ignore_index=True)
|
||||
|
||||
# Validate rows (ensuring each value is bytes or str)
|
||||
|
@ -24,7 +24,7 @@ class UpscaleDataset(Dataset):
|
|||
self.failed_indices = set()
|
||||
|
||||
def _validate_row(self, row):
|
||||
for col in ['image_512', 'image_1024']:
|
||||
for col in ['image_410', 'image_820']:
|
||||
if not isinstance(row[col], (bytes, str)):
|
||||
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
|
||||
return row
|
||||
|
@ -48,8 +48,8 @@ class UpscaleDataset(Dataset):
|
|||
return self[(idx + 1) % len(self)]
|
||||
try:
|
||||
row = self.df.iloc[idx]
|
||||
low_res_bytes = self._decode_image(row['image_512'])
|
||||
high_res_bytes = self._decode_image(row['image_1024'])
|
||||
low_res_bytes = self._decode_image(row['image_410'])
|
||||
high_res_bytes = self._decode_image(row['image_820'])
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
# Open image bytes with Pillow and convert to RGBA first
|
||||
low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA')
|
||||
|
@ -67,10 +67,6 @@ class UpscaleDataset(Dataset):
|
|||
low_res = low_res_rgb
|
||||
high_res = high_res_rgb
|
||||
|
||||
# Resize the images to reduce VRAM usage
|
||||
low_res = low_res.resize((410, 410), Image.LANCZOS)
|
||||
high_res = high_res.resize((820, 820), Image.LANCZOS)
|
||||
|
||||
# If a transform is provided (e.g. conversion to Tensor), apply it
|
||||
if self.transform:
|
||||
low_res = self.transform(low_res)
|
||||
|
@ -98,7 +94,7 @@ if __name__ =="__main__":
|
|||
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
|
||||
],
|
||||
'transform': transforms.Compose([transforms.ToTensor()]),
|
||||
'samples_per_file': 5000
|
||||
'samples_per_file': 20_000
|
||||
}
|
||||
trainer.load_data(dataset_params=dataset_params, batch_size=1)
|
||||
|
||||
|
|
Loading…
Reference in New Issue