develop #41

Merged
Fabel merged 27 commits from develop into main 2025-04-17 17:08:57 +00:00
1 changed files with 8 additions and 8 deletions
Showing only changes of commit 22e5d0023e - Show all commits

View File

@ -3,7 +3,7 @@ from torch import nn
import csv
import pandas as pd
from tqdm import tqdm
from ..model.Model import AIIA
from transformers import PreTrainedModel
from ..model.config import AIIAConfig
from ..data.DataLoader import AIIADataLoader
import os
@ -21,12 +21,12 @@ class ProjectionHead(nn.Module):
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
class Pretrainer:
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None):
def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None):
"""
Initialize the pretrainer with a model.
Args:
model (AIIA): The model instance to pretrain
model (PreTrainedModel): The model instance to pretrain
learning_rate (float): Learning rate for optimization
config (dict): Model configuration containing hidden_size
"""
@ -186,11 +186,11 @@ class Pretrainer:
if val_loss < best_val_loss:
best_val_loss = val_loss
self.model.save(output_path)
print("Best model saved!")
self.model.save_pretrained(output_path)
print("Best model save_pretrainedd!")
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
self.save_losses(losses_path)
self.save_pretrained_losses(losses_path)
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
"""Perform validation and return average validation loss."""
@ -216,8 +216,8 @@ class Pretrainer:
return avg_val_loss
def save_losses(self, csv_file):
"""Save training and validation losses to a CSV file."""
def save_pretrained_losses(self, csv_file):
"""save_pretrained training and validation losses to a CSV file."""
data = list(zip(
range(1, len(self.train_losses) + 1),
self.train_losses,