Compare commits

..

No commits in common. "f0f3f05584d0cd7eb02451c42ccd1585ef246f90" and "bb65dec449d4cfc983d7a61b8d3aba56e6464a46" have entirely different histories.

1 changed files with 2 additions and 2 deletions

View File

@ -5,7 +5,7 @@ import datetime
import time import time
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from transformers import PreTrainedModel from ..model.Model import AIIA
from ..model.config import AIIAConfig from ..model.config import AIIAConfig
from ..data.DataLoader import AIIADataLoader from ..data.DataLoader import AIIADataLoader
import os import os
@ -23,7 +23,7 @@ class ProjectionHead(nn.Module):
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
class Pretrainer: class Pretrainer:
def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None): def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None):
""" """
Initialize the pretrainer with a model. Initialize the pretrainer with a model.