From 34c547fb23dff78ff86f68d1f298c91fce48f4ed Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Thu, 30 Jan 2025 13:14:45 +0100 Subject: [PATCH] next try --- src/aiunn/finetune.py | 89 ++++++++++++++++++++++++++++++++----------- 1 file changed, 67 insertions(+), 22 deletions(-) diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 22a4efa..4bdf821 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -11,8 +11,22 @@ from typing import Dict, List, Union, Optional import base64 +import torch +from torch.utils.data import Dataset +from PIL import Image, ImageFile +import io +import base64 +import pandas as pd + class ImageDataset(Dataset): def __init__(self, dataframe, transform=None): + """ + Initialize the dataset with a dataframe containing image data + + Args: + dataframe (pd.DataFrame): DataFrame containing 'image_512' and 'image_1024' columns + transform (callable, optional): Optional transform to be applied to both images + """ self.dataframe = dataframe self.transform = transform @@ -20,44 +34,75 @@ class ImageDataset(Dataset): return len(self.dataframe) def __getitem__(self, idx): + """ + Get a pair of low and high resolution images + + Args: + idx (int): Index of the data point + + Returns: + dict: Contains 'low_ress' and 'high_ress' PIL images or transformed tensors + """ row = self.dataframe.iloc[idx] try: + # Handle both bytes and base64 encoded strings + low_res_data = row['image_512'] + high_res_data = row['image_1024'] + + if isinstance(low_res_data, str): + # Decode base64 string to bytes + low_res_data = base64.b64decode(low_res_data) + high_res_data = base64.b64decode(high_res_data) + # Verify data is valid before creating BytesIO - if not isinstance(row['image_512'], bytes) or not isinstance(row['image_1024'], bytes): - raise ValueError("Image data must be in bytes format") + if not isinstance(low_res_data, bytes) or not isinstance(high_res_data, bytes): + raise ValueError(f"Invalid image data format at index {idx}") - low_res_stream = io.BytesIO(row['image_512']) - high_res_stream = io.BytesIO(row['image_1024']) + # Create image streams + low_res_stream = io.BytesIO(low_res_data) + high_res_stream = io.BytesIO(high_res_data) - # Reset stream position - low_res_stream.seek(0) - high_res_stream.seek(0) - - # Enable loading of truncated images if necessary + # Enable loading of truncated images ImageFile.LOAD_TRUNCATED_IMAGES = True + # Load and convert images to RGB low_res_image = Image.open(low_res_stream).convert('RGB') high_res_image = Image.open(high_res_stream).convert('RGB') + # Create fresh copies for verify() since it modifies the image object + low_res_verify = low_res_image.copy() + high_res_verify = high_res_image.copy() + # Verify images are valid - low_res_image.verify() - high_res_image.verify() + try: + low_res_verify.verify() + high_res_verify.verify() + except Exception as e: + raise ValueError(f"Image verification failed at index {idx}: {str(e)}") + finally: + low_res_verify.close() + high_res_verify.close() + + # Apply transforms if specified + if self.transform: + low_res_image = self.transform(low_res_image) + high_res_image = self.transform(high_res_image) + + return { + 'low_ress': low_res_image, # Note: Using 'low_ress' to match ModelTrainer + 'high_ress': high_res_image # Note: Using 'high_ress' to match ModelTrainer + } except Exception as e: - raise ValueError(f"Image loading failed: {str(e)}") + raise RuntimeError(f"Error loading images at index {idx}: {str(e)}") finally: - low_res_stream.close() - high_res_stream.close() - - if self.transform: - low_res_image = self.transform(low_res_image) - high_res_image = self.transform(high_res_image) - - return {'low_ress': low_res_image, 'high_ress': high_res_image} - - + # Ensure streams are closed + if 'low_res_stream' in locals(): + low_res_stream.close() + if 'high_res_stream' in locals(): + high_res_stream.close() class ModelTrainer: def __init__(self,