develop #4
|
@ -8,6 +8,8 @@ import torchvision.transforms as transforms
|
||||||
from aiia.model import AIIABase, AIIA
|
from aiia.model import AIIABase, AIIA
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from typing import Dict, List, Union, Optional
|
from typing import Dict, List, Union, Optional
|
||||||
|
import base64
|
||||||
|
|
||||||
|
|
||||||
class ImageDataset(Dataset):
|
class ImageDataset(Dataset):
|
||||||
def __init__(self, dataframe, transform=None):
|
def __init__(self, dataframe, transform=None):
|
||||||
|
@ -20,23 +22,30 @@ class ImageDataset(Dataset):
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
row = self.dataframe.iloc[idx]
|
row = self.dataframe.iloc[idx]
|
||||||
|
|
||||||
# Decode image_512 from bytes
|
# Convert string to bytes and handle decoding
|
||||||
img_bytes = row['image_512']
|
try:
|
||||||
img_stream = io.BytesIO(img_bytes)
|
# Decode base64 string to bytes
|
||||||
low_res_image = Image.open(img_stream).convert('RGB')
|
low_res_bytes = base64.b64decode(row['image_512'])
|
||||||
|
high_res_bytes = base64.b64decode(row['image_1024'])
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error decoding base64 string: {str(e)}")
|
||||||
|
|
||||||
# Decode image_1024 from bytes
|
# Create image streams
|
||||||
high_res_bytes = row['image_1024']
|
low_res_stream = io.BytesIO(low_res_bytes)
|
||||||
high_stream = io.BytesIO(high_res_bytes)
|
high_res_stream = io.BytesIO(high_res_bytes)
|
||||||
high_res_image = Image.open(high_stream).convert('RGB')
|
|
||||||
|
|
||||||
# Apply transformations if specified
|
# Open images
|
||||||
|
low_res_image = Image.open(low_res_stream).convert('RGB')
|
||||||
|
high_res_image = Image.open(high_res_stream).convert('RGB')
|
||||||
|
|
||||||
|
# Apply transformations
|
||||||
if self.transform:
|
if self.transform:
|
||||||
low_res_image = self.transform(low_res_image)
|
low_res_image = self.transform(low_res_image)
|
||||||
high_res_image = self.transform(high_res_image)
|
high_res_image = self.transform(high_res_image)
|
||||||
|
|
||||||
return {'low_ress': low_res_image, 'high_ress': high_res_image}
|
return {'low_ress': low_res_image, 'high_ress': high_res_image}
|
||||||
|
|
||||||
|
|
||||||
class ModelTrainer:
|
class ModelTrainer:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model: AIIA,
|
model: AIIA,
|
||||||
|
|
Loading…
Reference in New Issue