From 91a568731faa68bdcaa4eef474aaaac05ed51c68 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Mon, 27 Jan 2025 09:26:08 +0100 Subject: [PATCH] removed placeholder collate function --- src/pretrain.py | 44 ++++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/src/pretrain.py b/src/pretrain.py index 45aac3f..c4bc1a7 100644 --- a/src/pretrain.py +++ b/src/pretrain.py @@ -26,22 +26,22 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): processed_batch = [] for sample in batch: try: - # Process each sample here (e.g., decode image, preprocess, etc.) - # Replace with actual preprocessing steps - processed_sample = { - 'image': torch.randn(3, 224, 224), # Example tensor - 'target': torch.randint(0, 10, (1,)), # Example target - 'task': 'denoise' # Example task - } - processed_batch.append(processed_sample) + noisy_img, target, task = sample + processed_batch.append({ + 'image': noisy_img, + 'target': target, + 'task': task + }) except Exception as e: print(f"Skipping sample due to error: {e}") + continue + if not processed_batch: return None # Skip batch if all samples are invalid # Stack tensors for the batch images = torch.stack([x['image'] for x in processed_batch]) - targets = torch.stack([x['target'] for x in processed_batch]) + targets = [x['target'] for x in processed_batch] # Don't stack targets yet tasks = [x['task'] for x in processed_batch] return (images, targets, tasks) @@ -81,18 +81,21 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): noisy_imgs, targets, tasks = batch noisy_imgs = noisy_imgs.to(device) - targets = targets.to(device) optimizer.zero_grad() outputs = model(noisy_imgs) task_losses = [] - for i, task in enumerate(tasks): + for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)): if task == 'denoise': - loss = criterion_denoise(outputs[i], targets[i]) - else: - loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0)) + target = target.to(device) # Move target to device + # Ensure output and target have same shape + loss = criterion_denoise(output, target) + else: # rotate task + target = target.to(device) # Move target to device + # For rotation task, output should be [batch_size, num_classes] + loss = criterion_rotate(output.view(1, -1), target.view(-1)) task_losses.append(loss) batch_loss = sum(task_losses) / len(task_losses) @@ -101,8 +104,8 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): total_train_loss += batch_loss.item() - avg_total_train_loss = total_train_loss / len(train_loader) - print(f"Training Loss: {avg_total_train_loss:.4f}") + avg_train_loss = total_train_loss / len(train_loader) + print(f"Training Loss: {avg_train_loss:.4f}") # Validation phase model.eval() @@ -115,16 +118,17 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): noisy_imgs, targets, tasks = batch noisy_imgs = noisy_imgs.to(device) - targets = targets.to(device) outputs = model(noisy_imgs) task_losses = [] - for i, task in enumerate(tasks): + for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)): if task == 'denoise': - loss = criterion_denoise(outputs[i], targets[i]) + target = target.to(device) + loss = criterion_denoise(output, target) else: - loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0)) + target = target.to(device) + loss = criterion_rotate(output.view(1, -1), target.view(-1)) task_losses.append(loss) batch_loss = sum(task_losses) / len(task_losses)