diff --git a/src/aiia/data/DataLoader.py b/src/aiia/data/DataLoader.py index cbb22a6..4ba5032 100644 --- a/src/aiia/data/DataLoader.py +++ b/src/aiia/data/DataLoader.py @@ -189,7 +189,12 @@ class AIIADataset(torch.utils.data.Dataset): if self.pretraining: image, task, label = item + if not isinstance(image, Image.Image): + raise ValueError(f"Invalid image at index {idx}") + image = self.transform(image) + if image.shape != (3, 224, 224): + raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") if task == 'denoise': noise_std = 0.1 @@ -202,13 +207,22 @@ class AIIADataset(torch.utils.data.Dataset): rotated_img = transforms.functional.rotate(image, angle) target = torch.tensor(angle / 90).long() return rotated_img, target, task + else: + raise ValueError(f"Invalid task at index {idx}: {task}") else: if isinstance(item, tuple) and len(item) == 2: image, label = item + if not isinstance(image, Image.Image): + raise ValueError(f"Invalid image at index {idx}") image = self.transform(image) + if image.shape != (3, 224, 224): + raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") return image, label else: if isinstance(item, Image.Image): - return self.transform(item) + image = self.transform(item) else: - return self.transform(item[0]) + image = self.transform(item[0]) + if image.shape != (3, 224, 224): + raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") + return image