aiuNN/example.py

106 lines
4.2 KiB
Python

from aiia import AIIABase
from aiunn import aiuNN
from aiunn import aiuNNTrainer
import pandas as pd
import io
import base64
from PIL import Image, ImageFile
from torch.utils.data import Dataset
from torchvision import transforms
class UpscaleDataset(Dataset):
def __init__(self, parquet_files: list, transform=None, samples_per_file=10_000):
combined_df = pd.DataFrame()
for parquet_file in parquet_files:
# Load a subset from each parquet file
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(samples_per_file)
combined_df = pd.concat([combined_df, df], ignore_index=True)
# Validate rows (ensuring each value is bytes or str)
self.df = combined_df.apply(self._validate_row, axis=1)
self.transform = transform
self.failed_indices = set()
def _validate_row(self, row):
for col in ['image_512', 'image_1024']:
if not isinstance(row[col], (bytes, str)):
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
return row
def _decode_image(self, data):
try:
if isinstance(data, str):
return base64.b64decode(data)
elif isinstance(data, bytes):
return data
raise ValueError(f"Unsupported data type: {type(data)}")
except Exception as e:
raise RuntimeError(f"Decoding failed: {str(e)}")
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# If previous call failed for this index, use a different index
if idx in self.failed_indices:
return self[(idx + 1) % len(self)]
try:
row = self.df.iloc[idx]
low_res_bytes = self._decode_image(row['image_512'])
high_res_bytes = self._decode_image(row['image_1024'])
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Open image bytes with Pillow and convert to RGBA first
low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA')
high_res_rgba = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA')
# Create a new RGB image with black background
low_res_rgb = Image.new("RGB", low_res_rgba.size, (0, 0, 0))
high_res_rgb = Image.new("RGB", high_res_rgba.size, (0, 0, 0))
# Composite the original image over the black background
low_res_rgb.paste(low_res_rgba, mask=low_res_rgba.split()[3])
high_res_rgb.paste(high_res_rgba, mask=high_res_rgba.split()[3])
# Now we have true 3-channel RGB images with transparent areas converted to black
low_res = low_res_rgb
high_res = high_res_rgb
# Resize the images to reduce VRAM usage
low_res = low_res.resize((410, 410), Image.LANCZOS)
high_res = high_res.resize((820, 820), Image.LANCZOS)
# If a transform is provided (e.g. conversion to Tensor), apply it
if self.transform:
low_res = self.transform(low_res)
high_res = self.transform(high_res)
return low_res, high_res
except Exception as e:
print(f"\nError at index {idx}: {str(e)}")
self.failed_indices.add(idx)
return self[(idx + 1) % len(self)]
if __name__ =="__main__":
# Load your base model and upscaler
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
upscaler = aiuNN(base_model)
# Create trainer with your dataset class
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)
# Load data using parameters for your dataset
dataset_params = {
'parquet_files': [
"/root/training_data/vision-dataset/image_upscaler.parquet",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
],
'transform': transforms.Compose([transforms.ToTensor()]),
'samples_per_file': 5000
}
trainer.load_data(dataset_params=dataset_params, batch_size=1)
# Fine-tune the model
trainer.finetune(output_path="trained_models")