updated pretrainer to handle multiple classes and configs.

This commit is contained in:
Falko Victor Habel 2025-01-28 11:42:03 +01:00
parent a369c49f15
commit 3631df7f0a
2 changed files with 10 additions and 6 deletions

View File

@ -10,7 +10,7 @@ config = AIIAConfig(model_name="AIIA-Base-512x10k-small", num_hidden_layers=6, h
model = AIIABase(config)
# Initialize pretrainer with the model
pretrainer = Pretrainer(model, learning_rate=config.learning_rate)
pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config)
# List of dataset paths
dataset_paths = [

View File

@ -4,13 +4,15 @@ import csv
import pandas as pd
from tqdm import tqdm
from ..model.Model import AIIA
from ..model.config import AIIAConfig
from ..data.DataLoader import AIIADataLoader
class ProjectionHead(nn.Module):
def __init__(self):
def __init__(self, hidden_size):
super().__init__()
self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1)
self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
self.conv_denoise = nn.Conv2d(hidden_size, 3, kernel_size=1)
self.conv_rotate = nn.Conv2d(hidden_size, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
def forward(self, x, task='denoise'):
if task == 'denoise':
@ -19,17 +21,19 @@ class ProjectionHead(nn.Module):
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
class Pretrainer:
def __init__(self, model: AIIA, learning_rate=1e-4):
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None):
"""
Initialize the pretrainer with a model.
Args:
model (AIIA): The model instance to pretrain
learning_rate (float): Learning rate for optimization
config (dict): Model configuration containing hidden_size
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = model.to(self.device)
self.projection_head = ProjectionHead().to(self.device)
hidden_size = config.hidden_size
self.projection_head = ProjectionHead(hidden_size).to(self.device)
self.optimizer = torch.optim.AdamW(
list(self.model.parameters()) + list(self.projection_head.parameters()),
lr=learning_rate