finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 67 additions and 22 deletions
Showing only changes of commit 34c547fb23 - Show all commits

View File

@ -11,8 +11,22 @@ from typing import Dict, List, Union, Optional
import base64 import base64
import torch
from torch.utils.data import Dataset
from PIL import Image, ImageFile
import io
import base64
import pandas as pd
class ImageDataset(Dataset): class ImageDataset(Dataset):
def __init__(self, dataframe, transform=None): def __init__(self, dataframe, transform=None):
"""
Initialize the dataset with a dataframe containing image data
Args:
dataframe (pd.DataFrame): DataFrame containing 'image_512' and 'image_1024' columns
transform (callable, optional): Optional transform to be applied to both images
"""
self.dataframe = dataframe self.dataframe = dataframe
self.transform = transform self.transform = transform
@ -20,44 +34,75 @@ class ImageDataset(Dataset):
return len(self.dataframe) return len(self.dataframe)
def __getitem__(self, idx): def __getitem__(self, idx):
"""
Get a pair of low and high resolution images
Args:
idx (int): Index of the data point
Returns:
dict: Contains 'low_ress' and 'high_ress' PIL images or transformed tensors
"""
row = self.dataframe.iloc[idx] row = self.dataframe.iloc[idx]
try: try:
# Handle both bytes and base64 encoded strings
low_res_data = row['image_512']
high_res_data = row['image_1024']
if isinstance(low_res_data, str):
# Decode base64 string to bytes
low_res_data = base64.b64decode(low_res_data)
high_res_data = base64.b64decode(high_res_data)
# Verify data is valid before creating BytesIO # Verify data is valid before creating BytesIO
if not isinstance(row['image_512'], bytes) or not isinstance(row['image_1024'], bytes): if not isinstance(low_res_data, bytes) or not isinstance(high_res_data, bytes):
raise ValueError("Image data must be in bytes format") raise ValueError(f"Invalid image data format at index {idx}")
low_res_stream = io.BytesIO(row['image_512']) # Create image streams
high_res_stream = io.BytesIO(row['image_1024']) low_res_stream = io.BytesIO(low_res_data)
high_res_stream = io.BytesIO(high_res_data)
# Reset stream position # Enable loading of truncated images
low_res_stream.seek(0)
high_res_stream.seek(0)
# Enable loading of truncated images if necessary
ImageFile.LOAD_TRUNCATED_IMAGES = True ImageFile.LOAD_TRUNCATED_IMAGES = True
# Load and convert images to RGB
low_res_image = Image.open(low_res_stream).convert('RGB') low_res_image = Image.open(low_res_stream).convert('RGB')
high_res_image = Image.open(high_res_stream).convert('RGB') high_res_image = Image.open(high_res_stream).convert('RGB')
# Create fresh copies for verify() since it modifies the image object
low_res_verify = low_res_image.copy()
high_res_verify = high_res_image.copy()
# Verify images are valid # Verify images are valid
low_res_image.verify() try:
high_res_image.verify() low_res_verify.verify()
high_res_verify.verify()
except Exception as e: except Exception as e:
raise ValueError(f"Image loading failed: {str(e)}") raise ValueError(f"Image verification failed at index {idx}: {str(e)}")
finally: finally:
low_res_stream.close() low_res_verify.close()
high_res_stream.close() high_res_verify.close()
# Apply transforms if specified
if self.transform: if self.transform:
low_res_image = self.transform(low_res_image) low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image) high_res_image = self.transform(high_res_image)
return {'low_ress': low_res_image, 'high_ress': high_res_image} return {
'low_ress': low_res_image, # Note: Using 'low_ress' to match ModelTrainer
'high_ress': high_res_image # Note: Using 'high_ress' to match ModelTrainer
}
except Exception as e:
raise RuntimeError(f"Error loading images at index {idx}: {str(e)}")
finally:
# Ensure streams are closed
if 'low_res_stream' in locals():
low_res_stream.close()
if 'high_res_stream' in locals():
high_res_stream.close()
class ModelTrainer: class ModelTrainer:
def __init__(self, def __init__(self,