rgba conversion to rgb

This commit is contained in:
Falko Victor Habel 2025-01-26 22:17:37 +01:00
parent cae3fa7fb3
commit 3f6e6514a9
1 changed files with 14 additions and 5 deletions

View File

@ -21,7 +21,13 @@ class FilePathLoader:
def _get_image(self, item):
try:
path = item[self.file_path_column]
image = Image.open(path).convert("RGB")
image = Image.open(path)
if image.mode == 'RGBA':
background = Image.new('RGB', image.size, (0, 0, 0))
background.paste(image, mask=image.split()[3])
image = background
elif image.mode != 'RGB':
image = image.convert('RGB')
return image
except Exception as e:
print(f"Error loading image from {path}: {e}")
@ -69,7 +75,13 @@ class JPGImageLoader:
bytes_data = data
img_bytes = io.BytesIO(bytes_data)
image = Image.open(img_bytes).convert("RGB")
image = Image.open(img_bytes)
if image.mode == 'RGBA':
background = Image.new('RGB', image.size, (0, 0, 0))
background.paste(image, mask=image.split()[3])
image = background
elif image.mode != 'RGB':
image = image.convert('RGB')
return image
except Exception as e:
print(f"Error loading image from bytes: {e}")
@ -167,12 +179,10 @@ class AIIADataset(torch.utils.data.Dataset):
item = self.items[idx]
if isinstance(item, tuple) and len(item) == 2:
image, label = item
# Convert PIL image to tensor
image = self.transform(image)
return (image, label)
elif isinstance(item, tuple) and len(item) == 3:
image, task, label = item
# Convert PIL image to tensor first
image = self.transform(image)
if task == 'denoise':
@ -190,7 +200,6 @@ class AIIADataset(torch.utils.data.Dataset):
raise ValueError(f"Unknown task: {task}")
else:
if isinstance(item, Image.Image):
# Convert single PIL image to tensor
return self.transform(item)
else:
raise ValueError("Invalid item format.")