eorr handling because we have a tensor misshaping
This commit is contained in:
parent
8a809269e5
commit
2b55f02b50
|
@ -189,7 +189,12 @@ class AIIADataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
if self.pretraining:
|
if self.pretraining:
|
||||||
image, task, label = item
|
image, task, label = item
|
||||||
|
if not isinstance(image, Image.Image):
|
||||||
|
raise ValueError(f"Invalid image at index {idx}")
|
||||||
|
|
||||||
image = self.transform(image)
|
image = self.transform(image)
|
||||||
|
if image.shape != (3, 224, 224):
|
||||||
|
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||||
|
|
||||||
if task == 'denoise':
|
if task == 'denoise':
|
||||||
noise_std = 0.1
|
noise_std = 0.1
|
||||||
|
@ -202,13 +207,22 @@ class AIIADataset(torch.utils.data.Dataset):
|
||||||
rotated_img = transforms.functional.rotate(image, angle)
|
rotated_img = transforms.functional.rotate(image, angle)
|
||||||
target = torch.tensor(angle / 90).long()
|
target = torch.tensor(angle / 90).long()
|
||||||
return rotated_img, target, task
|
return rotated_img, target, task
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid task at index {idx}: {task}")
|
||||||
else:
|
else:
|
||||||
if isinstance(item, tuple) and len(item) == 2:
|
if isinstance(item, tuple) and len(item) == 2:
|
||||||
image, label = item
|
image, label = item
|
||||||
|
if not isinstance(image, Image.Image):
|
||||||
|
raise ValueError(f"Invalid image at index {idx}")
|
||||||
image = self.transform(image)
|
image = self.transform(image)
|
||||||
|
if image.shape != (3, 224, 224):
|
||||||
|
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||||
return image, label
|
return image, label
|
||||||
else:
|
else:
|
||||||
if isinstance(item, Image.Image):
|
if isinstance(item, Image.Image):
|
||||||
return self.transform(item)
|
image = self.transform(item)
|
||||||
else:
|
else:
|
||||||
return self.transform(item[0])
|
image = self.transform(item[0])
|
||||||
|
if image.shape != (3, 224, 224):
|
||||||
|
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||||
|
return image
|
||||||
|
|
Loading…
Reference in New Issue