working transform
This commit is contained in:
parent
1f33f22bea
commit
be5bb53620
|
@ -22,23 +22,23 @@ class ImageDataset(Dataset):
|
|||
def __getitem__(self, idx):
|
||||
row = self.dataframe.iloc[idx]
|
||||
|
||||
# Convert string to bytes and handle decoding
|
||||
try:
|
||||
# Decode base64 string to bytes
|
||||
low_res_bytes = base64.b64decode(row['image_512'])
|
||||
high_res_bytes = base64.b64decode(row['image_1024'])
|
||||
# Directly use bytes data for PNG images
|
||||
low_res_bytes = row['image_512']
|
||||
high_res_bytes = row['image_1024']
|
||||
|
||||
# Create in-memory streams
|
||||
low_res_stream = io.BytesIO(low_res_bytes)
|
||||
high_res_stream = io.BytesIO(high_res_bytes)
|
||||
|
||||
# Open images with explicit RGB conversion
|
||||
low_res_image = Image.open(low_res_stream).convert('RGB')
|
||||
high_res_image = Image.open(high_res_stream).convert('RGB')
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error decoding base64 string: {str(e)}")
|
||||
raise ValueError(f"Image loading failed: {str(e)}")
|
||||
|
||||
# Create image streams
|
||||
low_res_stream = io.BytesIO(low_res_bytes)
|
||||
high_res_stream = io.BytesIO(high_res_bytes)
|
||||
|
||||
# Open images
|
||||
low_res_image = Image.open(low_res_stream).convert('RGB')
|
||||
high_res_image = Image.open(high_res_stream).convert('RGB')
|
||||
|
||||
# Apply transformations
|
||||
# Apply transformations if specified
|
||||
if self.transform:
|
||||
low_res_image = self.transform(low_res_image)
|
||||
high_res_image = self.transform(high_res_image)
|
||||
|
@ -46,6 +46,7 @@ class ImageDataset(Dataset):
|
|||
return {'low_ress': low_res_image, 'high_ress': high_res_image}
|
||||
|
||||
|
||||
|
||||
class ModelTrainer:
|
||||
def __init__(self,
|
||||
model: AIIA,
|
||||
|
|
Loading…
Reference in New Issue