finetune_class #1
|
@ -11,8 +11,22 @@ from typing import Dict, List, Union, Optional
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
class ImageDataset(Dataset):
|
class ImageDataset(Dataset):
|
||||||
def __init__(self, dataframe, transform=None):
|
def __init__(self, dataframe, transform=None):
|
||||||
|
"""
|
||||||
|
Initialize the dataset with a dataframe containing image data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataframe (pd.DataFrame): DataFrame containing 'image_512' and 'image_1024' columns
|
||||||
|
transform (callable, optional): Optional transform to be applied to both images
|
||||||
|
"""
|
||||||
self.dataframe = dataframe
|
self.dataframe = dataframe
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
|
||||||
|
@ -20,44 +34,75 @@ class ImageDataset(Dataset):
|
||||||
return len(self.dataframe)
|
return len(self.dataframe)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
|
"""
|
||||||
|
Get a pair of low and high resolution images
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int): Index of the data point
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Contains 'low_ress' and 'high_ress' PIL images or transformed tensors
|
||||||
|
"""
|
||||||
row = self.dataframe.iloc[idx]
|
row = self.dataframe.iloc[idx]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Handle both bytes and base64 encoded strings
|
||||||
|
low_res_data = row['image_512']
|
||||||
|
high_res_data = row['image_1024']
|
||||||
|
|
||||||
|
if isinstance(low_res_data, str):
|
||||||
|
# Decode base64 string to bytes
|
||||||
|
low_res_data = base64.b64decode(low_res_data)
|
||||||
|
high_res_data = base64.b64decode(high_res_data)
|
||||||
|
|
||||||
# Verify data is valid before creating BytesIO
|
# Verify data is valid before creating BytesIO
|
||||||
if not isinstance(row['image_512'], bytes) or not isinstance(row['image_1024'], bytes):
|
if not isinstance(low_res_data, bytes) or not isinstance(high_res_data, bytes):
|
||||||
raise ValueError("Image data must be in bytes format")
|
raise ValueError(f"Invalid image data format at index {idx}")
|
||||||
|
|
||||||
low_res_stream = io.BytesIO(row['image_512'])
|
# Create image streams
|
||||||
high_res_stream = io.BytesIO(row['image_1024'])
|
low_res_stream = io.BytesIO(low_res_data)
|
||||||
|
high_res_stream = io.BytesIO(high_res_data)
|
||||||
|
|
||||||
# Reset stream position
|
# Enable loading of truncated images
|
||||||
low_res_stream.seek(0)
|
|
||||||
high_res_stream.seek(0)
|
|
||||||
|
|
||||||
# Enable loading of truncated images if necessary
|
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
# Load and convert images to RGB
|
||||||
low_res_image = Image.open(low_res_stream).convert('RGB')
|
low_res_image = Image.open(low_res_stream).convert('RGB')
|
||||||
high_res_image = Image.open(high_res_stream).convert('RGB')
|
high_res_image = Image.open(high_res_stream).convert('RGB')
|
||||||
|
|
||||||
|
# Create fresh copies for verify() since it modifies the image object
|
||||||
|
low_res_verify = low_res_image.copy()
|
||||||
|
high_res_verify = high_res_image.copy()
|
||||||
|
|
||||||
# Verify images are valid
|
# Verify images are valid
|
||||||
low_res_image.verify()
|
try:
|
||||||
high_res_image.verify()
|
low_res_verify.verify()
|
||||||
|
high_res_verify.verify()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Image loading failed: {str(e)}")
|
raise ValueError(f"Image verification failed at index {idx}: {str(e)}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
low_res_stream.close()
|
low_res_verify.close()
|
||||||
high_res_stream.close()
|
high_res_verify.close()
|
||||||
|
|
||||||
|
# Apply transforms if specified
|
||||||
if self.transform:
|
if self.transform:
|
||||||
low_res_image = self.transform(low_res_image)
|
low_res_image = self.transform(low_res_image)
|
||||||
high_res_image = self.transform(high_res_image)
|
high_res_image = self.transform(high_res_image)
|
||||||
|
|
||||||
return {'low_ress': low_res_image, 'high_ress': high_res_image}
|
return {
|
||||||
|
'low_ress': low_res_image, # Note: Using 'low_ress' to match ModelTrainer
|
||||||
|
'high_ress': high_res_image # Note: Using 'high_ress' to match ModelTrainer
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error loading images at index {idx}: {str(e)}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Ensure streams are closed
|
||||||
|
if 'low_res_stream' in locals():
|
||||||
|
low_res_stream.close()
|
||||||
|
if 'high_res_stream' in locals():
|
||||||
|
high_res_stream.close()
|
||||||
|
|
||||||
class ModelTrainer:
|
class ModelTrainer:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|
Loading…
Reference in New Issue