diff --git a/src/aiia/data/DataLoader.py b/src/aiia/data/DataLoader.py index 223d146..1f5d75a 100644 --- a/src/aiia/data/DataLoader.py +++ b/src/aiia/data/DataLoader.py @@ -152,10 +152,13 @@ class AIIADataLoader: def _create_subset(self, indices): subset_items = [self.items[i] for i in indices] return AIIADataset(subset_items) - + class AIIADataset(torch.utils.data.Dataset): def __init__(self, items): self.items = items + self.transform = transforms.Compose([ + transforms.ToTensor() + ]) def __len__(self): return len(self.items) @@ -164,9 +167,14 @@ class AIIADataset(torch.utils.data.Dataset): item = self.items[idx] if isinstance(item, tuple) and len(item) == 2: image, label = item + # Convert PIL image to tensor + image = self.transform(image) return (image, label) elif isinstance(item, tuple) and len(item) == 3: image, task, label = item + # Convert PIL image to tensor first + image = self.transform(image) + if task == 'denoise': noise_std = 0.1 noisy_img = image + torch.randn_like(image) * noise_std @@ -182,6 +190,7 @@ class AIIADataset(torch.utils.data.Dataset): raise ValueError(f"Unknown task: {task}") else: if isinstance(item, Image.Image): - return item + # Convert single PIL image to tensor + return self.transform(item) else: - raise ValueError("Invalid item format.") \ No newline at end of file + raise ValueError("Invalid item format.")