finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 15 additions and 14 deletions
Showing only changes of commit be5bb53620 - Show all commits

View File

@ -22,23 +22,23 @@ class ImageDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
row = self.dataframe.iloc[idx] row = self.dataframe.iloc[idx]
# Convert string to bytes and handle decoding
try: try:
# Decode base64 string to bytes # Directly use bytes data for PNG images
low_res_bytes = base64.b64decode(row['image_512']) low_res_bytes = row['image_512']
high_res_bytes = base64.b64decode(row['image_1024']) high_res_bytes = row['image_1024']
except Exception as e:
raise ValueError(f"Error decoding base64 string: {str(e)}")
# Create image streams # Create in-memory streams
low_res_stream = io.BytesIO(low_res_bytes) low_res_stream = io.BytesIO(low_res_bytes)
high_res_stream = io.BytesIO(high_res_bytes) high_res_stream = io.BytesIO(high_res_bytes)
# Open images # Open images with explicit RGB conversion
low_res_image = Image.open(low_res_stream).convert('RGB') low_res_image = Image.open(low_res_stream).convert('RGB')
high_res_image = Image.open(high_res_stream).convert('RGB') high_res_image = Image.open(high_res_stream).convert('RGB')
# Apply transformations except Exception as e:
raise ValueError(f"Image loading failed: {str(e)}")
# Apply transformations if specified
if self.transform: if self.transform:
low_res_image = self.transform(low_res_image) low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image) high_res_image = self.transform(high_res_image)
@ -46,6 +46,7 @@ class ImageDataset(Dataset):
return {'low_ress': low_res_image, 'high_ress': high_res_image} return {'low_ress': low_res_image, 'high_ress': high_res_image}
class ModelTrainer: class ModelTrainer:
def __init__(self, def __init__(self,
model: AIIA, model: AIIA,