removed rint statemetns and added csv saving

This commit is contained in:
Falko Victor Habel 2025-01-27 10:56:16 +01:00
parent 58baf0ad3c
commit 8d08bfc14c
1 changed files with 22 additions and 16 deletions

View File

@ -1,5 +1,6 @@
import torch
from torch import nn, utils
from torch import nn
import csv
import pandas as pd
from aiia.model.config import AIIAConfig
from aiia.model import AIIABase
@ -101,7 +102,8 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
)
best_val_loss = float('inf')
train_losses = []
val_losses = []
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
@ -128,14 +130,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
# Get features from base model
features = model(noisy_imgs)
# Project features back to image space
outputs = projection_head(features, task='denoise')
print(f"\nDenoising task shapes:")
print(f"Input shape: {noisy_imgs.shape}")
print(f"Target shape: {targets.shape}")
print(f"Features shape: {features.shape}")
print(f"Output shape: {outputs.shape}")
outputs = projection_head(features, task='denoise')
loss = criterion_denoise(outputs, targets)
batch_loss += loss
@ -150,12 +145,6 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
# Project features to rotation predictions
outputs = projection_head(features, task='rotate')
print(f"\nRotation task shapes:")
print(f"Input shape: {imgs.shape}")
print(f"Target shape: {targets.shape}")
print(f"Features shape: {features.shape}")
print(f"Output shape: {outputs.shape}")
loss = criterion_rotate(outputs, targets)
batch_loss += loss
@ -166,6 +155,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
batch_count += 1
avg_train_loss = total_train_loss / max(batch_count, 1)
train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase
@ -206,6 +196,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
val_batch_count += 1
avg_val_loss = val_loss / max(val_batch_count, 1)
val_losses.append(avg_val_loss)
print(f"Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss:
@ -214,6 +205,21 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
model.save("AIIA-base-512")
print("Best model saved!")
# Prepare the data to be written to the CSV file
data = list(zip(range(1, len(train_losses) + 1), train_losses, val_losses))
# Specify the CSV file name
csv_file = 'losses.csv'
# Write the data to the CSV file
with open(csv_file, mode='w', newline='') as file:
writer = csv.writer(file)
# Write the header
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
# Write the data
writer.writerows(data)
print(f"Data has been written to {csv_file}")
if __name__ == "__main__":
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"