updated pretrainer to handle multiple classes and configs.

This commit is contained in:
Falko Victor Habel 2025-01-28 11:42:03 +01:00
parent a369c49f15
commit 3631df7f0a
2 changed files with 10 additions and 6 deletions

View File

@ -10,7 +10,7 @@ config = AIIAConfig(model_name="AIIA-Base-512x10k-small", num_hidden_layers=6, h
model = AIIABase(config) model = AIIABase(config)
# Initialize pretrainer with the model # Initialize pretrainer with the model
pretrainer = Pretrainer(model, learning_rate=config.learning_rate) pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config)
# List of dataset paths # List of dataset paths
dataset_paths = [ dataset_paths = [

View File

@ -4,13 +4,15 @@ import csv
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from ..model.Model import AIIA from ..model.Model import AIIA
from ..model.config import AIIAConfig
from ..data.DataLoader import AIIADataLoader from ..data.DataLoader import AIIADataLoader
class ProjectionHead(nn.Module): class ProjectionHead(nn.Module):
def __init__(self): def __init__(self, hidden_size):
super().__init__() super().__init__()
self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1) self.conv_denoise = nn.Conv2d(hidden_size, 3, kernel_size=1)
self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees self.conv_rotate = nn.Conv2d(hidden_size, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
def forward(self, x, task='denoise'): def forward(self, x, task='denoise'):
if task == 'denoise': if task == 'denoise':
@ -19,17 +21,19 @@ class ProjectionHead(nn.Module):
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
class Pretrainer: class Pretrainer:
def __init__(self, model: AIIA, learning_rate=1e-4): def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None):
""" """
Initialize the pretrainer with a model. Initialize the pretrainer with a model.
Args: Args:
model (AIIA): The model instance to pretrain model (AIIA): The model instance to pretrain
learning_rate (float): Learning rate for optimization learning_rate (float): Learning rate for optimization
config (dict): Model configuration containing hidden_size
""" """
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = model.to(self.device) self.model = model.to(self.device)
self.projection_head = ProjectionHead().to(self.device) hidden_size = config.hidden_size
self.projection_head = ProjectionHead(hidden_size).to(self.device)
self.optimizer = torch.optim.AdamW( self.optimizer = torch.optim.AdamW(
list(self.model.parameters()) + list(self.projection_head.parameters()), list(self.model.parameters()) + list(self.projection_head.parameters()),
lr=learning_rate lr=learning_rate