updated pretrainer to feature PreTrainedModel
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Failing after 0s
Details
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Failing after 0s
Details
This commit is contained in:
parent
30695154b4
commit
22e5d0023e
|
@ -3,7 +3,7 @@ from torch import nn
|
||||||
import csv
|
import csv
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from ..model.Model import AIIA
|
from transformers import PreTrainedModel
|
||||||
from ..model.config import AIIAConfig
|
from ..model.config import AIIAConfig
|
||||||
from ..data.DataLoader import AIIADataLoader
|
from ..data.DataLoader import AIIADataLoader
|
||||||
import os
|
import os
|
||||||
|
@ -21,12 +21,12 @@ class ProjectionHead(nn.Module):
|
||||||
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
|
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
|
||||||
|
|
||||||
class Pretrainer:
|
class Pretrainer:
|
||||||
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None):
|
def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None):
|
||||||
"""
|
"""
|
||||||
Initialize the pretrainer with a model.
|
Initialize the pretrainer with a model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (AIIA): The model instance to pretrain
|
model (PreTrainedModel): The model instance to pretrain
|
||||||
learning_rate (float): Learning rate for optimization
|
learning_rate (float): Learning rate for optimization
|
||||||
config (dict): Model configuration containing hidden_size
|
config (dict): Model configuration containing hidden_size
|
||||||
"""
|
"""
|
||||||
|
@ -186,11 +186,11 @@ class Pretrainer:
|
||||||
|
|
||||||
if val_loss < best_val_loss:
|
if val_loss < best_val_loss:
|
||||||
best_val_loss = val_loss
|
best_val_loss = val_loss
|
||||||
self.model.save(output_path)
|
self.model.save_pretrained(output_path)
|
||||||
print("Best model saved!")
|
print("Best model save_pretrainedd!")
|
||||||
|
|
||||||
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
|
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
|
||||||
self.save_losses(losses_path)
|
self.save_pretrained_losses(losses_path)
|
||||||
|
|
||||||
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
|
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
|
||||||
"""Perform validation and return average validation loss."""
|
"""Perform validation and return average validation loss."""
|
||||||
|
@ -216,8 +216,8 @@ class Pretrainer:
|
||||||
return avg_val_loss
|
return avg_val_loss
|
||||||
|
|
||||||
|
|
||||||
def save_losses(self, csv_file):
|
def save_pretrained_losses(self, csv_file):
|
||||||
"""Save training and validation losses to a CSV file."""
|
"""save_pretrained training and validation losses to a CSV file."""
|
||||||
data = list(zip(
|
data = list(zip(
|
||||||
range(1, len(self.train_losses) + 1),
|
range(1, len(self.train_losses) + 1),
|
||||||
self.train_losses,
|
self.train_losses,
|
||||||
|
|
Loading…
Reference in New Issue