handle both pretraining methods

This commit is contained in:
Falko Victor Habel 2025-01-27 09:32:20 +01:00
parent 91a568731f
commit 13cf1897ae
1 changed files with 22 additions and 18 deletions

View File

@ -6,7 +6,6 @@ from aiia.model import AIIABase
from aiia.data.DataLoader import AIIADataLoader from aiia.data.DataLoader import AIIADataLoader
from tqdm import tqdm from tqdm import tqdm
def pretrain_model(data_path1, data_path2, num_epochs=3): def pretrain_model(data_path1, data_path2, num_epochs=3):
# Read and merge datasets # Read and merge datasets
df1 = pd.read_parquet(data_path1).head(10000) df1 = pd.read_parquet(data_path1).head(10000)
@ -21,7 +20,6 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
# Initialize model and data loader # Initialize model and data loader
model = AIIABase(config) model = AIIABase(config)
# Define a custom collate function to handle preprocessing and skip bad samples
def safe_collate(batch): def safe_collate(batch):
processed_batch = [] processed_batch = []
for sample in batch: for sample in batch:
@ -37,11 +35,11 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
continue continue
if not processed_batch: if not processed_batch:
return None # Skip batch if all samples are invalid return None
# Stack tensors for the batch # Stack tensors for the batch
images = torch.stack([x['image'] for x in processed_batch]) images = torch.stack([x['image'] for x in processed_batch])
targets = [x['target'] for x in processed_batch] # Don't stack targets yet targets = torch.stack([x['target'] for x in processed_batch]) # Stack targets
tasks = [x['task'] for x in processed_batch] tasks = [x['task'] for x in processed_batch]
return (images, targets, tasks) return (images, targets, tasks)
@ -57,7 +55,6 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
train_loader = aiia_loader.train_loader train_loader = aiia_loader.train_loader
val_loader = aiia_loader.val_loader val_loader = aiia_loader.val_loader
# Define loss functions and optimizer
criterion_denoise = nn.MSELoss() criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss() criterion_rotate = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
@ -77,25 +74,29 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
for batch in tqdm(train_loader): for batch in tqdm(train_loader):
if batch is None: if batch is None:
continue # Skip empty batches continue
noisy_imgs, targets, tasks = batch noisy_imgs, targets, tasks = batch
batch_size = noisy_imgs.size(0)
noisy_imgs = noisy_imgs.to(device) noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
optimizer.zero_grad() optimizer.zero_grad()
# Get model outputs and reshape if necessary
outputs = model(noisy_imgs) outputs = model(noisy_imgs)
task_losses = []
for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)): task_losses = []
for i, task in enumerate(tasks):
if task == 'denoise': if task == 'denoise':
target = target.to(device) # Move target to device # Ensure output matches target shape for denoising
# Ensure output and target have same shape output = outputs[i].view(3, 224, 224) # Reshape to match image dimensions
target = targets[i]
loss = criterion_denoise(output, target) loss = criterion_denoise(output, target)
else: # rotate task else: # rotate task
target = target.to(device) # Move target to device output = outputs[i].view(-1) # Flatten output for rotation prediction
# For rotation task, output should be [batch_size, num_classes] target = targets[i].long() # Convert target to long for classification
loss = criterion_rotate(output.view(1, -1), target.view(-1)) loss = criterion_rotate(output.unsqueeze(0), target.unsqueeze(0))
task_losses.append(loss) task_losses.append(loss)
batch_loss = sum(task_losses) / len(task_losses) batch_loss = sum(task_losses) / len(task_losses)
@ -118,17 +119,20 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
noisy_imgs, targets, tasks = batch noisy_imgs, targets, tasks = batch
noisy_imgs = noisy_imgs.to(device) noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
outputs = model(noisy_imgs) outputs = model(noisy_imgs)
task_losses = []
for i, (output, target, task) in enumerate(zip(outputs, targets, tasks)): task_losses = []
for i, task in enumerate(tasks):
if task == 'denoise': if task == 'denoise':
target = target.to(device) output = outputs[i].view(3, 224, 224)
target = targets[i]
loss = criterion_denoise(output, target) loss = criterion_denoise(output, target)
else: else:
target = target.to(device) output = outputs[i].view(-1)
loss = criterion_rotate(output.view(1, -1), target.view(-1)) target = targets[i].long()
loss = criterion_rotate(output.unsqueeze(0), target.unsqueeze(0))
task_losses.append(loss) task_losses.append(loss)
batch_loss = sum(task_losses) / len(task_losses) batch_loss = sum(task_losses) / len(task_losses)