develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 26 additions and 17 deletions
Showing only changes of commit 9b7a182782 - Show all commits

View File

@ -1,6 +1,6 @@
import torch import torch
import pandas as pd import pandas as pd
from PIL import Image from PIL import Image, ImageFile
import io import io
from torch import nn from torch import nn
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
@ -23,25 +23,34 @@ class ImageDataset(Dataset):
row = self.dataframe.iloc[idx] row = self.dataframe.iloc[idx]
try: try:
# Convert string data to bytes if needed # Verify data is valid before creating BytesIO
low_res_bytes = row['image_512'].encode() if isinstance(row['image_512'], str) else row['image_512'] if not isinstance(row['image_512'], bytes) or not isinstance(row['image_1024'], bytes):
high_res_bytes = row['image_1024'].encode() if isinstance(row['image_1024'], str) else row['image_1024'] raise ValueError("Image data must be in bytes format")
# Create BytesIO objects low_res_stream = io.BytesIO(row['image_512'])
low_res_stream = io.BytesIO(low_res_bytes) high_res_stream = io.BytesIO(row['image_1024'])
high_res_stream = io.BytesIO(high_res_bytes)
# Reset stream position
low_res_stream.seek(0)
high_res_stream.seek(0)
# Enable loading of truncated images if necessary
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Open images
low_res_image = Image.open(low_res_stream).convert('RGB') low_res_image = Image.open(low_res_stream).convert('RGB')
high_res_image = Image.open(high_res_stream).convert('RGB') high_res_image = Image.open(high_res_stream).convert('RGB')
# Close the streams # Verify images are valid
low_res_stream.close() low_res_image.verify()
high_res_stream.close() high_res_image.verify()
except Exception as e: except Exception as e:
raise ValueError(f"Image loading failed: {str(e)}") raise ValueError(f"Image loading failed: {str(e)}")
finally:
low_res_stream.close()
high_res_stream.close()
if self.transform: if self.transform:
low_res_image = self.transform(low_res_image) low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image) high_res_image = self.transform(high_res_image)