removed rint statemetns and added csv saving
This commit is contained in:
parent
58baf0ad3c
commit
8d08bfc14c
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
from torch import nn, utils
|
||||
from torch import nn
|
||||
import csv
|
||||
import pandas as pd
|
||||
from aiia.model.config import AIIAConfig
|
||||
from aiia.model import AIIABase
|
||||
|
@ -101,7 +102,8 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
)
|
||||
|
||||
best_val_loss = float('inf')
|
||||
|
||||
train_losses = []
|
||||
val_losses = []
|
||||
for epoch in range(num_epochs):
|
||||
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||
print("-" * 20)
|
||||
|
@ -128,14 +130,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
# Get features from base model
|
||||
features = model(noisy_imgs)
|
||||
# Project features back to image space
|
||||
outputs = projection_head(features, task='denoise')
|
||||
|
||||
print(f"\nDenoising task shapes:")
|
||||
print(f"Input shape: {noisy_imgs.shape}")
|
||||
print(f"Target shape: {targets.shape}")
|
||||
print(f"Features shape: {features.shape}")
|
||||
print(f"Output shape: {outputs.shape}")
|
||||
|
||||
outputs = projection_head(features, task='denoise')
|
||||
loss = criterion_denoise(outputs, targets)
|
||||
batch_loss += loss
|
||||
|
||||
|
@ -150,12 +145,6 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
# Project features to rotation predictions
|
||||
outputs = projection_head(features, task='rotate')
|
||||
|
||||
print(f"\nRotation task shapes:")
|
||||
print(f"Input shape: {imgs.shape}")
|
||||
print(f"Target shape: {targets.shape}")
|
||||
print(f"Features shape: {features.shape}")
|
||||
print(f"Output shape: {outputs.shape}")
|
||||
|
||||
loss = criterion_rotate(outputs, targets)
|
||||
batch_loss += loss
|
||||
|
||||
|
@ -166,6 +155,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
batch_count += 1
|
||||
|
||||
avg_train_loss = total_train_loss / max(batch_count, 1)
|
||||
train_losses.append(avg_train_loss)
|
||||
print(f"Training Loss: {avg_train_loss:.4f}")
|
||||
|
||||
# Validation phase
|
||||
|
@ -206,6 +196,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
val_batch_count += 1
|
||||
|
||||
avg_val_loss = val_loss / max(val_batch_count, 1)
|
||||
val_losses.append(avg_val_loss)
|
||||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
||||
|
||||
if avg_val_loss < best_val_loss:
|
||||
|
@ -214,6 +205,21 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
model.save("AIIA-base-512")
|
||||
print("Best model saved!")
|
||||
|
||||
# Prepare the data to be written to the CSV file
|
||||
data = list(zip(range(1, len(train_losses) + 1), train_losses, val_losses))
|
||||
|
||||
# Specify the CSV file name
|
||||
csv_file = 'losses.csv'
|
||||
|
||||
# Write the data to the CSV file
|
||||
with open(csv_file, mode='w', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
# Write the header
|
||||
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
|
||||
# Write the data
|
||||
writer.writerows(data)
|
||||
print(f"Data has been written to {csv_file}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
|
||||
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
|
||||
|
|
Loading…
Reference in New Issue