addeed tasks for both denosing and rotation

This commit is contained in:
Falko Victor Habel 2025-01-27 10:15:00 +01:00
parent 13cf1897ae
commit b6b63851ca
1 changed files with 88 additions and 59 deletions

View File

@ -6,6 +6,7 @@ from aiia.model import AIIABase
from aiia.data.DataLoader import AIIADataLoader from aiia.data.DataLoader import AIIADataLoader
from tqdm import tqdm from tqdm import tqdm
def pretrain_model(data_path1, data_path2, num_epochs=3): def pretrain_model(data_path1, data_path2, num_epochs=3):
# Read and merge datasets # Read and merge datasets
df1 = pd.read_parquet(data_path1).head(10000) df1 = pd.read_parquet(data_path1).head(10000)
@ -21,28 +22,49 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
model = AIIABase(config) model = AIIABase(config)
def safe_collate(batch): def safe_collate(batch):
processed_batch = [] denoise_batch = []
rotate_batch = []
for sample in batch: for sample in batch:
try: try:
noisy_img, target, task = sample noisy_img, target, task = sample
processed_batch.append({ if task == 'denoise':
'image': noisy_img, denoise_batch.append({
'target': target, 'image': noisy_img,
'task': task 'target': target,
}) 'task': task
})
else: # rotate task
rotate_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
except Exception as e: except Exception as e:
print(f"Skipping sample due to error: {e}") print(f"Skipping sample due to error: {e}")
continue continue
if not processed_batch: if not denoise_batch and not rotate_batch:
return None return None
# Stack tensors for the batch batch_data = {
images = torch.stack([x['image'] for x in processed_batch]) 'denoise': None,
targets = torch.stack([x['target'] for x in processed_batch]) # Stack targets 'rotate': None
tasks = [x['task'] for x in processed_batch] }
return (images, targets, tasks) # Process denoise batch
if denoise_batch:
images = torch.stack([x['image'] for x in denoise_batch])
targets = torch.stack([x['target'] for x in denoise_batch])
batch_data['denoise'] = (images, targets)
# Process rotate batch
if rotate_batch:
images = torch.stack([x['image'] for x in rotate_batch])
targets = torch.stack([x['target'] for x in rotate_batch])
batch_data['rotate'] = (images, targets)
return batch_data
aiia_loader = AIIADataLoader( aiia_loader = AIIADataLoader(
merged_df, merged_df,
@ -71,74 +93,81 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
# Training phase # Training phase
model.train() model.train()
total_train_loss = 0.0 total_train_loss = 0.0
batch_count = 0
for batch in tqdm(train_loader): for batch_data in tqdm(train_loader):
if batch is None: if batch_data is None:
continue continue
noisy_imgs, targets, tasks = batch
batch_size = noisy_imgs.size(0)
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
optimizer.zero_grad() optimizer.zero_grad()
batch_loss = 0
# Get model outputs and reshape if necessary # Handle denoise task
outputs = model(noisy_imgs) if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
task_losses = [] outputs = model(noisy_imgs)
for i, task in enumerate(tasks): loss = criterion_denoise(outputs, targets)
if task == 'denoise': batch_loss += loss
# Ensure output matches target shape for denoising
output = outputs[i].view(3, 224, 224) # Reshape to match image dimensions
target = targets[i]
loss = criterion_denoise(output, target)
else: # rotate task
output = outputs[i].view(-1) # Flatten output for rotation prediction
target = targets[i].long() # Convert target to long for classification
loss = criterion_rotate(output.unsqueeze(0), target.unsqueeze(0))
task_losses.append(loss)
batch_loss = sum(task_losses) / len(task_losses) # Handle rotate task
batch_loss.backward() if batch_data['rotate'] is not None:
optimizer.step() imgs, targets = batch_data['rotate']
imgs = imgs.to(device)
targets = targets.long().to(device)
total_train_loss += batch_loss.item() outputs = model(imgs)
loss = criterion_rotate(outputs, targets)
batch_loss += loss
avg_train_loss = total_train_loss / len(train_loader) if batch_loss > 0:
batch_loss.backward()
optimizer.step()
total_train_loss += batch_loss.item()
batch_count += 1
avg_train_loss = total_train_loss / max(batch_count, 1)
print(f"Training Loss: {avg_train_loss:.4f}") print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase # Validation phase
model.eval() model.eval()
val_loss = 0.0 val_loss = 0.0
val_batch_count = 0
with torch.no_grad(): with torch.no_grad():
for batch in val_loader: for batch_data in val_loader:
if batch is None: if batch_data is None:
continue continue
noisy_imgs, targets, tasks = batch batch_loss = 0
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
outputs = model(noisy_imgs) # Handle denoise task
if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
task_losses = [] outputs = model(noisy_imgs)
for i, task in enumerate(tasks): loss = criterion_denoise(outputs, targets)
if task == 'denoise': batch_loss += loss
output = outputs[i].view(3, 224, 224)
target = targets[i]
loss = criterion_denoise(output, target)
else:
output = outputs[i].view(-1)
target = targets[i].long()
loss = criterion_rotate(output.unsqueeze(0), target.unsqueeze(0))
task_losses.append(loss)
batch_loss = sum(task_losses) / len(task_losses) # Handle rotate task
val_loss += batch_loss.item() if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(device)
targets = targets.long().to(device)
avg_val_loss = val_loss / len(val_loader) outputs = model(imgs)
loss = criterion_rotate(outputs, targets)
batch_loss += loss
if batch_loss > 0:
val_loss += batch_loss.item()
val_batch_count += 1
avg_val_loss = val_loss / max(val_batch_count, 1)
print(f"Validation Loss: {avg_val_loss:.4f}") print(f"Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss: if avg_val_loss < best_val_loss: