develop #41

Merged
Fabel merged 27 commits from develop into main 2025-04-17 17:08:57 +00:00
1 changed files with 2 additions and 2 deletions
Showing only changes of commit f0f3f05584 - Show all commits

View File

@ -5,7 +5,7 @@ import datetime
import time import time
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from ..model.Model import AIIA from transformers import PreTrainedModel
from ..model.config import AIIAConfig from ..model.config import AIIAConfig
from ..data.DataLoader import AIIADataLoader from ..data.DataLoader import AIIADataLoader
import os import os
@ -23,7 +23,7 @@ class ProjectionHead(nn.Module):
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
class Pretrainer: class Pretrainer:
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None): def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None):
""" """
Initialize the pretrainer with a model. Initialize the pretrainer with a model.