develop #41
|
@ -221,18 +221,49 @@ class Pretrainer:
|
||||||
if not dataset_paths:
|
if not dataset_paths:
|
||||||
raise ValueError("No dataset paths provided")
|
raise ValueError("No dataset paths provided")
|
||||||
|
|
||||||
# Initialize checkpoint tracking variables
|
self._initialize_checkpoint_variables()
|
||||||
last_checkpoint_time = time.time()
|
start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir)
|
||||||
checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds
|
|
||||||
last_22_date = None
|
|
||||||
recent_checkpoints = []
|
|
||||||
|
|
||||||
# Initialize resumption variables
|
dataframes = self._load_and_merge_datasets(dataset_paths, sample_size)
|
||||||
|
aiia_loader = self._initialize_data_loader(dataframes, column, batch_size)
|
||||||
|
|
||||||
|
criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions()
|
||||||
|
|
||||||
|
for epoch in range(start_epoch, num_epochs):
|
||||||
|
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||||
|
print("-" * 20)
|
||||||
|
total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader,
|
||||||
|
start_batch if (epoch == start_epoch and resume_training) else 0,
|
||||||
|
criterion_denoise,
|
||||||
|
criterion_rotate)
|
||||||
|
|
||||||
|
avg_train_loss = total_train_loss / max(batch_count, 1)
|
||||||
|
self.train_losses.append(avg_train_loss)
|
||||||
|
print(f"Training Loss: {avg_train_loss:.4f}")
|
||||||
|
|
||||||
|
val_loss = self._validation_phase(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
|
||||||
|
|
||||||
|
if val_loss < best_val_loss:
|
||||||
|
best_val_loss = val_loss
|
||||||
|
self.model.save(output_path)
|
||||||
|
print("Best model saved!")
|
||||||
|
|
||||||
|
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
|
||||||
|
self.save_losses(losses_path)
|
||||||
|
|
||||||
|
def _initialize_checkpoint_variables(self):
|
||||||
|
"""Initialize checkpoint tracking variables."""
|
||||||
|
self.last_checkpoint_time = time.time()
|
||||||
|
self.checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds
|
||||||
|
self.last_22_date = None
|
||||||
|
self.recent_checkpoints = []
|
||||||
|
|
||||||
|
def _load_checkpoints(self, checkpoint_dir):
|
||||||
|
"""Load checkpoints and return start epoch, batch, and resumption flag."""
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
start_batch = 0
|
start_batch = 0
|
||||||
resume_training = False
|
resume_training = False
|
||||||
|
|
||||||
# Check for existing checkpoint and load if available
|
|
||||||
if checkpoint_dir is not None:
|
if checkpoint_dir is not None:
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
checkpoint_info = self.load_checkpoint(checkpoint_dir)
|
checkpoint_info = self.load_checkpoint(checkpoint_dir)
|
||||||
|
@ -242,7 +273,10 @@ class Pretrainer:
|
||||||
# Adjust epoch to be 0-indexed for the loop
|
# Adjust epoch to be 0-indexed for the loop
|
||||||
start_epoch -= 1
|
start_epoch -= 1
|
||||||
|
|
||||||
# Load and merge datasets
|
return start_epoch, start_batch, resume_training
|
||||||
|
|
||||||
|
def _load_and_merge_datasets(self, dataset_paths, sample_size):
|
||||||
|
"""Load and merge datasets."""
|
||||||
dataframes = []
|
dataframes = []
|
||||||
for path in dataset_paths:
|
for path in dataset_paths:
|
||||||
try:
|
try:
|
||||||
|
@ -254,10 +288,11 @@ class Pretrainer:
|
||||||
if not dataframes:
|
if not dataframes:
|
||||||
raise ValueError("No valid datasets could be loaded")
|
raise ValueError("No valid datasets could be loaded")
|
||||||
|
|
||||||
merged_df = pd.concat(dataframes, ignore_index=True)
|
return pd.concat(dataframes, ignore_index=True)
|
||||||
|
|
||||||
# Initialize data loader
|
def _initialize_data_loader(self, merged_df, column, batch_size):
|
||||||
aiia_loader = AIIADataLoader(
|
"""Initialize the data loader."""
|
||||||
|
return AIIADataLoader(
|
||||||
merged_df,
|
merged_df,
|
||||||
column=column,
|
column=column,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
@ -265,68 +300,30 @@ class Pretrainer:
|
||||||
collate_fn=self.safe_collate
|
collate_fn=self.safe_collate
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize loss functions and tracking variables
|
def _initialize_loss_functions(self):
|
||||||
|
"""Initialize loss functions and tracking variables."""
|
||||||
criterion_denoise = nn.MSELoss()
|
criterion_denoise = nn.MSELoss()
|
||||||
criterion_rotate = nn.CrossEntropyLoss()
|
criterion_rotate = nn.CrossEntropyLoss()
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
|
return criterion_denoise, criterion_rotate, best_val_loss
|
||||||
|
|
||||||
# Main training loop
|
def _training_phase(self, train_loader, skip_batches, criterion_denoise, criterion_rotate):
|
||||||
for epoch in range(start_epoch, num_epochs):
|
"""Handle the training phase."""
|
||||||
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
|
||||||
print("-" * 20)
|
|
||||||
|
|
||||||
# Training phase
|
|
||||||
self.model.train()
|
self.model.train()
|
||||||
self.projection_head.train()
|
self.projection_head.train()
|
||||||
total_train_loss = 0.0
|
total_train_loss = 0.0
|
||||||
batch_count = 0
|
batch_count = 0
|
||||||
|
|
||||||
# Convert data loader to enumerated list for batch tracking and resumption
|
train_batches = list(enumerate(train_loader))
|
||||||
train_batches = list(enumerate(aiia_loader.train_loader))
|
|
||||||
|
|
||||||
# Determine how many batches to skip if resuming from checkpoint
|
|
||||||
skip_batches = start_batch if (epoch == start_epoch and resume_training) else 0
|
|
||||||
|
|
||||||
# Process batches with proper resumption handling
|
|
||||||
for i, batch_data in tqdm(train_batches[skip_batches:],
|
for i, batch_data in tqdm(train_batches[skip_batches:],
|
||||||
initial=skip_batches,
|
initial=skip_batches,
|
||||||
total=len(train_batches)):
|
total=len(train_batches)):
|
||||||
if batch_data is None:
|
if batch_data is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Use i+1 as the actual batch count (to match 1-indexed batch numbers in checkpoints)
|
|
||||||
current_batch = i + 1
|
current_batch = i + 1
|
||||||
|
self._handle_checkpoints(current_batch)
|
||||||
|
|
||||||
# Check if we need to save a checkpoint
|
|
||||||
current_time = time.time()
|
|
||||||
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time
|
|
||||||
today = current_dt.date()
|
|
||||||
|
|
||||||
# Regular 2-hour checkpoint
|
|
||||||
if checkpoint_dir and (current_time - last_checkpoint_time) >= checkpoint_interval:
|
|
||||||
checkpoint_name = f"checkpoint_epoch{epoch+1}_batch{current_batch}.pt"
|
|
||||||
checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, current_batch, checkpoint_name)
|
|
||||||
|
|
||||||
# Track and maintain only 3 recent checkpoints
|
|
||||||
recent_checkpoints.append(checkpoint_path)
|
|
||||||
if len(recent_checkpoints) > 3:
|
|
||||||
oldest = recent_checkpoints.pop(0)
|
|
||||||
if os.path.exists(oldest):
|
|
||||||
os.remove(oldest)
|
|
||||||
|
|
||||||
last_checkpoint_time = current_time
|
|
||||||
print(f"Checkpoint saved at {checkpoint_path}")
|
|
||||||
|
|
||||||
# Special 22:00 checkpoint (considering it's currently 10:15 PM)
|
|
||||||
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
|
|
||||||
|
|
||||||
if checkpoint_dir and is_22_oclock and last_22_date != today:
|
|
||||||
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
|
|
||||||
checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, current_batch, checkpoint_name)
|
|
||||||
last_22_date = today
|
|
||||||
print(f"22:00 Checkpoint saved at {checkpoint_path}")
|
|
||||||
|
|
||||||
# Process the batch
|
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
||||||
|
|
||||||
|
@ -336,31 +333,42 @@ class Pretrainer:
|
||||||
total_train_loss += batch_loss.item()
|
total_train_loss += batch_loss.item()
|
||||||
batch_count += 1
|
batch_count += 1
|
||||||
|
|
||||||
# Reset batch skipping after completing the resumed epoch
|
return total_train_loss, batch_count
|
||||||
if resume_training and epoch == start_epoch:
|
|
||||||
resume_training = False
|
|
||||||
start_batch = 0
|
|
||||||
|
|
||||||
# Calculate and store training loss
|
def _handle_checkpoints(self, current_batch):
|
||||||
avg_train_loss = total_train_loss / max(batch_count, 1)
|
"""Handle checkpoint saving logic."""
|
||||||
self.train_losses.append(avg_train_loss)
|
current_time = time.time()
|
||||||
print(f"Training Loss: {avg_train_loss:.4f}")
|
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time
|
||||||
|
today = current_dt.date()
|
||||||
|
|
||||||
# Validation phase
|
if self.checkpoint_dir and (current_time - self.last_checkpoint_time) >= self.checkpoint_interval:
|
||||||
|
checkpoint_name = f"checkpoint_epoch{self.current_epoch+1}_batch{current_batch}.pt"
|
||||||
|
checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
|
||||||
|
|
||||||
|
# Track and maintain only 3 recent checkpoints
|
||||||
|
self.recent_checkpoints.append(checkpoint_path)
|
||||||
|
if len(self.recent_checkpoints) > 3:
|
||||||
|
oldest = self.recent_checkpoints.pop(0)
|
||||||
|
if os.path.exists(oldest):
|
||||||
|
os.remove(oldest)
|
||||||
|
|
||||||
|
self.last_checkpoint_time = current_time
|
||||||
|
print(f"Checkpoint saved at {checkpoint_path}")
|
||||||
|
|
||||||
|
# Special 22:00 checkpoint (considering it's currently 10:15 PM)
|
||||||
|
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
|
||||||
|
|
||||||
|
if self.checkpoint_dir and is_22_oclock and self.last_22_date != today:
|
||||||
|
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
|
||||||
|
checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
|
||||||
|
self.last_22_date = today
|
||||||
|
print(f"22:00 Checkpoint saved at {checkpoint_path}")
|
||||||
|
|
||||||
|
def _validation_phase(self, val_loader, criterion_denoise, criterion_rotate):
|
||||||
|
"""Handle the validation phase."""
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.projection_head.eval()
|
self.projection_head.eval()
|
||||||
val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
|
return self._validate(val_loader, criterion_denoise, criterion_rotate)
|
||||||
|
|
||||||
# Save best model based on validation loss
|
|
||||||
if val_loss < best_val_loss:
|
|
||||||
best_val_loss = val_loss
|
|
||||||
self.model.save(output_path)
|
|
||||||
print("Best model saved!")
|
|
||||||
|
|
||||||
# Save training history
|
|
||||||
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
|
|
||||||
self.save_losses(losses_path)
|
|
||||||
|
|
||||||
|
|
||||||
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
|
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
|
||||||
"""Perform validation and return average validation loss."""
|
"""Perform validation and return average validation loss."""
|
||||||
|
|
Loading…
Reference in New Issue