Merge pull request 'updated pretrainer to work with correct imports' (#39) from feat/fix_imports into develop
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 37s Details

Reviewed-on: #39
This commit is contained in:
Falko Victor Habel 2025-04-17 11:09:30 +00:00
commit f0f3f05584
1 changed files with 2 additions and 2 deletions

View File

@ -5,7 +5,7 @@ import datetime
import time
import pandas as pd
from tqdm import tqdm
from ..model.Model import AIIA
from transformers import PreTrainedModel
from ..model.config import AIIAConfig
from ..data.DataLoader import AIIADataLoader
import os
@ -23,7 +23,7 @@ class ProjectionHead(nn.Module):
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
class Pretrainer:
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None):
def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None):
"""
Initialize the pretrainer with a model.