proper image transformation

This commit is contained in:
Falko Victor Habel 2025-01-26 22:08:59 +01:00
parent 7a1eb8bd30
commit cae3fa7fb3
1 changed files with 12 additions and 3 deletions

View File

@ -152,10 +152,13 @@ class AIIADataLoader:
def _create_subset(self, indices): def _create_subset(self, indices):
subset_items = [self.items[i] for i in indices] subset_items = [self.items[i] for i in indices]
return AIIADataset(subset_items) return AIIADataset(subset_items)
class AIIADataset(torch.utils.data.Dataset): class AIIADataset(torch.utils.data.Dataset):
def __init__(self, items): def __init__(self, items):
self.items = items self.items = items
self.transform = transforms.Compose([
transforms.ToTensor()
])
def __len__(self): def __len__(self):
return len(self.items) return len(self.items)
@ -164,9 +167,14 @@ class AIIADataset(torch.utils.data.Dataset):
item = self.items[idx] item = self.items[idx]
if isinstance(item, tuple) and len(item) == 2: if isinstance(item, tuple) and len(item) == 2:
image, label = item image, label = item
# Convert PIL image to tensor
image = self.transform(image)
return (image, label) return (image, label)
elif isinstance(item, tuple) and len(item) == 3: elif isinstance(item, tuple) and len(item) == 3:
image, task, label = item image, task, label = item
# Convert PIL image to tensor first
image = self.transform(image)
if task == 'denoise': if task == 'denoise':
noise_std = 0.1 noise_std = 0.1
noisy_img = image + torch.randn_like(image) * noise_std noisy_img = image + torch.randn_like(image) * noise_std
@ -182,6 +190,7 @@ class AIIADataset(torch.utils.data.Dataset):
raise ValueError(f"Unknown task: {task}") raise ValueError(f"Unknown task: {task}")
else: else:
if isinstance(item, Image.Image): if isinstance(item, Image.Image):
return item # Convert single PIL image to tensor
return self.transform(item)
else: else:
raise ValueError("Invalid item format.") raise ValueError("Invalid item format.")