from aiia import AIIABase from aiunn import aiuNN from aiunn import aiuNNTrainer import pandas as pd import io import base64 from PIL import Image, ImageFile from torch.utils.data import Dataset from torchvision import transforms class UpscaleDataset(Dataset): def __init__(self, parquet_files: list, transform=None, samples_per_file=10_000): combined_df = pd.DataFrame() for parquet_file in parquet_files: # Load a subset from each parquet file df = pd.read_parquet(parquet_file, columns=['image_410', 'image_820']).head(samples_per_file) combined_df = pd.concat([combined_df, df], ignore_index=True) # Validate rows (ensuring each value is bytes or str) self.df = combined_df.apply(self._validate_row, axis=1) self.transform = transform self.failed_indices = set() def _validate_row(self, row): for col in ['image_410', 'image_820']: if not isinstance(row[col], (bytes, str)): raise ValueError(f"Invalid data type in column {col}: {type(row[col])}") return row def _decode_image(self, data): try: if isinstance(data, str): return base64.b64decode(data) elif isinstance(data, bytes): return data raise ValueError(f"Unsupported data type: {type(data)}") except Exception as e: raise RuntimeError(f"Decoding failed: {str(e)}") def __len__(self): return len(self.df) def __getitem__(self, idx): # If previous call failed for this index, use a different index if idx in self.failed_indices: return self[(idx + 1) % len(self)] try: row = self.df.iloc[idx] low_res_bytes = self._decode_image(row['image_410']) high_res_bytes = self._decode_image(row['image_820']) ImageFile.LOAD_TRUNCATED_IMAGES = True # Open image bytes with Pillow and convert to RGBA first low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA') high_res_rgba = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA') # Create a new RGB image with black background low_res_rgb = Image.new("RGB", low_res_rgba.size, (0, 0, 0)) high_res_rgb = Image.new("RGB", high_res_rgba.size, (0, 0, 0)) # Composite the original image over the black background low_res_rgb.paste(low_res_rgba, mask=low_res_rgba.split()[3]) high_res_rgb.paste(high_res_rgba, mask=high_res_rgba.split()[3]) # Now we have true 3-channel RGB images with transparent areas converted to black low_res = low_res_rgb high_res = high_res_rgb # If a transform is provided (e.g. conversion to Tensor), apply it if self.transform: low_res = self.transform(low_res) high_res = self.transform(high_res) return low_res, high_res except Exception as e: print(f"\nError at index {idx}: {str(e)}") self.failed_indices.add(idx) return self[(idx + 1) % len(self)] if __name__ =="__main__": # Load your base model and upscaler pretrained_model_path = "/root/vision/AIIA/AIIA-base-512" base_model = AIIABase.load(pretrained_model_path, precision="bf16") upscaler = aiuNN(base_model) # Create trainer with your dataset class trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset) # Load data using parameters for your dataset dataset_params = { 'parquet_files': [ "/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet" ], 'transform': transforms.Compose([transforms.ToTensor()]), 'samples_per_file': 20_000 } trainer.load_data(dataset_params=dataset_params, batch_size=1) # Fine-tune the model trainer.finetune(output_path="trained_models")