develop #4
|
@ -11,8 +11,22 @@ from typing import Dict, List, Union, Optional
|
|||
import base64
|
||||
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from PIL import Image, ImageFile
|
||||
import io
|
||||
import base64
|
||||
import pandas as pd
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(self, dataframe, transform=None):
|
||||
"""
|
||||
Initialize the dataset with a dataframe containing image data
|
||||
|
||||
Args:
|
||||
dataframe (pd.DataFrame): DataFrame containing 'image_512' and 'image_1024' columns
|
||||
transform (callable, optional): Optional transform to be applied to both images
|
||||
"""
|
||||
self.dataframe = dataframe
|
||||
self.transform = transform
|
||||
|
||||
|
@ -20,44 +34,75 @@ class ImageDataset(Dataset):
|
|||
return len(self.dataframe)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Get a pair of low and high resolution images
|
||||
|
||||
Args:
|
||||
idx (int): Index of the data point
|
||||
|
||||
Returns:
|
||||
dict: Contains 'low_ress' and 'high_ress' PIL images or transformed tensors
|
||||
"""
|
||||
row = self.dataframe.iloc[idx]
|
||||
|
||||
try:
|
||||
# Handle both bytes and base64 encoded strings
|
||||
low_res_data = row['image_512']
|
||||
high_res_data = row['image_1024']
|
||||
|
||||
if isinstance(low_res_data, str):
|
||||
# Decode base64 string to bytes
|
||||
low_res_data = base64.b64decode(low_res_data)
|
||||
high_res_data = base64.b64decode(high_res_data)
|
||||
|
||||
# Verify data is valid before creating BytesIO
|
||||
if not isinstance(row['image_512'], bytes) or not isinstance(row['image_1024'], bytes):
|
||||
raise ValueError("Image data must be in bytes format")
|
||||
if not isinstance(low_res_data, bytes) or not isinstance(high_res_data, bytes):
|
||||
raise ValueError(f"Invalid image data format at index {idx}")
|
||||
|
||||
low_res_stream = io.BytesIO(row['image_512'])
|
||||
high_res_stream = io.BytesIO(row['image_1024'])
|
||||
# Create image streams
|
||||
low_res_stream = io.BytesIO(low_res_data)
|
||||
high_res_stream = io.BytesIO(high_res_data)
|
||||
|
||||
# Reset stream position
|
||||
low_res_stream.seek(0)
|
||||
high_res_stream.seek(0)
|
||||
|
||||
# Enable loading of truncated images if necessary
|
||||
# Enable loading of truncated images
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
# Load and convert images to RGB
|
||||
low_res_image = Image.open(low_res_stream).convert('RGB')
|
||||
high_res_image = Image.open(high_res_stream).convert('RGB')
|
||||
|
||||
# Create fresh copies for verify() since it modifies the image object
|
||||
low_res_verify = low_res_image.copy()
|
||||
high_res_verify = high_res_image.copy()
|
||||
|
||||
# Verify images are valid
|
||||
low_res_image.verify()
|
||||
high_res_image.verify()
|
||||
|
||||
try:
|
||||
low_res_verify.verify()
|
||||
high_res_verify.verify()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Image loading failed: {str(e)}")
|
||||
|
||||
raise ValueError(f"Image verification failed at index {idx}: {str(e)}")
|
||||
finally:
|
||||
low_res_stream.close()
|
||||
high_res_stream.close()
|
||||
low_res_verify.close()
|
||||
high_res_verify.close()
|
||||
|
||||
# Apply transforms if specified
|
||||
if self.transform:
|
||||
low_res_image = self.transform(low_res_image)
|
||||
high_res_image = self.transform(high_res_image)
|
||||
|
||||
return {'low_ress': low_res_image, 'high_ress': high_res_image}
|
||||
return {
|
||||
'low_ress': low_res_image, # Note: Using 'low_ress' to match ModelTrainer
|
||||
'high_ress': high_res_image # Note: Using 'high_ress' to match ModelTrainer
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error loading images at index {idx}: {str(e)}")
|
||||
|
||||
finally:
|
||||
# Ensure streams are closed
|
||||
if 'low_res_stream' in locals():
|
||||
low_res_stream.close()
|
||||
if 'high_res_stream' in locals():
|
||||
high_res_stream.close()
|
||||
|
||||
class ModelTrainer:
|
||||
def __init__(self,
|
||||
|
|
Loading…
Reference in New Issue