From 9b39a6926552fd2124585f4b971b509f00871ddd Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Thu, 17 Apr 2025 13:04:23 +0200 Subject: [PATCH] updated pretrainer to work with correct imports --- src/aiia/pretrain/pretrainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index f94af2c..6815a90 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -5,7 +5,7 @@ import datetime import time import pandas as pd from tqdm import tqdm -from ..model.Model import AIIA +from transformers import PreTrainedModel from ..model.config import AIIAConfig from ..data.DataLoader import AIIADataLoader import os @@ -23,7 +23,7 @@ class ProjectionHead(nn.Module): return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task class Pretrainer: - def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None): + def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None): """ Initialize the pretrainer with a model. -- 2.34.1