proper image transformation
This commit is contained in:
parent
7a1eb8bd30
commit
cae3fa7fb3
|
@ -152,10 +152,13 @@ class AIIADataLoader:
|
||||||
def _create_subset(self, indices):
|
def _create_subset(self, indices):
|
||||||
subset_items = [self.items[i] for i in indices]
|
subset_items = [self.items[i] for i in indices]
|
||||||
return AIIADataset(subset_items)
|
return AIIADataset(subset_items)
|
||||||
|
|
||||||
class AIIADataset(torch.utils.data.Dataset):
|
class AIIADataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, items):
|
def __init__(self, items):
|
||||||
self.items = items
|
self.items = items
|
||||||
|
self.transform = transforms.Compose([
|
||||||
|
transforms.ToTensor()
|
||||||
|
])
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.items)
|
return len(self.items)
|
||||||
|
@ -164,9 +167,14 @@ class AIIADataset(torch.utils.data.Dataset):
|
||||||
item = self.items[idx]
|
item = self.items[idx]
|
||||||
if isinstance(item, tuple) and len(item) == 2:
|
if isinstance(item, tuple) and len(item) == 2:
|
||||||
image, label = item
|
image, label = item
|
||||||
|
# Convert PIL image to tensor
|
||||||
|
image = self.transform(image)
|
||||||
return (image, label)
|
return (image, label)
|
||||||
elif isinstance(item, tuple) and len(item) == 3:
|
elif isinstance(item, tuple) and len(item) == 3:
|
||||||
image, task, label = item
|
image, task, label = item
|
||||||
|
# Convert PIL image to tensor first
|
||||||
|
image = self.transform(image)
|
||||||
|
|
||||||
if task == 'denoise':
|
if task == 'denoise':
|
||||||
noise_std = 0.1
|
noise_std = 0.1
|
||||||
noisy_img = image + torch.randn_like(image) * noise_std
|
noisy_img = image + torch.randn_like(image) * noise_std
|
||||||
|
@ -182,6 +190,7 @@ class AIIADataset(torch.utils.data.Dataset):
|
||||||
raise ValueError(f"Unknown task: {task}")
|
raise ValueError(f"Unknown task: {task}")
|
||||||
else:
|
else:
|
||||||
if isinstance(item, Image.Image):
|
if isinstance(item, Image.Image):
|
||||||
return item
|
# Convert single PIL image to tensor
|
||||||
|
return self.transform(item)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid item format.")
|
raise ValueError("Invalid item format.")
|
||||||
|
|
Loading…
Reference in New Issue