diff --git a/src/pretrain.py b/src/pretrain.py index 55b4650..6e4c05c 100644 --- a/src/pretrain.py +++ b/src/pretrain.py @@ -1,5 +1,6 @@ import torch -from torch import nn, utils +from torch import nn +import csv import pandas as pd from aiia.model.config import AIIAConfig from aiia.model import AIIABase @@ -101,7 +102,8 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): ) best_val_loss = float('inf') - + train_losses = [] + val_losses = [] for epoch in range(num_epochs): print(f"\nEpoch {epoch+1}/{num_epochs}") print("-" * 20) @@ -128,14 +130,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): # Get features from base model features = model(noisy_imgs) # Project features back to image space - outputs = projection_head(features, task='denoise') - - print(f"\nDenoising task shapes:") - print(f"Input shape: {noisy_imgs.shape}") - print(f"Target shape: {targets.shape}") - print(f"Features shape: {features.shape}") - print(f"Output shape: {outputs.shape}") - + outputs = projection_head(features, task='denoise') loss = criterion_denoise(outputs, targets) batch_loss += loss @@ -150,12 +145,6 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): # Project features to rotation predictions outputs = projection_head(features, task='rotate') - print(f"\nRotation task shapes:") - print(f"Input shape: {imgs.shape}") - print(f"Target shape: {targets.shape}") - print(f"Features shape: {features.shape}") - print(f"Output shape: {outputs.shape}") - loss = criterion_rotate(outputs, targets) batch_loss += loss @@ -166,6 +155,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): batch_count += 1 avg_train_loss = total_train_loss / max(batch_count, 1) + train_losses.append(avg_train_loss) print(f"Training Loss: {avg_train_loss:.4f}") # Validation phase @@ -206,6 +196,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): val_batch_count += 1 avg_val_loss = val_loss / max(val_batch_count, 1) + val_losses.append(avg_val_loss) print(f"Validation Loss: {avg_val_loss:.4f}") if avg_val_loss < best_val_loss: @@ -214,6 +205,21 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): model.save("AIIA-base-512") print("Best model saved!") + # Prepare the data to be written to the CSV file + data = list(zip(range(1, len(train_losses) + 1), train_losses, val_losses)) + + # Specify the CSV file name + csv_file = 'losses.csv' + + # Write the data to the CSV file + with open(csv_file, mode='w', newline='') as file: + writer = csv.writer(file) + # Write the header + writer.writerow(['Epoch', 'Train Loss', 'Validation Loss']) + # Write the data + writer.writerows(data) + print(f"Data has been written to {csv_file}") + if __name__ == "__main__": data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet" data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"