updated pretrainer to work with correct imports #39

Merged
Fabel merged 1 commits from feat/fix_imports into develop 2025-04-17 11:09:30 +00:00
1 changed files with 2 additions and 2 deletions

View File

@ -5,7 +5,7 @@ import datetime
import time import time
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from ..model.Model import AIIA from transformers import PreTrainedModel
from ..model.config import AIIAConfig from ..model.config import AIIAConfig
from ..data.DataLoader import AIIADataLoader from ..data.DataLoader import AIIADataLoader
import os import os
@ -23,7 +23,7 @@ class ProjectionHead(nn.Module):
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
class Pretrainer: class Pretrainer:
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None): def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None):
""" """
Initialize the pretrainer with a model. Initialize the pretrainer with a model.