diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index 42ba4b8..93df9dd 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -3,7 +3,7 @@ from torch import nn import csv import pandas as pd from tqdm import tqdm -from ..model.Model import AIIA +from transformers import PreTrainedModel from ..model.config import AIIAConfig from ..data.DataLoader import AIIADataLoader import os @@ -21,12 +21,12 @@ class ProjectionHead(nn.Module): return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task class Pretrainer: - def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None): + def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None): """ Initialize the pretrainer with a model. Args: - model (AIIA): The model instance to pretrain + model (PreTrainedModel): The model instance to pretrain learning_rate (float): Learning rate for optimization config (dict): Model configuration containing hidden_size """ @@ -186,11 +186,11 @@ class Pretrainer: if val_loss < best_val_loss: best_val_loss = val_loss - self.model.save(output_path) - print("Best model saved!") + self.model.save_pretrained(output_path) + print("Best model save_pretrainedd!") losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') - self.save_losses(losses_path) + self.save_pretrained_losses(losses_path) def _validate(self, val_loader, criterion_denoise, criterion_rotate): """Perform validation and return average validation loss.""" @@ -216,8 +216,8 @@ class Pretrainer: return avg_val_loss - def save_losses(self, csv_file): - """Save training and validation losses to a CSV file.""" + def save_pretrained_losses(self, csv_file): + """save_pretrained training and validation losses to a CSV file.""" data = list(zip( range(1, len(self.train_losses) + 1), self.train_losses,