diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index ce1b07a..7af2695 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -8,6 +8,8 @@ import torchvision.transforms as transforms from aiia.model import AIIABase, AIIA from sklearn.model_selection import train_test_split from typing import Dict, List, Union, Optional +import base64 + class ImageDataset(Dataset): def __init__(self, dataframe, transform=None): @@ -20,23 +22,30 @@ class ImageDataset(Dataset): def __getitem__(self, idx): row = self.dataframe.iloc[idx] - # Decode image_512 from bytes - img_bytes = row['image_512'] - img_stream = io.BytesIO(img_bytes) - low_res_image = Image.open(img_stream).convert('RGB') + # Convert string to bytes and handle decoding + try: + # Decode base64 string to bytes + low_res_bytes = base64.b64decode(row['image_512']) + high_res_bytes = base64.b64decode(row['image_1024']) + except Exception as e: + raise ValueError(f"Error decoding base64 string: {str(e)}") - # Decode image_1024 from bytes - high_res_bytes = row['image_1024'] - high_stream = io.BytesIO(high_res_bytes) - high_res_image = Image.open(high_stream).convert('RGB') + # Create image streams + low_res_stream = io.BytesIO(low_res_bytes) + high_res_stream = io.BytesIO(high_res_bytes) - # Apply transformations if specified + # Open images + low_res_image = Image.open(low_res_stream).convert('RGB') + high_res_image = Image.open(high_res_stream).convert('RGB') + + # Apply transformations if self.transform: low_res_image = self.transform(low_res_image) high_res_image = self.transform(high_res_image) return {'low_ress': low_res_image, 'high_ress': high_res_image} + class ModelTrainer: def __init__(self, model: AIIA,