removed placeholder collate function

This commit is contained in:
Falko Victor Habel 2025-01-27 09:26:08 +01:00
parent 6c146f2767
commit 91a568731f
1 changed files with 24 additions and 20 deletions

View File

@ -26,22 +26,22 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
processed_batch = []
for sample in batch:
try:
# Process each sample here (e.g., decode image, preprocess, etc.)
# Replace with actual preprocessing steps
processed_sample = {
'image': torch.randn(3, 224, 224), # Example tensor
'target': torch.randint(0, 10, (1,)), # Example target
'task': 'denoise' # Example task
}
processed_batch.append(processed_sample)
noisy_img, target, task = sample
processed_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
except Exception as e:
print(f"Skipping sample due to error: {e}")
continue
if not processed_batch:
return None # Skip batch if all samples are invalid
# Stack tensors for the batch
images = torch.stack([x['image'] for x in processed_batch])
targets = torch.stack([x['target'] for x in processed_batch])
targets = [x['target'] for x in processed_batch] # Don't stack targets yet
tasks = [x['task'] for x in processed_batch]
return (images, targets, tasks)
@ -81,18 +81,21 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
noisy_imgs, targets, tasks = batch
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
optimizer.zero_grad()
outputs = model(noisy_imgs)
task_losses = []
for i, task in enumerate(tasks):
for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)):
if task == 'denoise':
loss = criterion_denoise(outputs[i], targets[i])
else:
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
target = target.to(device) # Move target to device
# Ensure output and target have same shape
loss = criterion_denoise(output, target)
else: # rotate task
target = target.to(device) # Move target to device
# For rotation task, output should be [batch_size, num_classes]
loss = criterion_rotate(output.view(1, -1), target.view(-1))
task_losses.append(loss)
batch_loss = sum(task_losses) / len(task_losses)
@ -101,8 +104,8 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
total_train_loss += batch_loss.item()
avg_total_train_loss = total_train_loss / len(train_loader)
print(f"Training Loss: {avg_total_train_loss:.4f}")
avg_train_loss = total_train_loss / len(train_loader)
print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase
model.eval()
@ -115,16 +118,17 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
noisy_imgs, targets, tasks = batch
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
outputs = model(noisy_imgs)
task_losses = []
for i, task in enumerate(tasks):
for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)):
if task == 'denoise':
loss = criterion_denoise(outputs[i], targets[i])
target = target.to(device)
loss = criterion_denoise(output, target)
else:
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
target = target.to(device)
loss = criterion_rotate(output.view(1, -1), target.view(-1))
task_losses.append(loss)
batch_loss = sum(task_losses) / len(task_losses)