develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 67 additions and 22 deletions
Showing only changes of commit 34c547fb23 - Show all commits

View File

@ -11,8 +11,22 @@ from typing import Dict, List, Union, Optional
import base64
import torch
from torch.utils.data import Dataset
from PIL import Image, ImageFile
import io
import base64
import pandas as pd
class ImageDataset(Dataset):
def __init__(self, dataframe, transform=None):
"""
Initialize the dataset with a dataframe containing image data
Args:
dataframe (pd.DataFrame): DataFrame containing 'image_512' and 'image_1024' columns
transform (callable, optional): Optional transform to be applied to both images
"""
self.dataframe = dataframe
self.transform = transform
@ -20,44 +34,75 @@ class ImageDataset(Dataset):
return len(self.dataframe)
def __getitem__(self, idx):
"""
Get a pair of low and high resolution images
Args:
idx (int): Index of the data point
Returns:
dict: Contains 'low_ress' and 'high_ress' PIL images or transformed tensors
"""
row = self.dataframe.iloc[idx]
try:
# Handle both bytes and base64 encoded strings
low_res_data = row['image_512']
high_res_data = row['image_1024']
if isinstance(low_res_data, str):
# Decode base64 string to bytes
low_res_data = base64.b64decode(low_res_data)
high_res_data = base64.b64decode(high_res_data)
# Verify data is valid before creating BytesIO
if not isinstance(row['image_512'], bytes) or not isinstance(row['image_1024'], bytes):
raise ValueError("Image data must be in bytes format")
if not isinstance(low_res_data, bytes) or not isinstance(high_res_data, bytes):
raise ValueError(f"Invalid image data format at index {idx}")
low_res_stream = io.BytesIO(row['image_512'])
high_res_stream = io.BytesIO(row['image_1024'])
# Create image streams
low_res_stream = io.BytesIO(low_res_data)
high_res_stream = io.BytesIO(high_res_data)
# Reset stream position
low_res_stream.seek(0)
high_res_stream.seek(0)
# Enable loading of truncated images if necessary
# Enable loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Load and convert images to RGB
low_res_image = Image.open(low_res_stream).convert('RGB')
high_res_image = Image.open(high_res_stream).convert('RGB')
# Create fresh copies for verify() since it modifies the image object
low_res_verify = low_res_image.copy()
high_res_verify = high_res_image.copy()
# Verify images are valid
low_res_image.verify()
high_res_image.verify()
try:
low_res_verify.verify()
high_res_verify.verify()
except Exception as e:
raise ValueError(f"Image loading failed: {str(e)}")
raise ValueError(f"Image verification failed at index {idx}: {str(e)}")
finally:
low_res_stream.close()
high_res_stream.close()
low_res_verify.close()
high_res_verify.close()
# Apply transforms if specified
if self.transform:
low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image)
return {'low_ress': low_res_image, 'high_ress': high_res_image}
return {
'low_ress': low_res_image, # Note: Using 'low_ress' to match ModelTrainer
'high_ress': high_res_image # Note: Using 'high_ress' to match ModelTrainer
}
except Exception as e:
raise RuntimeError(f"Error loading images at index {idx}: {str(e)}")
finally:
# Ensure streams are closed
if 'low_res_stream' in locals():
low_res_stream.close()
if 'high_res_stream' in locals():
high_res_stream.close()
class ModelTrainer:
def __init__(self,