diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 05fde00..22a4efa 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -1,6 +1,6 @@ import torch import pandas as pd -from PIL import Image +from PIL import Image, ImageFile import io from torch import nn from torch.utils.data import Dataset, DataLoader @@ -21,31 +21,40 @@ class ImageDataset(Dataset): def __getitem__(self, idx): row = self.dataframe.iloc[idx] - + try: - # Convert string data to bytes if needed - low_res_bytes = row['image_512'].encode() if isinstance(row['image_512'], str) else row['image_512'] - high_res_bytes = row['image_1024'].encode() if isinstance(row['image_1024'], str) else row['image_1024'] + # Verify data is valid before creating BytesIO + if not isinstance(row['image_512'], bytes) or not isinstance(row['image_1024'], bytes): + raise ValueError("Image data must be in bytes format") + + low_res_stream = io.BytesIO(row['image_512']) + high_res_stream = io.BytesIO(row['image_1024']) + + # Reset stream position + low_res_stream.seek(0) + high_res_stream.seek(0) + + # Enable loading of truncated images if necessary + ImageFile.LOAD_TRUNCATED_IMAGES = True - # Create BytesIO objects - low_res_stream = io.BytesIO(low_res_bytes) - high_res_stream = io.BytesIO(high_res_bytes) - - # Open images low_res_image = Image.open(low_res_stream).convert('RGB') high_res_image = Image.open(high_res_stream).convert('RGB') - - # Close the streams - low_res_stream.close() - high_res_stream.close() - + + # Verify images are valid + low_res_image.verify() + high_res_image.verify() + except Exception as e: raise ValueError(f"Image loading failed: {str(e)}") - + + finally: + low_res_stream.close() + high_res_stream.close() + if self.transform: low_res_image = self.transform(low_res_image) high_res_image = self.transform(high_res_image) - + return {'low_ress': low_res_image, 'high_ress': high_res_image}