diff --git a/example.py b/example.py index 6e1620b..6d605ca 100644 --- a/example.py +++ b/example.py @@ -10,7 +10,7 @@ config = AIIAConfig(model_name="AIIA-Base-512x10k-small", num_hidden_layers=6, h model = AIIABase(config) # Initialize pretrainer with the model -pretrainer = Pretrainer(model, learning_rate=config.learning_rate) +pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config) # List of dataset paths dataset_paths = [ diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index 913b77a..fa84fcb 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -4,13 +4,15 @@ import csv import pandas as pd from tqdm import tqdm from ..model.Model import AIIA +from ..model.config import AIIAConfig from ..data.DataLoader import AIIADataLoader + class ProjectionHead(nn.Module): - def __init__(self): + def __init__(self, hidden_size): super().__init__() - self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1) - self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees + self.conv_denoise = nn.Conv2d(hidden_size, 3, kernel_size=1) + self.conv_rotate = nn.Conv2d(hidden_size, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees def forward(self, x, task='denoise'): if task == 'denoise': @@ -19,17 +21,19 @@ class ProjectionHead(nn.Module): return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task class Pretrainer: - def __init__(self, model: AIIA, learning_rate=1e-4): + def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None): """ Initialize the pretrainer with a model. Args: model (AIIA): The model instance to pretrain learning_rate (float): Learning rate for optimization + config (dict): Model configuration containing hidden_size """ self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = model.to(self.device) - self.projection_head = ProjectionHead().to(self.device) + hidden_size = config.hidden_size + self.projection_head = ProjectionHead(hidden_size).to(self.device) self.optimizer = torch.optim.AdamW( list(self.model.parameters()) + list(self.projection_head.parameters()), lr=learning_rate