removed placeholder collate function
This commit is contained in:
parent
6c146f2767
commit
91a568731f
|
@ -26,22 +26,22 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
processed_batch = []
|
processed_batch = []
|
||||||
for sample in batch:
|
for sample in batch:
|
||||||
try:
|
try:
|
||||||
# Process each sample here (e.g., decode image, preprocess, etc.)
|
noisy_img, target, task = sample
|
||||||
# Replace with actual preprocessing steps
|
processed_batch.append({
|
||||||
processed_sample = {
|
'image': noisy_img,
|
||||||
'image': torch.randn(3, 224, 224), # Example tensor
|
'target': target,
|
||||||
'target': torch.randint(0, 10, (1,)), # Example target
|
'task': task
|
||||||
'task': 'denoise' # Example task
|
})
|
||||||
}
|
|
||||||
processed_batch.append(processed_sample)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Skipping sample due to error: {e}")
|
print(f"Skipping sample due to error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
if not processed_batch:
|
if not processed_batch:
|
||||||
return None # Skip batch if all samples are invalid
|
return None # Skip batch if all samples are invalid
|
||||||
|
|
||||||
# Stack tensors for the batch
|
# Stack tensors for the batch
|
||||||
images = torch.stack([x['image'] for x in processed_batch])
|
images = torch.stack([x['image'] for x in processed_batch])
|
||||||
targets = torch.stack([x['target'] for x in processed_batch])
|
targets = [x['target'] for x in processed_batch] # Don't stack targets yet
|
||||||
tasks = [x['task'] for x in processed_batch]
|
tasks = [x['task'] for x in processed_batch]
|
||||||
|
|
||||||
return (images, targets, tasks)
|
return (images, targets, tasks)
|
||||||
|
@ -81,18 +81,21 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
|
|
||||||
noisy_imgs, targets, tasks = batch
|
noisy_imgs, targets, tasks = batch
|
||||||
noisy_imgs = noisy_imgs.to(device)
|
noisy_imgs = noisy_imgs.to(device)
|
||||||
targets = targets.to(device)
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
outputs = model(noisy_imgs)
|
outputs = model(noisy_imgs)
|
||||||
task_losses = []
|
task_losses = []
|
||||||
|
|
||||||
for i, task in enumerate(tasks):
|
for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)):
|
||||||
if task == 'denoise':
|
if task == 'denoise':
|
||||||
loss = criterion_denoise(outputs[i], targets[i])
|
target = target.to(device) # Move target to device
|
||||||
else:
|
# Ensure output and target have same shape
|
||||||
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
|
loss = criterion_denoise(output, target)
|
||||||
|
else: # rotate task
|
||||||
|
target = target.to(device) # Move target to device
|
||||||
|
# For rotation task, output should be [batch_size, num_classes]
|
||||||
|
loss = criterion_rotate(output.view(1, -1), target.view(-1))
|
||||||
task_losses.append(loss)
|
task_losses.append(loss)
|
||||||
|
|
||||||
batch_loss = sum(task_losses) / len(task_losses)
|
batch_loss = sum(task_losses) / len(task_losses)
|
||||||
|
@ -101,8 +104,8 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
|
|
||||||
total_train_loss += batch_loss.item()
|
total_train_loss += batch_loss.item()
|
||||||
|
|
||||||
avg_total_train_loss = total_train_loss / len(train_loader)
|
avg_train_loss = total_train_loss / len(train_loader)
|
||||||
print(f"Training Loss: {avg_total_train_loss:.4f}")
|
print(f"Training Loss: {avg_train_loss:.4f}")
|
||||||
|
|
||||||
# Validation phase
|
# Validation phase
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -115,16 +118,17 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
|
|
||||||
noisy_imgs, targets, tasks = batch
|
noisy_imgs, targets, tasks = batch
|
||||||
noisy_imgs = noisy_imgs.to(device)
|
noisy_imgs = noisy_imgs.to(device)
|
||||||
targets = targets.to(device)
|
|
||||||
|
|
||||||
outputs = model(noisy_imgs)
|
outputs = model(noisy_imgs)
|
||||||
task_losses = []
|
task_losses = []
|
||||||
|
|
||||||
for i, task in enumerate(tasks):
|
for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)):
|
||||||
if task == 'denoise':
|
if task == 'denoise':
|
||||||
loss = criterion_denoise(outputs[i], targets[i])
|
target = target.to(device)
|
||||||
|
loss = criterion_denoise(output, target)
|
||||||
else:
|
else:
|
||||||
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
|
target = target.to(device)
|
||||||
|
loss = criterion_rotate(output.view(1, -1), target.view(-1))
|
||||||
task_losses.append(loss)
|
task_losses.append(loss)
|
||||||
|
|
||||||
batch_loss = sum(task_losses) / len(task_losses)
|
batch_loss = sum(task_losses) / len(task_losses)
|
||||||
|
|
Loading…
Reference in New Issue