Compare commits

...

2 Commits

Author SHA1 Message Date
Falko Victor Habel f0f3f05584 Merge pull request 'updated pretrainer to work with correct imports' (#39) from feat/fix_imports into develop
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 37s Details
Reviewed-on: #39
2025-04-17 11:09:30 +00:00
Falko Victor Habel 9b39a69265 updated pretrainer to work with correct imports
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 37s Details
2025-04-17 13:04:23 +02:00
1 changed files with 2 additions and 2 deletions

View File

@ -5,7 +5,7 @@ import datetime
import time import time
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from ..model.Model import AIIA from transformers import PreTrainedModel
from ..model.config import AIIAConfig from ..model.config import AIIAConfig
from ..data.DataLoader import AIIADataLoader from ..data.DataLoader import AIIADataLoader
import os import os
@ -23,7 +23,7 @@ class ProjectionHead(nn.Module):
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
class Pretrainer: class Pretrainer:
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None): def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None):
""" """
Initialize the pretrainer with a model. Initialize the pretrainer with a model.