removed placeholder collate function
This commit is contained in:
parent
6c146f2767
commit
91a568731f
|
@ -26,22 +26,22 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
processed_batch = []
|
||||
for sample in batch:
|
||||
try:
|
||||
# Process each sample here (e.g., decode image, preprocess, etc.)
|
||||
# Replace with actual preprocessing steps
|
||||
processed_sample = {
|
||||
'image': torch.randn(3, 224, 224), # Example tensor
|
||||
'target': torch.randint(0, 10, (1,)), # Example target
|
||||
'task': 'denoise' # Example task
|
||||
}
|
||||
processed_batch.append(processed_sample)
|
||||
noisy_img, target, task = sample
|
||||
processed_batch.append({
|
||||
'image': noisy_img,
|
||||
'target': target,
|
||||
'task': task
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Skipping sample due to error: {e}")
|
||||
continue
|
||||
|
||||
if not processed_batch:
|
||||
return None # Skip batch if all samples are invalid
|
||||
|
||||
# Stack tensors for the batch
|
||||
images = torch.stack([x['image'] for x in processed_batch])
|
||||
targets = torch.stack([x['target'] for x in processed_batch])
|
||||
targets = [x['target'] for x in processed_batch] # Don't stack targets yet
|
||||
tasks = [x['task'] for x in processed_batch]
|
||||
|
||||
return (images, targets, tasks)
|
||||
|
@ -81,18 +81,21 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
|
||||
noisy_imgs, targets, tasks = batch
|
||||
noisy_imgs = noisy_imgs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = model(noisy_imgs)
|
||||
task_losses = []
|
||||
|
||||
for i, task in enumerate(tasks):
|
||||
for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)):
|
||||
if task == 'denoise':
|
||||
loss = criterion_denoise(outputs[i], targets[i])
|
||||
else:
|
||||
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
|
||||
target = target.to(device) # Move target to device
|
||||
# Ensure output and target have same shape
|
||||
loss = criterion_denoise(output, target)
|
||||
else: # rotate task
|
||||
target = target.to(device) # Move target to device
|
||||
# For rotation task, output should be [batch_size, num_classes]
|
||||
loss = criterion_rotate(output.view(1, -1), target.view(-1))
|
||||
task_losses.append(loss)
|
||||
|
||||
batch_loss = sum(task_losses) / len(task_losses)
|
||||
|
@ -101,8 +104,8 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
|
||||
total_train_loss += batch_loss.item()
|
||||
|
||||
avg_total_train_loss = total_train_loss / len(train_loader)
|
||||
print(f"Training Loss: {avg_total_train_loss:.4f}")
|
||||
avg_train_loss = total_train_loss / len(train_loader)
|
||||
print(f"Training Loss: {avg_train_loss:.4f}")
|
||||
|
||||
# Validation phase
|
||||
model.eval()
|
||||
|
@ -115,16 +118,17 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
|
||||
noisy_imgs, targets, tasks = batch
|
||||
noisy_imgs = noisy_imgs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
outputs = model(noisy_imgs)
|
||||
task_losses = []
|
||||
|
||||
for i, task in enumerate(tasks):
|
||||
for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)):
|
||||
if task == 'denoise':
|
||||
loss = criterion_denoise(outputs[i], targets[i])
|
||||
target = target.to(device)
|
||||
loss = criterion_denoise(output, target)
|
||||
else:
|
||||
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
|
||||
target = target.to(device)
|
||||
loss = criterion_rotate(output.view(1, -1), target.view(-1))
|
||||
task_losses.append(loss)
|
||||
|
||||
batch_loss = sum(task_losses) / len(task_losses)
|
||||
|
|
Loading…
Reference in New Issue