debug mode

This commit is contained in:
Falko Victor Habel 2025-01-30 12:52:57 +01:00
parent 4037a07764
commit 9b7a182782
1 changed files with 26 additions and 17 deletions

View File

@ -1,6 +1,6 @@
import torch import torch
import pandas as pd import pandas as pd
from PIL import Image from PIL import Image, ImageFile
import io import io
from torch import nn from torch import nn
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
@ -23,25 +23,34 @@ class ImageDataset(Dataset):
row = self.dataframe.iloc[idx] row = self.dataframe.iloc[idx]
try: try:
# Convert string data to bytes if needed # Verify data is valid before creating BytesIO
low_res_bytes = row['image_512'].encode() if isinstance(row['image_512'], str) else row['image_512'] if not isinstance(row['image_512'], bytes) or not isinstance(row['image_1024'], bytes):
high_res_bytes = row['image_1024'].encode() if isinstance(row['image_1024'], str) else row['image_1024'] raise ValueError("Image data must be in bytes format")
# Create BytesIO objects low_res_stream = io.BytesIO(row['image_512'])
low_res_stream = io.BytesIO(low_res_bytes) high_res_stream = io.BytesIO(row['image_1024'])
high_res_stream = io.BytesIO(high_res_bytes)
# Reset stream position
low_res_stream.seek(0)
high_res_stream.seek(0)
# Enable loading of truncated images if necessary
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Open images
low_res_image = Image.open(low_res_stream).convert('RGB') low_res_image = Image.open(low_res_stream).convert('RGB')
high_res_image = Image.open(high_res_stream).convert('RGB') high_res_image = Image.open(high_res_stream).convert('RGB')
# Close the streams # Verify images are valid
low_res_stream.close() low_res_image.verify()
high_res_stream.close() high_res_image.verify()
except Exception as e: except Exception as e:
raise ValueError(f"Image loading failed: {str(e)}") raise ValueError(f"Image loading failed: {str(e)}")
finally:
low_res_stream.close()
high_res_stream.close()
if self.transform: if self.transform:
low_res_image = self.transform(low_res_image) low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image) high_res_image = self.transform(high_res_image)