From 58baf0ad3c9b212000fbbb4372120161d9ac9950 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Mon, 27 Jan 2025 10:43:59 +0100 Subject: [PATCH] overall improvement --- src/pretrain.py | 78 +++++++++++++++++++++++++++++-------------------- 1 file changed, 46 insertions(+), 32 deletions(-) diff --git a/src/pretrain.py b/src/pretrain.py index 0792f09..55b4650 100644 --- a/src/pretrain.py +++ b/src/pretrain.py @@ -6,6 +6,18 @@ from aiia.model import AIIABase from aiia.data.DataLoader import AIIADataLoader from tqdm import tqdm +class ProjectionHead(nn.Module): + def __init__(self): + super().__init__() + self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1) + self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees + + def forward(self, x, task='denoise'): + if task == 'denoise': + return self.conv_denoise(x) + else: + return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task + def pretrain_model(data_path1, data_path2, num_epochs=3): # Read and merge datasets df1 = pd.read_parquet(data_path1).head(10000) @@ -17,9 +29,14 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): model_name="AIIA-Base-512x20k", ) - # Initialize model and data loader + # Initialize model and projection head model = AIIABase(config) + projection_head = ProjectionHead() + device = "cuda" if torch.cuda.is_available() else "cpu" + model.to(device) + projection_head.to(device) + def safe_collate(batch): denoise_batch = [] rotate_batch = [] @@ -51,13 +68,11 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): 'rotate': None } - # Process denoise batch if denoise_batch: images = torch.stack([x['image'] for x in denoise_batch]) targets = torch.stack([x['target'] for x in denoise_batch]) batch_data['denoise'] = (images, targets) - # Process rotate batch if rotate_batch: images = torch.stack([x['image'] for x in rotate_batch]) targets = torch.stack([x['target'] for x in rotate_batch]) @@ -78,10 +93,12 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): criterion_denoise = nn.MSELoss() criterion_rotate = nn.CrossEntropyLoss() - optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) - - device = "cuda" if torch.cuda.is_available() else "cpu" - model.to(device) + + # Update optimizer to include projection head parameters + optimizer = torch.optim.AdamW( + list(model.parameters()) + list(projection_head.parameters()), + lr=config.learning_rate + ) best_val_loss = float('inf') @@ -91,6 +108,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): # Training phase model.train() + projection_head.train() total_train_loss = 0.0 batch_count = 0 @@ -107,18 +125,16 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): noisy_imgs = noisy_imgs.to(device) targets = targets.to(device) - # Print shapes for debugging + # Get features from base model + features = model(noisy_imgs) + # Project features back to image space + outputs = projection_head(features, task='denoise') + print(f"\nDenoising task shapes:") print(f"Input shape: {noisy_imgs.shape}") print(f"Target shape: {targets.shape}") - - outputs = model(noisy_imgs) - print(f"Raw output shape: {outputs.shape}") - - # Reshape output to match target dimensions - batch_size = targets.size(0) - outputs = outputs.view(batch_size, 3, 224, 224) - print(f"Reshaped output shape: {outputs.shape}") + print(f"Features shape: {features.shape}") + print(f"Output shape: {outputs.shape}") loss = criterion_denoise(outputs, targets) batch_loss += loss @@ -129,17 +145,16 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): imgs = imgs.to(device) targets = targets.long().to(device) - # Print shapes for debugging + # Get features from base model + features = model(imgs) + # Project features to rotation predictions + outputs = projection_head(features, task='rotate') + print(f"\nRotation task shapes:") print(f"Input shape: {imgs.shape}") print(f"Target shape: {targets.shape}") - - outputs = model(imgs) - print(f"Raw output shape: {outputs.shape}") - - # Reshape output for rotation classification - outputs = outputs.view(targets.size(0), -1) # Flatten to [batch_size, features] - print(f"Reshaped output shape: {outputs.shape}") + print(f"Features shape: {features.shape}") + print(f"Output shape: {outputs.shape}") loss = criterion_rotate(outputs, targets) batch_loss += loss @@ -155,6 +170,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): # Validation phase model.eval() + projection_head.eval() val_loss = 0.0 val_batch_count = 0 @@ -165,26 +181,23 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): batch_loss = 0 - # Handle denoise task if batch_data['denoise'] is not None: noisy_imgs, targets = batch_data['denoise'] noisy_imgs = noisy_imgs.to(device) targets = targets.to(device) - outputs = model(noisy_imgs) - batch_size = targets.size(0) - outputs = outputs.view(batch_size, 3, 224, 224) + features = model(noisy_imgs) + outputs = projection_head(features, task='denoise') loss = criterion_denoise(outputs, targets) batch_loss += loss - # Handle rotate task if batch_data['rotate'] is not None: imgs, targets = batch_data['rotate'] imgs = imgs.to(device) targets = targets.long().to(device) - outputs = model(imgs) - outputs = outputs.view(targets.size(0), -1) + features = model(imgs) + outputs = projection_head(features, task='rotate') loss = criterion_rotate(outputs, targets) batch_loss += loss @@ -197,7 +210,8 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss - model.save("BASEv0.1") + # Save both model and projection head + model.save("AIIA-base-512") print("Best model saved!") if __name__ == "__main__":