updated pretrainer to work with correct imports #39
|
@ -5,7 +5,7 @@ import datetime
|
||||||
import time
|
import time
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from ..model.Model import AIIA
|
from transformers import PreTrainedModel
|
||||||
from ..model.config import AIIAConfig
|
from ..model.config import AIIAConfig
|
||||||
from ..data.DataLoader import AIIADataLoader
|
from ..data.DataLoader import AIIADataLoader
|
||||||
import os
|
import os
|
||||||
|
@ -23,7 +23,7 @@ class ProjectionHead(nn.Module):
|
||||||
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
|
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
|
||||||
|
|
||||||
class Pretrainer:
|
class Pretrainer:
|
||||||
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None):
|
def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None):
|
||||||
"""
|
"""
|
||||||
Initialize the pretrainer with a model.
|
Initialize the pretrainer with a model.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue