409 lines
16 KiB
Python
409 lines
16 KiB
Python
import torch
|
|
from torch import nn
|
|
import csv
|
|
import datetime
|
|
import time
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
from transformers import PreTrainedModel
|
|
from ..model.config import AIIAConfig
|
|
from ..data.DataLoader import AIIADataLoader
|
|
import os
|
|
|
|
class ProjectionHead(nn.Module):
|
|
def __init__(self, hidden_size):
|
|
super().__init__()
|
|
self.conv_denoise = nn.Conv2d(hidden_size, 3, kernel_size=1)
|
|
self.conv_rotate = nn.Conv2d(hidden_size, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees
|
|
|
|
def forward(self, x, task='denoise'):
|
|
if task == 'denoise':
|
|
return self.conv_denoise(x)
|
|
else:
|
|
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
|
|
|
|
class Pretrainer:
|
|
def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None):
|
|
"""
|
|
Initialize the pretrainer with a model.
|
|
|
|
Args:
|
|
model (AIIA): The model instance to pretrain
|
|
learning_rate (float): Learning rate for optimization
|
|
config (dict): Model configuration containing hidden_size
|
|
"""
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
self.model = model.to(self.device)
|
|
hidden_size = config.hidden_size
|
|
self.projection_head = ProjectionHead(hidden_size).to(self.device)
|
|
self.optimizer = torch.optim.AdamW(
|
|
list(self.model.parameters()) + list(self.projection_head.parameters()),
|
|
lr=learning_rate
|
|
)
|
|
self.train_losses = []
|
|
self.val_losses = []
|
|
self.checkpoint_dir = None # Initialize checkpoint_dir
|
|
self.current_epoch = 0 # Add current_epoch tracking
|
|
|
|
@staticmethod
|
|
def safe_collate(batch):
|
|
"""Safely collate batch data handling both denoise and rotate tasks."""
|
|
denoise_batch = []
|
|
rotate_batch = []
|
|
|
|
for sample in batch:
|
|
try:
|
|
noisy_img, target, task = sample
|
|
if task == 'denoise':
|
|
denoise_batch.append({
|
|
'image': noisy_img,
|
|
'target': target,
|
|
'task': task
|
|
})
|
|
else: # rotate task
|
|
rotate_batch.append({
|
|
'image': noisy_img,
|
|
'target': target,
|
|
'task': task
|
|
})
|
|
except Exception as e:
|
|
print(f"Skipping sample due to error: {e}")
|
|
continue
|
|
|
|
if not denoise_batch and not rotate_batch:
|
|
return None
|
|
|
|
batch_data = {
|
|
'denoise': None,
|
|
'rotate': None
|
|
}
|
|
|
|
if denoise_batch:
|
|
images = torch.stack([x['image'] for x in denoise_batch])
|
|
targets = torch.stack([x['target'] for x in denoise_batch])
|
|
batch_data['denoise'] = (images, targets)
|
|
|
|
if rotate_batch:
|
|
images = torch.stack([x['image'] for x in rotate_batch])
|
|
targets = torch.stack([x['target'] for x in rotate_batch])
|
|
batch_data['rotate'] = (images, targets)
|
|
|
|
return batch_data
|
|
|
|
def _process_batch(self, batch_data, criterion_denoise, criterion_rotate, training=True):
|
|
"""Process a single batch of data."""
|
|
batch_loss = 0
|
|
|
|
if batch_data['denoise'] is not None:
|
|
noisy_imgs, targets = batch_data['denoise']
|
|
noisy_imgs = noisy_imgs.to(self.device)
|
|
targets = targets.to(self.device)
|
|
|
|
features = self.model(noisy_imgs)
|
|
outputs = self.projection_head(features, task='denoise')
|
|
loss = criterion_denoise(outputs, targets)
|
|
batch_loss += loss
|
|
|
|
if batch_data['rotate'] is not None:
|
|
imgs, targets = batch_data['rotate']
|
|
imgs = imgs.to(self.device)
|
|
targets = targets.long().to(self.device)
|
|
|
|
features = self.model(imgs)
|
|
outputs = self.projection_head(features, task='rotate')
|
|
loss = criterion_rotate(outputs, targets)
|
|
batch_loss += loss
|
|
|
|
return batch_loss
|
|
|
|
def _save_checkpoint(self, checkpoint_dir, epoch, batch_count, checkpoint_name):
|
|
"""Save a model checkpoint.
|
|
|
|
Args:
|
|
checkpoint_dir (str): Directory to save the checkpoint
|
|
epoch (int): Current epoch number
|
|
batch_count (int): Current batch count
|
|
checkpoint_name (str): Name for the checkpoint file
|
|
|
|
Returns:
|
|
str: Path to the saved checkpoint
|
|
"""
|
|
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
|
|
checkpoint_data = {
|
|
'epoch': epoch + 1,
|
|
'batch': batch_count,
|
|
'model_state_dict': self.model.state_dict(),
|
|
'projection_head_state_dict': self.projection_head.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'train_losses': self.train_losses,
|
|
'val_losses': self.val_losses,
|
|
}
|
|
torch.save(checkpoint_data, checkpoint_path)
|
|
return checkpoint_path
|
|
|
|
def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None):
|
|
"""Check for checkpoints and load if available.
|
|
|
|
Args:
|
|
checkpoint_dir (str): Directory where checkpoints are stored
|
|
specific_checkpoint (str, optional): Specific checkpoint file to load. If None, loads the most recent.
|
|
|
|
Returns:
|
|
tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise
|
|
"""
|
|
# Create checkpoint directory if it doesn't exist
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
# If a specific checkpoint is requested
|
|
if specific_checkpoint:
|
|
checkpoint_path = os.path.join(checkpoint_dir, specific_checkpoint)
|
|
if os.path.exists(checkpoint_path):
|
|
return self._load_checkpoint_file(checkpoint_path)
|
|
else:
|
|
print(f"Specified checkpoint {specific_checkpoint} not found.")
|
|
return None
|
|
|
|
# Find all checkpoint files
|
|
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pt")]
|
|
|
|
if not checkpoint_files:
|
|
print("No checkpoints found in directory.")
|
|
return None
|
|
|
|
# Find the most recent checkpoint
|
|
checkpoint_files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
|
|
most_recent = checkpoint_files[0]
|
|
checkpoint_path = os.path.join(checkpoint_dir, most_recent)
|
|
|
|
return self._load_checkpoint_file(checkpoint_path)
|
|
|
|
def _load_checkpoint_file(self, checkpoint_path):
|
|
"""Load a specific checkpoint file.
|
|
|
|
Args:
|
|
checkpoint_path (str): Path to the checkpoint file
|
|
|
|
Returns:
|
|
tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise
|
|
"""
|
|
try:
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
|
|
# Load model state
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
# Load projection head state
|
|
self.projection_head.load_state_dict(checkpoint['projection_head_state_dict'])
|
|
|
|
# Load optimizer state
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
|
# Load loss history
|
|
self.train_losses = checkpoint.get('train_losses', [])
|
|
self.val_losses = checkpoint.get('val_losses', [])
|
|
|
|
loaded_epoch = checkpoint['epoch']
|
|
loaded_batch = checkpoint['batch']
|
|
|
|
print(f"Checkpoint loaded from {checkpoint_path}")
|
|
print(f"Resuming from epoch {loaded_epoch}, batch {loaded_batch}")
|
|
|
|
return loaded_epoch, loaded_batch
|
|
|
|
except Exception as e:
|
|
print(f"Error loading checkpoint: {e}")
|
|
return None
|
|
|
|
def train(self, dataset_paths, output_path="AIIA", column="image_bytes",
|
|
num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None):
|
|
"""Train the model using multiple specified datasets with checkpoint resumption support."""
|
|
if not dataset_paths:
|
|
raise ValueError("No dataset paths provided")
|
|
|
|
self.checkpoint_dir = checkpoint_dir # Set checkpoint_dir class variable
|
|
self._initialize_checkpoint_variables()
|
|
start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir)
|
|
|
|
dataframes = self._load_and_merge_datasets(dataset_paths, sample_size)
|
|
aiia_loader = self._initialize_data_loader(dataframes, column, batch_size)
|
|
|
|
criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions()
|
|
|
|
for epoch in range(start_epoch, num_epochs):
|
|
self.current_epoch = epoch # Update current_epoch
|
|
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
|
print("-" * 20)
|
|
total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader,
|
|
start_batch if (epoch == start_epoch and resume_training) else 0,
|
|
criterion_denoise,
|
|
criterion_rotate)
|
|
|
|
avg_train_loss = total_train_loss / max(batch_count, 1)
|
|
self.train_losses.append(avg_train_loss)
|
|
print(f"Training Loss: {avg_train_loss:.4f}")
|
|
|
|
val_loss = self._validation_phase(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
|
|
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
self.model.save_pretrained(output_path)
|
|
print("Best model saved!")
|
|
|
|
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
|
|
self.save_losses(losses_path)
|
|
|
|
def _initialize_checkpoint_variables(self):
|
|
"""Initialize checkpoint tracking variables."""
|
|
self.last_checkpoint_time = time.time()
|
|
self.checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds
|
|
self.last_22_date = None
|
|
self.recent_checkpoints = []
|
|
|
|
def _load_checkpoints(self, checkpoint_dir):
|
|
"""Load checkpoints and return start epoch, batch, and resumption flag."""
|
|
start_epoch = 0
|
|
start_batch = 0
|
|
resume_training = False
|
|
|
|
if checkpoint_dir is not None:
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
checkpoint_info = self.load_checkpoint(checkpoint_dir)
|
|
if checkpoint_info:
|
|
start_epoch, start_batch = checkpoint_info
|
|
resume_training = True
|
|
# Adjust epoch to be 0-indexed for the loop
|
|
start_epoch -= 1
|
|
|
|
return start_epoch, start_batch, resume_training
|
|
|
|
def _load_and_merge_datasets(self, dataset_paths, sample_size):
|
|
"""Load and merge datasets."""
|
|
dataframes = []
|
|
for path in dataset_paths:
|
|
try:
|
|
df = pd.read_parquet(path).head(sample_size)
|
|
dataframes.append(df)
|
|
except Exception as e:
|
|
print(f"Error loading dataset {path}: {e}")
|
|
|
|
if not dataframes:
|
|
raise ValueError("No valid datasets could be loaded")
|
|
|
|
return pd.concat(dataframes, ignore_index=True)
|
|
|
|
def _initialize_data_loader(self, merged_df, column, batch_size):
|
|
"""Initialize the data loader."""
|
|
return AIIADataLoader(
|
|
merged_df,
|
|
column=column,
|
|
batch_size=batch_size,
|
|
pretraining=True,
|
|
collate_fn=self.safe_collate
|
|
)
|
|
|
|
def _initialize_loss_functions(self):
|
|
"""Initialize loss functions and tracking variables."""
|
|
criterion_denoise = nn.MSELoss()
|
|
criterion_rotate = nn.CrossEntropyLoss()
|
|
best_val_loss = float('inf')
|
|
return criterion_denoise, criterion_rotate, best_val_loss
|
|
|
|
def _training_phase(self, train_loader, skip_batches, criterion_denoise, criterion_rotate):
|
|
"""Handle the training phase."""
|
|
self.model.train()
|
|
self.projection_head.train()
|
|
total_train_loss = 0.0
|
|
batch_count = 0
|
|
|
|
train_batches = list(enumerate(train_loader))
|
|
for i, batch_data in tqdm(train_batches[skip_batches:],
|
|
initial=skip_batches,
|
|
total=len(train_batches)):
|
|
if batch_data is None:
|
|
continue
|
|
|
|
current_batch = i + 1
|
|
self._handle_checkpoints(current_batch)
|
|
|
|
self.optimizer.zero_grad()
|
|
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
|
|
|
if batch_loss > 0:
|
|
batch_loss.backward()
|
|
self.optimizer.step()
|
|
total_train_loss += batch_loss.item()
|
|
batch_count += 1
|
|
|
|
return total_train_loss, batch_count
|
|
|
|
def _handle_checkpoints(self, current_batch):
|
|
"""Handle checkpoint saving logic."""
|
|
current_time = time.time()
|
|
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time
|
|
today = current_dt.date()
|
|
|
|
if self.checkpoint_dir and (current_time - self.last_checkpoint_time) >= self.checkpoint_interval:
|
|
checkpoint_name = f"checkpoint_epoch{self.current_epoch+1}_batch{current_batch}.pt"
|
|
checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
|
|
|
|
# Track and maintain only 3 recent checkpoints
|
|
self.recent_checkpoints.append(checkpoint_path)
|
|
if len(self.recent_checkpoints) > 3:
|
|
oldest = self.recent_checkpoints.pop(0)
|
|
if os.path.exists(oldest):
|
|
os.remove(oldest)
|
|
|
|
self.last_checkpoint_time = current_time
|
|
print(f"Checkpoint saved at {checkpoint_path}")
|
|
|
|
# Special 22:00 checkpoint (considering it's currently 10:15 PM)
|
|
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
|
|
|
|
if self.checkpoint_dir and is_22_oclock and self.last_22_date != today:
|
|
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
|
|
checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
|
|
self.last_22_date = today
|
|
print(f"22:00 Checkpoint saved at {checkpoint_path}")
|
|
|
|
def _validation_phase(self, val_loader, criterion_denoise, criterion_rotate):
|
|
"""Handle the validation phase."""
|
|
self.model.eval()
|
|
self.projection_head.eval()
|
|
return self._validate(val_loader, criterion_denoise, criterion_rotate)
|
|
|
|
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
|
|
"""Perform validation and return average validation loss."""
|
|
val_loss = 0.0
|
|
val_batch_count = 0
|
|
|
|
with torch.no_grad():
|
|
for batch_data in val_loader:
|
|
if batch_data is None:
|
|
continue
|
|
|
|
batch_loss = self._process_batch(
|
|
batch_data, criterion_denoise, criterion_rotate, training=False
|
|
)
|
|
|
|
if batch_loss > 0:
|
|
val_loss += batch_loss.item()
|
|
val_batch_count += 1
|
|
|
|
avg_val_loss = val_loss / max(val_batch_count, 1)
|
|
self.val_losses.append(avg_val_loss)
|
|
print(f"Validation Loss: {avg_val_loss:.4f}")
|
|
return avg_val_loss
|
|
|
|
def save_losses(self, csv_file):
|
|
"""Save training and validation losses to a CSV file."""
|
|
data = list(zip(
|
|
range(1, len(self.train_losses) + 1),
|
|
self.train_losses,
|
|
self.val_losses
|
|
))
|
|
|
|
with open(csv_file, mode='w', newline='') as file:
|
|
writer = csv.writer(file)
|
|
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
|
|
writer.writerows(data)
|
|
print(f"Loss data has been written to {csv_file}") |