updated pretrainer to handle multiple classes and configs.
This commit is contained in:
parent
a369c49f15
commit
3631df7f0a
|
@ -10,7 +10,7 @@ config = AIIAConfig(model_name="AIIA-Base-512x10k-small", num_hidden_layers=6, h
|
|||
model = AIIABase(config)
|
||||
|
||||
# Initialize pretrainer with the model
|
||||
pretrainer = Pretrainer(model, learning_rate=config.learning_rate)
|
||||
pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config)
|
||||
|
||||
# List of dataset paths
|
||||
dataset_paths = [
|
||||
|
|
|
@ -4,13 +4,15 @@ import csv
|
|||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from ..model.Model import AIIA
|
||||
from ..model.config import AIIAConfig
|
||||
from ..data.DataLoader import AIIADataLoader
|
||||
|
||||
|
||||
class ProjectionHead(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, hidden_size):
|
||||
super().__init__()
|
||||
self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1)
|
||||
self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
|
||||
self.conv_denoise = nn.Conv2d(hidden_size, 3, kernel_size=1)
|
||||
self.conv_rotate = nn.Conv2d(hidden_size, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
|
||||
|
||||
def forward(self, x, task='denoise'):
|
||||
if task == 'denoise':
|
||||
|
@ -19,17 +21,19 @@ class ProjectionHead(nn.Module):
|
|||
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
|
||||
|
||||
class Pretrainer:
|
||||
def __init__(self, model: AIIA, learning_rate=1e-4):
|
||||
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None):
|
||||
"""
|
||||
Initialize the pretrainer with a model.
|
||||
|
||||
Args:
|
||||
model (AIIA): The model instance to pretrain
|
||||
learning_rate (float): Learning rate for optimization
|
||||
config (dict): Model configuration containing hidden_size
|
||||
"""
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.model = model.to(self.device)
|
||||
self.projection_head = ProjectionHead().to(self.device)
|
||||
hidden_size = config.hidden_size
|
||||
self.projection_head = ProjectionHead(hidden_size).to(self.device)
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
list(self.model.parameters()) + list(self.projection_head.parameters()),
|
||||
lr=learning_rate
|
||||
|
|
Loading…
Reference in New Issue