diff --git a/src/pretrain.py b/src/pretrain.py index c4bc1a7..690bd9e 100644 --- a/src/pretrain.py +++ b/src/pretrain.py @@ -6,7 +6,6 @@ from aiia.model import AIIABase from aiia.data.DataLoader import AIIADataLoader from tqdm import tqdm - def pretrain_model(data_path1, data_path2, num_epochs=3): # Read and merge datasets df1 = pd.read_parquet(data_path1).head(10000) @@ -21,7 +20,6 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): # Initialize model and data loader model = AIIABase(config) - # Define a custom collate function to handle preprocessing and skip bad samples def safe_collate(batch): processed_batch = [] for sample in batch: @@ -37,11 +35,11 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): continue if not processed_batch: - return None # Skip batch if all samples are invalid + return None # Stack tensors for the batch images = torch.stack([x['image'] for x in processed_batch]) - targets = [x['target'] for x in processed_batch] # Don't stack targets yet + targets = torch.stack([x['target'] for x in processed_batch]) # Stack targets tasks = [x['task'] for x in processed_batch] return (images, targets, tasks) @@ -57,7 +55,6 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): train_loader = aiia_loader.train_loader val_loader = aiia_loader.val_loader - # Define loss functions and optimizer criterion_denoise = nn.MSELoss() criterion_rotate = nn.CrossEntropyLoss() optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) @@ -77,25 +74,29 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): for batch in tqdm(train_loader): if batch is None: - continue # Skip empty batches + continue noisy_imgs, targets, tasks = batch + batch_size = noisy_imgs.size(0) noisy_imgs = noisy_imgs.to(device) + targets = targets.to(device) optimizer.zero_grad() + # Get model outputs and reshape if necessary outputs = model(noisy_imgs) - task_losses = [] - for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)): + task_losses = [] + for i, task in enumerate(tasks): if task == 'denoise': - target = target.to(device) # Move target to device - # Ensure output and target have same shape + # Ensure output matches target shape for denoising + output = outputs[i].view(3, 224, 224) # Reshape to match image dimensions + target = targets[i] loss = criterion_denoise(output, target) else: # rotate task - target = target.to(device) # Move target to device - # For rotation task, output should be [batch_size, num_classes] - loss = criterion_rotate(output.view(1, -1), target.view(-1)) + output = outputs[i].view(-1) # Flatten output for rotation prediction + target = targets[i].long() # Convert target to long for classification + loss = criterion_rotate(output.unsqueeze(0), target.unsqueeze(0)) task_losses.append(loss) batch_loss = sum(task_losses) / len(task_losses) @@ -118,17 +119,20 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): noisy_imgs, targets, tasks = batch noisy_imgs = noisy_imgs.to(device) + targets = targets.to(device) outputs = model(noisy_imgs) - task_losses = [] - for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)): + task_losses = [] + for i, task in enumerate(tasks): if task == 'denoise': - target = target.to(device) + output = outputs[i].view(3, 224, 224) + target = targets[i] loss = criterion_denoise(output, target) else: - target = target.to(device) - loss = criterion_rotate(output.view(1, -1), target.view(-1)) + output = outputs[i].view(-1) + target = targets[i].long() + loss = criterion_rotate(output.unsqueeze(0), target.unsqueeze(0)) task_losses.append(loss) batch_loss = sum(task_losses) / len(task_losses)