updated pretrainer to handle multiple classes and configs.
This commit is contained in:
parent
a369c49f15
commit
3631df7f0a
|
@ -10,7 +10,7 @@ config = AIIAConfig(model_name="AIIA-Base-512x10k-small", num_hidden_layers=6, h
|
||||||
model = AIIABase(config)
|
model = AIIABase(config)
|
||||||
|
|
||||||
# Initialize pretrainer with the model
|
# Initialize pretrainer with the model
|
||||||
pretrainer = Pretrainer(model, learning_rate=config.learning_rate)
|
pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config)
|
||||||
|
|
||||||
# List of dataset paths
|
# List of dataset paths
|
||||||
dataset_paths = [
|
dataset_paths = [
|
||||||
|
|
|
@ -4,13 +4,15 @@ import csv
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from ..model.Model import AIIA
|
from ..model.Model import AIIA
|
||||||
|
from ..model.config import AIIAConfig
|
||||||
from ..data.DataLoader import AIIADataLoader
|
from ..data.DataLoader import AIIADataLoader
|
||||||
|
|
||||||
|
|
||||||
class ProjectionHead(nn.Module):
|
class ProjectionHead(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, hidden_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv_denoise = nn.Conv2d(512, 3, kernel_size=1)
|
self.conv_denoise = nn.Conv2d(hidden_size, 3, kernel_size=1)
|
||||||
self.conv_rotate = nn.Conv2d(512, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
|
self.conv_rotate = nn.Conv2d(hidden_size, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
|
||||||
|
|
||||||
def forward(self, x, task='denoise'):
|
def forward(self, x, task='denoise'):
|
||||||
if task == 'denoise':
|
if task == 'denoise':
|
||||||
|
@ -19,17 +21,19 @@ class ProjectionHead(nn.Module):
|
||||||
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
|
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
|
||||||
|
|
||||||
class Pretrainer:
|
class Pretrainer:
|
||||||
def __init__(self, model: AIIA, learning_rate=1e-4):
|
def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None):
|
||||||
"""
|
"""
|
||||||
Initialize the pretrainer with a model.
|
Initialize the pretrainer with a model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (AIIA): The model instance to pretrain
|
model (AIIA): The model instance to pretrain
|
||||||
learning_rate (float): Learning rate for optimization
|
learning_rate (float): Learning rate for optimization
|
||||||
|
config (dict): Model configuration containing hidden_size
|
||||||
"""
|
"""
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
self.model = model.to(self.device)
|
self.model = model.to(self.device)
|
||||||
self.projection_head = ProjectionHead().to(self.device)
|
hidden_size = config.hidden_size
|
||||||
|
self.projection_head = ProjectionHead(hidden_size).to(self.device)
|
||||||
self.optimizer = torch.optim.AdamW(
|
self.optimizer = torch.optim.AdamW(
|
||||||
list(self.model.parameters()) + list(self.projection_head.parameters()),
|
list(self.model.parameters()) + list(self.projection_head.parameters()),
|
||||||
lr=learning_rate
|
lr=learning_rate
|
||||||
|
|
Loading…
Reference in New Issue