finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 18 additions and 9 deletions
Showing only changes of commit 1f33f22bea - Show all commits

View File

@ -8,6 +8,8 @@ import torchvision.transforms as transforms
from aiia.model import AIIABase, AIIA
from sklearn.model_selection import train_test_split
from typing import Dict, List, Union, Optional
import base64
class ImageDataset(Dataset):
def __init__(self, dataframe, transform=None):
@ -20,23 +22,30 @@ class ImageDataset(Dataset):
def __getitem__(self, idx):
row = self.dataframe.iloc[idx]
# Decode image_512 from bytes
img_bytes = row['image_512']
img_stream = io.BytesIO(img_bytes)
low_res_image = Image.open(img_stream).convert('RGB')
# Convert string to bytes and handle decoding
try:
# Decode base64 string to bytes
low_res_bytes = base64.b64decode(row['image_512'])
high_res_bytes = base64.b64decode(row['image_1024'])
except Exception as e:
raise ValueError(f"Error decoding base64 string: {str(e)}")
# Decode image_1024 from bytes
high_res_bytes = row['image_1024']
high_stream = io.BytesIO(high_res_bytes)
high_res_image = Image.open(high_stream).convert('RGB')
# Create image streams
low_res_stream = io.BytesIO(low_res_bytes)
high_res_stream = io.BytesIO(high_res_bytes)
# Apply transformations if specified
# Open images
low_res_image = Image.open(low_res_stream).convert('RGB')
high_res_image = Image.open(high_res_stream).convert('RGB')
# Apply transformations
if self.transform:
low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image)
return {'low_ress': low_res_image, 'high_ress': high_res_image}
class ModelTrainer:
def __init__(self,
model: AIIA,