From fca74fb8d2cd6ebd76a4b4769091fd0908ea2bb3 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Fri, 21 Feb 2025 21:12:14 +0100 Subject: [PATCH] new loading --- src/aiunn/finetune.py | 68 +++++++++++++++++++++++++++++++------------ 1 file changed, 50 insertions(+), 18 deletions(-) diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index e3bdf46..b47478c 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -1,46 +1,78 @@ import pandas as pd import io -from PIL import Image +from PIL import Image, ImageFile from torch.utils.data import Dataset from torchvision import transforms from aiia import AIIABase import csv from tqdm import tqdm +import base64 class UpscaleDataset(Dataset): def __init__(self, parquet_files: list, transform=None): - # Initialize an empty DataFrame to hold the combined data combined_df = pd.DataFrame() - - # Iterate through each Parquet file in the list and load it into a DataFrame for parquet_file in parquet_files: + # Load data with chunking for memory efficiency df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(10000) combined_df = pd.concat([combined_df, df], ignore_index=True) - self.df = combined_df + + # Validate data format + self.df = combined_df.apply(self._validate_row, axis=1) self.transform = transform + self.failed_indices = set() + + def _validate_row(self, row): + """Ensure both images exist and have correct dimensions""" + for col in ['image_512', 'image_1024']: + if not isinstance(row[col], (bytes, str)): + raise ValueError(f"Invalid data type in column {col}: {type(row[col])}") + return row + + def _decode_image(self, data): + """Universal decoder handling both base64 strings and bytes""" + try: + if isinstance(data, str): + # Handle base64 encoded strings + return base64.b64decode(data) + elif isinstance(data, bytes): + return data + raise ValueError(f"Unsupported data type: {type(data)}") + except Exception as e: + raise RuntimeError(f"Decoding failed: {str(e)}") def __len__(self): return len(self.df) def __getitem__(self, idx): + if idx in self.failed_indices: + return self[(idx + 1) % len(self)] # Skip failed indices + try: row = self.df.iloc[idx] - # Convert string to bytes if necessary - low_res_bytes = row['image_512'].encode('latin-1') if isinstance(row['image_512'], str) else row['image_512'] - high_res_bytes = row['image_1024'].encode('latin-1') if isinstance(row['image_1024'], str) else row['image_1024'] - - # Decode the bytes into images - low_res_image = Image.open(io.BytesIO(low_res_bytes)).convert('RGB') - high_res_image = Image.open(io.BytesIO(high_res_bytes)).convert('RGB') + # Decode both images + low_res_bytes = self._decode_image(row['image_512']) + high_res_bytes = self._decode_image(row['image_1024']) + + # Load images with truncation handling + ImageFile.LOAD_TRUNCATED_IMAGES = True + low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB') + high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB') + + # Validate image sizes + if low_res.size != (512, 512) or high_res.size != (1024, 1024): + raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}") + if self.transform: - low_res_image = self.transform(low_res_image) - high_res_image = self.transform(high_res_image) - return low_res_image, high_res_image + low_res = self.transform(low_res) + high_res = self.transform(high_res) + + return low_res, high_res + except Exception as e: - print(f"Error processing index {idx}: {str(e)}") - # You might want to either skip this sample or return a default value - raise e + print(f"\nError at index {idx}: {str(e)}") + self.failed_indices.add(idx) + return self[(idx + 1) % len(self)] # Return next valid sample # Example transform: converting PIL images to tensors transform = transforms.Compose([