removed placeholder collate function

This commit is contained in:
Falko Victor Habel 2025-01-27 09:26:08 +01:00
parent 6c146f2767
commit 91a568731f
1 changed files with 24 additions and 20 deletions

View File

@ -26,22 +26,22 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
processed_batch = [] processed_batch = []
for sample in batch: for sample in batch:
try: try:
# Process each sample here (e.g., decode image, preprocess, etc.) noisy_img, target, task = sample
# Replace with actual preprocessing steps processed_batch.append({
processed_sample = { 'image': noisy_img,
'image': torch.randn(3, 224, 224), # Example tensor 'target': target,
'target': torch.randint(0, 10, (1,)), # Example target 'task': task
'task': 'denoise' # Example task })
}
processed_batch.append(processed_sample)
except Exception as e: except Exception as e:
print(f"Skipping sample due to error: {e}") print(f"Skipping sample due to error: {e}")
continue
if not processed_batch: if not processed_batch:
return None # Skip batch if all samples are invalid return None # Skip batch if all samples are invalid
# Stack tensors for the batch # Stack tensors for the batch
images = torch.stack([x['image'] for x in processed_batch]) images = torch.stack([x['image'] for x in processed_batch])
targets = torch.stack([x['target'] for x in processed_batch]) targets = [x['target'] for x in processed_batch] # Don't stack targets yet
tasks = [x['task'] for x in processed_batch] tasks = [x['task'] for x in processed_batch]
return (images, targets, tasks) return (images, targets, tasks)
@ -81,18 +81,21 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
noisy_imgs, targets, tasks = batch noisy_imgs, targets, tasks = batch
noisy_imgs = noisy_imgs.to(device) noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
optimizer.zero_grad() optimizer.zero_grad()
outputs = model(noisy_imgs) outputs = model(noisy_imgs)
task_losses = [] task_losses = []
for i, task in enumerate(tasks): for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)):
if task == 'denoise': if task == 'denoise':
loss = criterion_denoise(outputs[i], targets[i]) target = target.to(device) # Move target to device
else: # Ensure output and target have same shape
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0)) loss = criterion_denoise(output, target)
else: # rotate task
target = target.to(device) # Move target to device
# For rotation task, output should be [batch_size, num_classes]
loss = criterion_rotate(output.view(1, -1), target.view(-1))
task_losses.append(loss) task_losses.append(loss)
batch_loss = sum(task_losses) / len(task_losses) batch_loss = sum(task_losses) / len(task_losses)
@ -101,8 +104,8 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
total_train_loss += batch_loss.item() total_train_loss += batch_loss.item()
avg_total_train_loss = total_train_loss / len(train_loader) avg_train_loss = total_train_loss / len(train_loader)
print(f"Training Loss: {avg_total_train_loss:.4f}") print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase # Validation phase
model.eval() model.eval()
@ -115,16 +118,17 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
noisy_imgs, targets, tasks = batch noisy_imgs, targets, tasks = batch
noisy_imgs = noisy_imgs.to(device) noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
outputs = model(noisy_imgs) outputs = model(noisy_imgs)
task_losses = [] task_losses = []
for i, task in enumerate(tasks): for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)):
if task == 'denoise': if task == 'denoise':
loss = criterion_denoise(outputs[i], targets[i]) target = target.to(device)
loss = criterion_denoise(output, target)
else: else:
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0)) target = target.to(device)
loss = criterion_rotate(output.view(1, -1), target.view(-1))
task_losses.append(loss) task_losses.append(loss)
batch_loss = sum(task_losses) / len(task_losses) batch_loss = sum(task_losses) / len(task_losses)