eorr handling because we have a tensor misshaping

This commit is contained in:
Falko Victor Habel 2025-01-26 23:08:56 +01:00
parent 8a809269e5
commit 2b55f02b50
1 changed files with 16 additions and 2 deletions

View File

@ -189,7 +189,12 @@ class AIIADataset(torch.utils.data.Dataset):
if self.pretraining:
image, task, label = item
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
if task == 'denoise':
noise_std = 0.1
@ -202,13 +207,22 @@ class AIIADataset(torch.utils.data.Dataset):
rotated_img = transforms.functional.rotate(image, angle)
target = torch.tensor(angle / 90).long()
return rotated_img, target, task
else:
raise ValueError(f"Invalid task at index {idx}: {task}")
else:
if isinstance(item, tuple) and len(item) == 2:
image, label = item
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
return image, label
else:
if isinstance(item, Image.Image):
return self.transform(item)
image = self.transform(item)
else:
return self.transform(item[0])
image = self.transform(item[0])
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
return image