updated pretrainer to feature PreTrainedModel
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Failing after 0s Details

This commit is contained in:
Falko Victor Habel 2025-04-12 21:44:53 +02:00
parent 30695154b4
commit 22e5d0023e
1 changed files with 8 additions and 8 deletions

View File

@ -3,7 +3,7 @@ from torch import nn
import csv import csv
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from ..model.Model import AIIA from transformers import PreTrainedModel
from ..model.config import AIIAConfig from ..model.config import AIIAConfig
from ..data.DataLoader import AIIADataLoader from ..data.DataLoader import AIIADataLoader
import os import os
@ -21,12 +21,12 @@ class ProjectionHead(nn.Module):
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
class Pretrainer: class Pretrainer:
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None): def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None):
""" """
Initialize the pretrainer with a model. Initialize the pretrainer with a model.
Args: Args:
model (AIIA): The model instance to pretrain model (PreTrainedModel): The model instance to pretrain
learning_rate (float): Learning rate for optimization learning_rate (float): Learning rate for optimization
config (dict): Model configuration containing hidden_size config (dict): Model configuration containing hidden_size
""" """
@ -186,11 +186,11 @@ class Pretrainer:
if val_loss < best_val_loss: if val_loss < best_val_loss:
best_val_loss = val_loss best_val_loss = val_loss
self.model.save(output_path) self.model.save_pretrained(output_path)
print("Best model saved!") print("Best model save_pretrainedd!")
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
self.save_losses(losses_path) self.save_pretrained_losses(losses_path)
def _validate(self, val_loader, criterion_denoise, criterion_rotate): def _validate(self, val_loader, criterion_denoise, criterion_rotate):
"""Perform validation and return average validation loss.""" """Perform validation and return average validation loss."""
@ -216,8 +216,8 @@ class Pretrainer:
return avg_val_loss return avg_val_loss
def save_losses(self, csv_file): def save_pretrained_losses(self, csv_file):
"""Save training and validation losses to a CSV file.""" """save_pretrained training and validation losses to a CSV file."""
data = list(zip( data = list(zip(
range(1, len(self.train_losses) + 1), range(1, len(self.train_losses) + 1),
self.train_losses, self.train_losses,