diff --git a/src/aiia/data/DataLoader.py b/src/aiia/data/DataLoader.py index 6f3a334..89ba400 100644 --- a/src/aiia/data/DataLoader.py +++ b/src/aiia/data/DataLoader.py @@ -177,7 +177,7 @@ class AIIADataset(torch.utils.data.Dataset): self.items = items self.pretraining = pretraining self.transform = transforms.Compose([ - transforms.Resize((410, 410)), + transforms.Resize((400, 400)), transforms.ToTensor() ]) @@ -193,7 +193,7 @@ class AIIADataset(torch.utils.data.Dataset): raise ValueError(f"Invalid image at index {idx}") image = self.transform(image) - if image.shape != (3, 410, 410): + if image.shape != (3, 400, 400): raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") if task == 'denoise': @@ -215,7 +215,7 @@ class AIIADataset(torch.utils.data.Dataset): if not isinstance(image, Image.Image): raise ValueError(f"Invalid image at index {idx}") image = self.transform(image) - if image.shape != (3, 410, 410): + if image.shape != (3, 400, 400): raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") return image, label else: @@ -223,6 +223,6 @@ class AIIADataset(torch.utils.data.Dataset): image = self.transform(item) else: image = self.transform(item[0]) - if image.shape != (3, 410, 410): + if image.shape != (3, 400, 400): raise ValueError(f"Invalid image shape at index {idx}: {image.shape}") return image