eorr handling because we have a tensor misshaping
This commit is contained in:
parent
8a809269e5
commit
2b55f02b50
|
@ -189,7 +189,12 @@ class AIIADataset(torch.utils.data.Dataset):
|
|||
|
||||
if self.pretraining:
|
||||
image, task, label = item
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError(f"Invalid image at index {idx}")
|
||||
|
||||
image = self.transform(image)
|
||||
if image.shape != (3, 224, 224):
|
||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||
|
||||
if task == 'denoise':
|
||||
noise_std = 0.1
|
||||
|
@ -202,13 +207,22 @@ class AIIADataset(torch.utils.data.Dataset):
|
|||
rotated_img = transforms.functional.rotate(image, angle)
|
||||
target = torch.tensor(angle / 90).long()
|
||||
return rotated_img, target, task
|
||||
else:
|
||||
raise ValueError(f"Invalid task at index {idx}: {task}")
|
||||
else:
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
image, label = item
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError(f"Invalid image at index {idx}")
|
||||
image = self.transform(image)
|
||||
if image.shape != (3, 224, 224):
|
||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||
return image, label
|
||||
else:
|
||||
if isinstance(item, Image.Image):
|
||||
return self.transform(item)
|
||||
image = self.transform(item)
|
||||
else:
|
||||
return self.transform(item[0])
|
||||
image = self.transform(item[0])
|
||||
if image.shape != (3, 224, 224):
|
||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||
return image
|
||||
|
|
Loading…
Reference in New Issue