From 09f196294c528a0f103a2dd99cbb75e57031aeeb Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Tue, 25 Feb 2025 15:47:10 +0100 Subject: [PATCH] new upsampler scripts --- example.py | 106 ++++++++++++++++++++++++++++++++++++++++++++ src/aiunn/config.py | 2 +- 2 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 example.py diff --git a/example.py b/example.py new file mode 100644 index 0000000..8492d51 --- /dev/null +++ b/example.py @@ -0,0 +1,106 @@ +from aiia import AIIABase +from aiunn import aiuNN +from aiunn import aiuNNTrainer +import pandas as pd +import io +import base64 +from PIL import Image, ImageFile +from torch.utils.data import Dataset +from torchvision import transforms + + + +class UpscaleDataset(Dataset): + def __init__(self, parquet_files: list, transform=None, samples_per_file=2500): + combined_df = pd.DataFrame() + for parquet_file in parquet_files: + # Load a subset from each parquet file + df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(samples_per_file) + combined_df = pd.concat([combined_df, df], ignore_index=True) + + # Validate rows (ensuring each value is bytes or str) + self.df = combined_df.apply(self._validate_row, axis=1) + self.transform = transform + self.failed_indices = set() + + def _validate_row(self, row): + for col in ['image_512', 'image_1024']: + if not isinstance(row[col], (bytes, str)): + raise ValueError(f"Invalid data type in column {col}: {type(row[col])}") + return row + + def _decode_image(self, data): + try: + if isinstance(data, str): + return base64.b64decode(data) + elif isinstance(data, bytes): + return data + raise ValueError(f"Unsupported data type: {type(data)}") + except Exception as e: + raise RuntimeError(f"Decoding failed: {str(e)}") + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + # If previous call failed for this index, use a different index + if idx in self.failed_indices: + return self[(idx + 1) % len(self)] + try: + row = self.df.iloc[idx] + low_res_bytes = self._decode_image(row['image_512']) + high_res_bytes = self._decode_image(row['image_1024']) + ImageFile.LOAD_TRUNCATED_IMAGES = True + # Open image bytes with Pillow and convert to RGBA first + low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA') + high_res_rgba = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA') + + # Create a new RGB image with black background + low_res_rgb = Image.new("RGB", low_res_rgba.size, (0, 0, 0)) + high_res_rgb = Image.new("RGB", high_res_rgba.size, (0, 0, 0)) + + # Composite the original image over the black background + low_res_rgb.paste(low_res_rgba, mask=low_res_rgba.split()[3]) + high_res_rgb.paste(high_res_rgba, mask=high_res_rgba.split()[3]) + + # Now we have true 3-channel RGB images with transparent areas converted to black + low_res = low_res_rgb + high_res = high_res_rgb + + # Resize the images to reduce VRAM usage + low_res = low_res.resize((384, 384), Image.LANCZOS) + high_res = high_res.resize((768, 768), Image.LANCZOS) + + # If a transform is provided (e.g. conversion to Tensor), apply it + if self.transform: + low_res = self.transform(low_res) + high_res = self.transform(high_res) + return low_res, high_res + except Exception as e: + print(f"\nError at index {idx}: {str(e)}") + self.failed_indices.add(idx) + return self[(idx + 1) % len(self)] + + +if __name__ =="__main__": + # Load your base model and upscaler + pretrained_model_path = "/root/vision/AIIA/AIIA-base-512" + base_model = AIIABase.load(pretrained_model_path, precision="bf16") + upscaler = aiuNN(base_model) + + # Create trainer with your dataset class + trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset) + + # Load data using parameters for your dataset + dataset_params = { + 'parquet_files': [ + "/root/training_data/vision-dataset/image_upscaler.parquet", + "/root/training_data/vision-dataset/image_vec_upscaler.parquet" + ], + 'transform': transforms.Compose([transforms.ToTensor()]), + 'samples_per_file': 2500 + } + trainer.load_data(dataset_params=dataset_params, batch_size=1) + + # Fine-tune the model + trainer.finetune(output_path="trained_models") \ No newline at end of file diff --git a/src/aiunn/config.py b/src/aiunn/config.py index 6fc72f1..b56699b 100644 --- a/src/aiunn/config.py +++ b/src/aiunn/config.py @@ -1,7 +1,7 @@ from aiia import AIIAConfig -class UpsamplerConfig(AIIAConfig): +class aiuNNConfig(AIIAConfig): def __init__( self, base_config=None,