feat/tf_support #37
|
@ -3,7 +3,7 @@ from torch import nn
|
|||
import csv
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from ..model.Model import AIIA
|
||||
from transformers import PreTrainedModel
|
||||
from ..model.config import AIIAConfig
|
||||
from ..data.DataLoader import AIIADataLoader
|
||||
import os
|
||||
|
@ -21,12 +21,12 @@ class ProjectionHead(nn.Module):
|
|||
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
|
||||
|
||||
class Pretrainer:
|
||||
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None):
|
||||
def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None):
|
||||
"""
|
||||
Initialize the pretrainer with a model.
|
||||
|
||||
Args:
|
||||
model (AIIA): The model instance to pretrain
|
||||
model (PreTrainedModel): The model instance to pretrain
|
||||
learning_rate (float): Learning rate for optimization
|
||||
config (dict): Model configuration containing hidden_size
|
||||
"""
|
||||
|
@ -186,11 +186,11 @@ class Pretrainer:
|
|||
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
self.model.save(output_path)
|
||||
print("Best model saved!")
|
||||
self.model.save_pretrained(output_path)
|
||||
print("Best model save_pretrainedd!")
|
||||
|
||||
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
|
||||
self.save_losses(losses_path)
|
||||
self.save_pretrained_losses(losses_path)
|
||||
|
||||
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
|
||||
"""Perform validation and return average validation loss."""
|
||||
|
@ -216,8 +216,8 @@ class Pretrainer:
|
|||
return avg_val_loss
|
||||
|
||||
|
||||
def save_losses(self, csv_file):
|
||||
"""Save training and validation losses to a CSV file."""
|
||||
def save_pretrained_losses(self, csv_file):
|
||||
"""save_pretrained training and validation losses to a CSV file."""
|
||||
data = list(zip(
|
||||
range(1, len(self.train_losses) + 1),
|
||||
self.train_losses,
|
||||
|
|
Loading…
Reference in New Issue