From fe4d6b5b22fed2c52cbaf3dcab7444300d81b6fc Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Mon, 27 Jan 2025 10:26:28 +0100 Subject: [PATCH] corrected viewing and some prints --- src/pretrain.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/pretrain.py b/src/pretrain.py index 201c03f..0792f09 100644 --- a/src/pretrain.py +++ b/src/pretrain.py @@ -6,7 +6,6 @@ from aiia.model import AIIABase from aiia.data.DataLoader import AIIADataLoader from tqdm import tqdm - def pretrain_model(data_path1, data_path2, num_epochs=3): # Read and merge datasets df1 = pd.read_parquet(data_path1).head(10000) @@ -108,7 +107,19 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): noisy_imgs = noisy_imgs.to(device) targets = targets.to(device) + # Print shapes for debugging + print(f"\nDenoising task shapes:") + print(f"Input shape: {noisy_imgs.shape}") + print(f"Target shape: {targets.shape}") + outputs = model(noisy_imgs) + print(f"Raw output shape: {outputs.shape}") + + # Reshape output to match target dimensions + batch_size = targets.size(0) + outputs = outputs.view(batch_size, 3, 224, 224) + print(f"Reshaped output shape: {outputs.shape}") + loss = criterion_denoise(outputs, targets) batch_loss += loss @@ -118,7 +129,18 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): imgs = imgs.to(device) targets = targets.long().to(device) + # Print shapes for debugging + print(f"\nRotation task shapes:") + print(f"Input shape: {imgs.shape}") + print(f"Target shape: {targets.shape}") + outputs = model(imgs) + print(f"Raw output shape: {outputs.shape}") + + # Reshape output for rotation classification + outputs = outputs.view(targets.size(0), -1) # Flatten to [batch_size, features] + print(f"Reshaped output shape: {outputs.shape}") + loss = criterion_rotate(outputs, targets) batch_loss += loss @@ -150,6 +172,8 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): targets = targets.to(device) outputs = model(noisy_imgs) + batch_size = targets.size(0) + outputs = outputs.view(batch_size, 3, 224, 224) loss = criterion_denoise(outputs, targets) batch_loss += loss @@ -160,6 +184,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): targets = targets.long().to(device) outputs = model(imgs) + outputs = outputs.view(targets.size(0), -1) loss = criterion_rotate(outputs, targets) batch_loss += loss