rgba conversion to rgb
This commit is contained in:
parent
cae3fa7fb3
commit
3f6e6514a9
src/aiia/data
|
@ -21,7 +21,13 @@ class FilePathLoader:
|
||||||
def _get_image(self, item):
|
def _get_image(self, item):
|
||||||
try:
|
try:
|
||||||
path = item[self.file_path_column]
|
path = item[self.file_path_column]
|
||||||
image = Image.open(path).convert("RGB")
|
image = Image.open(path)
|
||||||
|
if image.mode == 'RGBA':
|
||||||
|
background = Image.new('RGB', image.size, (0, 0, 0))
|
||||||
|
background.paste(image, mask=image.split()[3])
|
||||||
|
image = background
|
||||||
|
elif image.mode != 'RGB':
|
||||||
|
image = image.convert('RGB')
|
||||||
return image
|
return image
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading image from {path}: {e}")
|
print(f"Error loading image from {path}: {e}")
|
||||||
|
@ -69,7 +75,13 @@ class JPGImageLoader:
|
||||||
bytes_data = data
|
bytes_data = data
|
||||||
|
|
||||||
img_bytes = io.BytesIO(bytes_data)
|
img_bytes = io.BytesIO(bytes_data)
|
||||||
image = Image.open(img_bytes).convert("RGB")
|
image = Image.open(img_bytes)
|
||||||
|
if image.mode == 'RGBA':
|
||||||
|
background = Image.new('RGB', image.size, (0, 0, 0))
|
||||||
|
background.paste(image, mask=image.split()[3])
|
||||||
|
image = background
|
||||||
|
elif image.mode != 'RGB':
|
||||||
|
image = image.convert('RGB')
|
||||||
return image
|
return image
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading image from bytes: {e}")
|
print(f"Error loading image from bytes: {e}")
|
||||||
|
@ -167,12 +179,10 @@ class AIIADataset(torch.utils.data.Dataset):
|
||||||
item = self.items[idx]
|
item = self.items[idx]
|
||||||
if isinstance(item, tuple) and len(item) == 2:
|
if isinstance(item, tuple) and len(item) == 2:
|
||||||
image, label = item
|
image, label = item
|
||||||
# Convert PIL image to tensor
|
|
||||||
image = self.transform(image)
|
image = self.transform(image)
|
||||||
return (image, label)
|
return (image, label)
|
||||||
elif isinstance(item, tuple) and len(item) == 3:
|
elif isinstance(item, tuple) and len(item) == 3:
|
||||||
image, task, label = item
|
image, task, label = item
|
||||||
# Convert PIL image to tensor first
|
|
||||||
image = self.transform(image)
|
image = self.transform(image)
|
||||||
|
|
||||||
if task == 'denoise':
|
if task == 'denoise':
|
||||||
|
@ -190,7 +200,6 @@ class AIIADataset(torch.utils.data.Dataset):
|
||||||
raise ValueError(f"Unknown task: {task}")
|
raise ValueError(f"Unknown task: {task}")
|
||||||
else:
|
else:
|
||||||
if isinstance(item, Image.Image):
|
if isinstance(item, Image.Image):
|
||||||
# Convert single PIL image to tensor
|
|
||||||
return self.transform(item)
|
return self.transform(item)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid item format.")
|
raise ValueError("Invalid item format.")
|
||||||
|
|
Loading…
Reference in New Issue