fixed dataloading
This commit is contained in:
parent
00168af32d
commit
b7dc835a86
|
@ -85,21 +85,19 @@ class JPGImageLoader:
|
|||
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||
|
||||
|
||||
class AIIADataLoader(DataLoader):
|
||||
class AIIADataLoader:
|
||||
def __init__(self, dataset,
|
||||
batch_size=32,
|
||||
val_split=0.2,
|
||||
seed=42,
|
||||
column="file_path",
|
||||
label_column=None):
|
||||
super().__init__(dataset)
|
||||
|
||||
label_column=None,
|
||||
**dataloader_kwargs):
|
||||
self.batch_size = batch_size
|
||||
self.val_split = val_split
|
||||
self.seed = seed
|
||||
random.seed(seed)
|
||||
|
||||
# Determine which loader to use based on the dataset's content
|
||||
# Check if any entry in bytes_column is a bytes or bytestring type
|
||||
is_bytes_or_bytestring = any(
|
||||
isinstance(value, (bytes, memoryview))
|
||||
for value in dataset[column].dropna().head(1).astype(str)
|
||||
|
@ -112,10 +110,8 @@ class AIIADataLoader(DataLoader):
|
|||
label_column=label_column
|
||||
)
|
||||
else:
|
||||
# Check if file_path column contains valid image file paths (at least one entry)
|
||||
sample_paths = dataset[column].dropna().head(1).astype(str)
|
||||
|
||||
# Regex pattern for matching image file paths (adjust as needed)
|
||||
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|png|gif)$'
|
||||
|
||||
if any(
|
||||
|
@ -128,23 +124,33 @@ class AIIADataLoader(DataLoader):
|
|||
label_column=label_column
|
||||
)
|
||||
else:
|
||||
# If neither condition is met, default to JPGImageLoader (assuming bytes are stored as strings)
|
||||
self.loader = JPGImageLoader(
|
||||
dataset,
|
||||
bytes_column=column,
|
||||
label_column=label_column
|
||||
)
|
||||
|
||||
# Get all items
|
||||
self.items = [self.loader.get_item(idx) for idx in range(len(dataset))]
|
||||
|
||||
# Split into train and validation sets
|
||||
train_indices, val_indices = self._split_data()
|
||||
|
||||
# Create datasets for training and validation
|
||||
self.train_dataset = self._create_subset(train_indices)
|
||||
self.val_dataset = self._create_subset(val_indices)
|
||||
|
||||
self.train_loader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
**dataloader_kwargs
|
||||
)
|
||||
|
||||
self.val_loader = DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
**dataloader_kwargs
|
||||
)
|
||||
|
||||
def _split_data(self):
|
||||
if len(self.items) == 0:
|
||||
return [], []
|
||||
|
@ -184,7 +190,6 @@ class AIIADataset(torch.utils.data.Dataset):
|
|||
return (image, label)
|
||||
elif isinstance(item, tuple) and len(item) == 3:
|
||||
image, task, label = item
|
||||
# Handle tasks accordingly (e.g., apply different augmentations)
|
||||
if task == 'denoise':
|
||||
noise_std = 0.1
|
||||
noisy_img = image + torch.randn_like(image) * noise_std
|
||||
|
@ -199,7 +204,6 @@ class AIIADataset(torch.utils.data.Dataset):
|
|||
else:
|
||||
raise ValueError(f"Unknown task: {task}")
|
||||
else:
|
||||
# Handle single images without labels or tasks
|
||||
if isinstance(item, Image.Image):
|
||||
return item
|
||||
else:
|
||||
|
|
|
@ -29,29 +29,17 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
model = AIIABase(config)
|
||||
|
||||
# Create dataset loader with merged data
|
||||
train_dataset = AIIADataLoader(
|
||||
aiia_loader = AIIADataLoader(
|
||||
merged_df,
|
||||
batch_size=32,
|
||||
val_split=0.2,
|
||||
seed=42,
|
||||
column="image_bytes",
|
||||
label_column=None
|
||||
column="image_bytes"
|
||||
)
|
||||
|
||||
# Create separate dataloaders for training and validation sets
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset.train_dataset,
|
||||
batch_size=train_dataset.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=4
|
||||
)
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
train_dataset.val_ataset,
|
||||
batch_size=train_dataset.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4
|
||||
)
|
||||
# Access the train and validation loaders
|
||||
train_loader = aiia_loader.train_loader
|
||||
val_loader = aiia_loader.val_loader
|
||||
|
||||
# Initialize loss functions and optimizer
|
||||
criterion_denoise = nn.MSELoss()
|
||||
|
@ -72,7 +60,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
model.train()
|
||||
total_train_loss = 0.0
|
||||
|
||||
for batch in train_dataloader:
|
||||
for batch in train_loader:
|
||||
images, targets, tasks = zip(*batch)
|
||||
|
||||
if device == "cuda":
|
||||
|
@ -102,14 +90,14 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
total_train_loss += avg_loss.item()
|
||||
# Separate losses for reporting (you'd need to track this based on tasks)
|
||||
|
||||
avg_total_train_loss = total_train_loss / len(train_dataloader)
|
||||
avg_total_train_loss = total_train_loss / len(train_loader)
|
||||
print(f"Training Loss: {avg_total_train_loss:.4f}")
|
||||
|
||||
# Validation phase
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
val_losses = []
|
||||
for batch in val_dataloader:
|
||||
for batch in val_loader:
|
||||
images, targets, tasks = zip(*batch)
|
||||
|
||||
if device == "cuda":
|
||||
|
@ -132,7 +120,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
avg_val_loss = total_loss / len(images)
|
||||
val_losses.append(avg_val_loss.item())
|
||||
|
||||
avg_val_loss = sum(val_losses) / len(val_dataloader)
|
||||
avg_val_loss = sum(val_losses) / len(val_loader)
|
||||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
||||
|
||||
# Save the best model
|
||||
|
|
Loading…
Reference in New Issue