eorr handling because we have a tensor misshaping

This commit is contained in:
Falko Victor Habel 2025-01-26 23:08:56 +01:00
parent 8a809269e5
commit 2b55f02b50
1 changed files with 16 additions and 2 deletions

View File

@ -189,7 +189,12 @@ class AIIADataset(torch.utils.data.Dataset):
if self.pretraining: if self.pretraining:
image, task, label = item image, task, label = item
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
image = self.transform(image) image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
if task == 'denoise': if task == 'denoise':
noise_std = 0.1 noise_std = 0.1
@ -202,13 +207,22 @@ class AIIADataset(torch.utils.data.Dataset):
rotated_img = transforms.functional.rotate(image, angle) rotated_img = transforms.functional.rotate(image, angle)
target = torch.tensor(angle / 90).long() target = torch.tensor(angle / 90).long()
return rotated_img, target, task return rotated_img, target, task
else:
raise ValueError(f"Invalid task at index {idx}: {task}")
else: else:
if isinstance(item, tuple) and len(item) == 2: if isinstance(item, tuple) and len(item) == 2:
image, label = item image, label = item
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
image = self.transform(image) image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
return image, label return image, label
else: else:
if isinstance(item, Image.Image): if isinstance(item, Image.Image):
return self.transform(item) image = self.transform(item)
else: else:
return self.transform(item[0]) image = self.transform(item[0])
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
return image