develop #41

Merged
Fabel merged 27 commits from develop into main 2025-04-17 17:08:57 +00:00
1 changed files with 116 additions and 108 deletions
Showing only changes of commit 09662d6102 - Show all commits

View File

@ -221,18 +221,49 @@ class Pretrainer:
if not dataset_paths:
raise ValueError("No dataset paths provided")
# Initialize checkpoint tracking variables
last_checkpoint_time = time.time()
checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds
last_22_date = None
recent_checkpoints = []
self._initialize_checkpoint_variables()
start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir)
# Initialize resumption variables
dataframes = self._load_and_merge_datasets(dataset_paths, sample_size)
aiia_loader = self._initialize_data_loader(dataframes, column, batch_size)
criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions()
for epoch in range(start_epoch, num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader,
start_batch if (epoch == start_epoch and resume_training) else 0,
criterion_denoise,
criterion_rotate)
avg_train_loss = total_train_loss / max(batch_count, 1)
self.train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
val_loss = self._validation_phase(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
if val_loss < best_val_loss:
best_val_loss = val_loss
self.model.save(output_path)
print("Best model saved!")
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
self.save_losses(losses_path)
def _initialize_checkpoint_variables(self):
"""Initialize checkpoint tracking variables."""
self.last_checkpoint_time = time.time()
self.checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds
self.last_22_date = None
self.recent_checkpoints = []
def _load_checkpoints(self, checkpoint_dir):
"""Load checkpoints and return start epoch, batch, and resumption flag."""
start_epoch = 0
start_batch = 0
resume_training = False
# Check for existing checkpoint and load if available
if checkpoint_dir is not None:
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_info = self.load_checkpoint(checkpoint_dir)
@ -242,7 +273,10 @@ class Pretrainer:
# Adjust epoch to be 0-indexed for the loop
start_epoch -= 1
# Load and merge datasets
return start_epoch, start_batch, resume_training
def _load_and_merge_datasets(self, dataset_paths, sample_size):
"""Load and merge datasets."""
dataframes = []
for path in dataset_paths:
try:
@ -254,10 +288,11 @@ class Pretrainer:
if not dataframes:
raise ValueError("No valid datasets could be loaded")
merged_df = pd.concat(dataframes, ignore_index=True)
return pd.concat(dataframes, ignore_index=True)
# Initialize data loader
aiia_loader = AIIADataLoader(
def _initialize_data_loader(self, merged_df, column, batch_size):
"""Initialize the data loader."""
return AIIADataLoader(
merged_df,
column=column,
batch_size=batch_size,
@ -265,68 +300,30 @@ class Pretrainer:
collate_fn=self.safe_collate
)
# Initialize loss functions and tracking variables
def _initialize_loss_functions(self):
"""Initialize loss functions and tracking variables."""
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
best_val_loss = float('inf')
return criterion_denoise, criterion_rotate, best_val_loss
# Main training loop
for epoch in range(start_epoch, num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
# Training phase
def _training_phase(self, train_loader, skip_batches, criterion_denoise, criterion_rotate):
"""Handle the training phase."""
self.model.train()
self.projection_head.train()
total_train_loss = 0.0
batch_count = 0
# Convert data loader to enumerated list for batch tracking and resumption
train_batches = list(enumerate(aiia_loader.train_loader))
# Determine how many batches to skip if resuming from checkpoint
skip_batches = start_batch if (epoch == start_epoch and resume_training) else 0
# Process batches with proper resumption handling
train_batches = list(enumerate(train_loader))
for i, batch_data in tqdm(train_batches[skip_batches:],
initial=skip_batches,
total=len(train_batches)):
if batch_data is None:
continue
# Use i+1 as the actual batch count (to match 1-indexed batch numbers in checkpoints)
current_batch = i + 1
self._handle_checkpoints(current_batch)
# Check if we need to save a checkpoint
current_time = time.time()
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time
today = current_dt.date()
# Regular 2-hour checkpoint
if checkpoint_dir and (current_time - last_checkpoint_time) >= checkpoint_interval:
checkpoint_name = f"checkpoint_epoch{epoch+1}_batch{current_batch}.pt"
checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, current_batch, checkpoint_name)
# Track and maintain only 3 recent checkpoints
recent_checkpoints.append(checkpoint_path)
if len(recent_checkpoints) > 3:
oldest = recent_checkpoints.pop(0)
if os.path.exists(oldest):
os.remove(oldest)
last_checkpoint_time = current_time
print(f"Checkpoint saved at {checkpoint_path}")
# Special 22:00 checkpoint (considering it's currently 10:15 PM)
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
if checkpoint_dir and is_22_oclock and last_22_date != today:
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, current_batch, checkpoint_name)
last_22_date = today
print(f"22:00 Checkpoint saved at {checkpoint_path}")
# Process the batch
self.optimizer.zero_grad()
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
@ -336,31 +333,42 @@ class Pretrainer:
total_train_loss += batch_loss.item()
batch_count += 1
# Reset batch skipping after completing the resumed epoch
if resume_training and epoch == start_epoch:
resume_training = False
start_batch = 0
return total_train_loss, batch_count
# Calculate and store training loss
avg_train_loss = total_train_loss / max(batch_count, 1)
self.train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
def _handle_checkpoints(self, current_batch):
"""Handle checkpoint saving logic."""
current_time = time.time()
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time
today = current_dt.date()
# Validation phase
if self.checkpoint_dir and (current_time - self.last_checkpoint_time) >= self.checkpoint_interval:
checkpoint_name = f"checkpoint_epoch{self.current_epoch+1}_batch{current_batch}.pt"
checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
# Track and maintain only 3 recent checkpoints
self.recent_checkpoints.append(checkpoint_path)
if len(self.recent_checkpoints) > 3:
oldest = self.recent_checkpoints.pop(0)
if os.path.exists(oldest):
os.remove(oldest)
self.last_checkpoint_time = current_time
print(f"Checkpoint saved at {checkpoint_path}")
# Special 22:00 checkpoint (considering it's currently 10:15 PM)
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
if self.checkpoint_dir and is_22_oclock and self.last_22_date != today:
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
self.last_22_date = today
print(f"22:00 Checkpoint saved at {checkpoint_path}")
def _validation_phase(self, val_loader, criterion_denoise, criterion_rotate):
"""Handle the validation phase."""
self.model.eval()
self.projection_head.eval()
val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
# Save best model based on validation loss
if val_loss < best_val_loss:
best_val_loss = val_loss
self.model.save(output_path)
print("Best model saved!")
# Save training history
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
self.save_losses(losses_path)
return self._validate(val_loader, criterion_denoise, criterion_rotate)
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
"""Perform validation and return average validation loss."""