develop #4
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from PIL import Image
|
from PIL import Image, ImageFile
|
||||||
import io
|
import io
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
@ -23,25 +23,34 @@ class ImageDataset(Dataset):
|
||||||
row = self.dataframe.iloc[idx]
|
row = self.dataframe.iloc[idx]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Convert string data to bytes if needed
|
# Verify data is valid before creating BytesIO
|
||||||
low_res_bytes = row['image_512'].encode() if isinstance(row['image_512'], str) else row['image_512']
|
if not isinstance(row['image_512'], bytes) or not isinstance(row['image_1024'], bytes):
|
||||||
high_res_bytes = row['image_1024'].encode() if isinstance(row['image_1024'], str) else row['image_1024']
|
raise ValueError("Image data must be in bytes format")
|
||||||
|
|
||||||
# Create BytesIO objects
|
low_res_stream = io.BytesIO(row['image_512'])
|
||||||
low_res_stream = io.BytesIO(low_res_bytes)
|
high_res_stream = io.BytesIO(row['image_1024'])
|
||||||
high_res_stream = io.BytesIO(high_res_bytes)
|
|
||||||
|
# Reset stream position
|
||||||
|
low_res_stream.seek(0)
|
||||||
|
high_res_stream.seek(0)
|
||||||
|
|
||||||
|
# Enable loading of truncated images if necessary
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
# Open images
|
|
||||||
low_res_image = Image.open(low_res_stream).convert('RGB')
|
low_res_image = Image.open(low_res_stream).convert('RGB')
|
||||||
high_res_image = Image.open(high_res_stream).convert('RGB')
|
high_res_image = Image.open(high_res_stream).convert('RGB')
|
||||||
|
|
||||||
# Close the streams
|
# Verify images are valid
|
||||||
low_res_stream.close()
|
low_res_image.verify()
|
||||||
high_res_stream.close()
|
high_res_image.verify()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Image loading failed: {str(e)}")
|
raise ValueError(f"Image loading failed: {str(e)}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
low_res_stream.close()
|
||||||
|
high_res_stream.close()
|
||||||
|
|
||||||
if self.transform:
|
if self.transform:
|
||||||
low_res_image = self.transform(low_res_image)
|
low_res_image = self.transform(low_res_image)
|
||||||
high_res_image = self.transform(high_res_image)
|
high_res_image = self.transform(high_res_image)
|
||||||
|
|
Loading…
Reference in New Issue