diff --git a/src/pretrain.py b/src/pretrain.py index 690bd9e..201c03f 100644 --- a/src/pretrain.py +++ b/src/pretrain.py @@ -6,6 +6,7 @@ from aiia.model import AIIABase from aiia.data.DataLoader import AIIADataLoader from tqdm import tqdm + def pretrain_model(data_path1, data_path2, num_epochs=3): # Read and merge datasets df1 = pd.read_parquet(data_path1).head(10000) @@ -21,28 +22,49 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): model = AIIABase(config) def safe_collate(batch): - processed_batch = [] + denoise_batch = [] + rotate_batch = [] + for sample in batch: try: noisy_img, target, task = sample - processed_batch.append({ - 'image': noisy_img, - 'target': target, - 'task': task - }) + if task == 'denoise': + denoise_batch.append({ + 'image': noisy_img, + 'target': target, + 'task': task + }) + else: # rotate task + rotate_batch.append({ + 'image': noisy_img, + 'target': target, + 'task': task + }) except Exception as e: print(f"Skipping sample due to error: {e}") continue - if not processed_batch: + if not denoise_batch and not rotate_batch: return None - # Stack tensors for the batch - images = torch.stack([x['image'] for x in processed_batch]) - targets = torch.stack([x['target'] for x in processed_batch]) # Stack targets - tasks = [x['task'] for x in processed_batch] + batch_data = { + 'denoise': None, + 'rotate': None + } - return (images, targets, tasks) + # Process denoise batch + if denoise_batch: + images = torch.stack([x['image'] for x in denoise_batch]) + targets = torch.stack([x['target'] for x in denoise_batch]) + batch_data['denoise'] = (images, targets) + + # Process rotate batch + if rotate_batch: + images = torch.stack([x['image'] for x in rotate_batch]) + targets = torch.stack([x['target'] for x in rotate_batch]) + batch_data['rotate'] = (images, targets) + + return batch_data aiia_loader = AIIADataLoader( merged_df, @@ -71,74 +93,81 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): # Training phase model.train() total_train_loss = 0.0 + batch_count = 0 - for batch in tqdm(train_loader): - if batch is None: + for batch_data in tqdm(train_loader): + if batch_data is None: continue - noisy_imgs, targets, tasks = batch - batch_size = noisy_imgs.size(0) - noisy_imgs = noisy_imgs.to(device) - targets = targets.to(device) - optimizer.zero_grad() + batch_loss = 0 - # Get model outputs and reshape if necessary - outputs = model(noisy_imgs) + # Handle denoise task + if batch_data['denoise'] is not None: + noisy_imgs, targets = batch_data['denoise'] + noisy_imgs = noisy_imgs.to(device) + targets = targets.to(device) + + outputs = model(noisy_imgs) + loss = criterion_denoise(outputs, targets) + batch_loss += loss - task_losses = [] - for i, task in enumerate(tasks): - if task == 'denoise': - # Ensure output matches target shape for denoising - output = outputs[i].view(3, 224, 224) # Reshape to match image dimensions - target = targets[i] - loss = criterion_denoise(output, target) - else: # rotate task - output = outputs[i].view(-1) # Flatten output for rotation prediction - target = targets[i].long() # Convert target to long for classification - loss = criterion_rotate(output.unsqueeze(0), target.unsqueeze(0)) - task_losses.append(loss) + # Handle rotate task + if batch_data['rotate'] is not None: + imgs, targets = batch_data['rotate'] + imgs = imgs.to(device) + targets = targets.long().to(device) + + outputs = model(imgs) + loss = criterion_rotate(outputs, targets) + batch_loss += loss - batch_loss = sum(task_losses) / len(task_losses) - batch_loss.backward() - optimizer.step() - - total_train_loss += batch_loss.item() + if batch_loss > 0: + batch_loss.backward() + optimizer.step() + total_train_loss += batch_loss.item() + batch_count += 1 - avg_train_loss = total_train_loss / len(train_loader) + avg_train_loss = total_train_loss / max(batch_count, 1) print(f"Training Loss: {avg_train_loss:.4f}") # Validation phase model.eval() val_loss = 0.0 + val_batch_count = 0 with torch.no_grad(): - for batch in val_loader: - if batch is None: + for batch_data in val_loader: + if batch_data is None: continue - noisy_imgs, targets, tasks = batch - noisy_imgs = noisy_imgs.to(device) - targets = targets.to(device) + batch_loss = 0 - outputs = model(noisy_imgs) + # Handle denoise task + if batch_data['denoise'] is not None: + noisy_imgs, targets = batch_data['denoise'] + noisy_imgs = noisy_imgs.to(device) + targets = targets.to(device) + + outputs = model(noisy_imgs) + loss = criterion_denoise(outputs, targets) + batch_loss += loss - task_losses = [] - for i, task in enumerate(tasks): - if task == 'denoise': - output = outputs[i].view(3, 224, 224) - target = targets[i] - loss = criterion_denoise(output, target) - else: - output = outputs[i].view(-1) - target = targets[i].long() - loss = criterion_rotate(output.unsqueeze(0), target.unsqueeze(0)) - task_losses.append(loss) + # Handle rotate task + if batch_data['rotate'] is not None: + imgs, targets = batch_data['rotate'] + imgs = imgs.to(device) + targets = targets.long().to(device) + + outputs = model(imgs) + loss = criterion_rotate(outputs, targets) + batch_loss += loss - batch_loss = sum(task_losses) / len(task_losses) - val_loss += batch_loss.item() + if batch_loss > 0: + val_loss += batch_loss.item() + val_batch_count += 1 - avg_val_loss = val_loss / len(val_loader) + avg_val_loss = val_loss / max(val_batch_count, 1) print(f"Validation Loss: {avg_val_loss:.4f}") if avg_val_loss < best_val_loss: