feat/energy_efficenty #38

Merged
Fabel merged 6 commits from feat/energy_efficenty into develop 2025-04-17 10:52:25 +00:00
1 changed files with 53 additions and 30 deletions
Showing only changes of commit 47b42c3ab3 - Show all commits

View File

@ -114,9 +114,55 @@ class Pretrainer:
return batch_loss
def _save_checkpoint(self, checkpoint_dir, epoch, batch_count, checkpoint_name):
"""Save a model checkpoint.
Args:
checkpoint_dir (str): Directory to save the checkpoint
epoch (int): Current epoch number
batch_count (int): Current batch count
checkpoint_name (str): Name for the checkpoint file
Returns:
str: Path to the saved checkpoint
"""
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
checkpoint_data = {
'epoch': epoch + 1,
'batch': batch_count,
'model_state_dict': self.model.state_dict(),
'projection_head_state_dict': self.projection_head.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_losses': self.train_losses,
'val_losses': self.val_losses,
}
torch.save(checkpoint_data, checkpoint_path)
return checkpoint_path
def train(self, dataset_paths, output_path="AIIA", column="image_bytes",
num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None):
"""Train the model using multiple specified datasets."""
"""Train the model using multiple specified datasets.
Args:
dataset_paths (List[str]): List of paths to parquet dataset files
output_path (str, optional): Path to save the trained model. Defaults to "AIIA".
column (str, optional): Column name containing image data. Defaults to "image_bytes".
num_epochs (int, optional): Number of training epochs. Defaults to 3.
batch_size (int, optional): Size of training batches. Defaults to 2.
sample_size (int, optional): Number of samples to use from each dataset. Defaults to 10000.
checkpoint_dir (str, optional): Directory to save checkpoints. If None, no checkpoints are saved.
Raises:
ValueError: If no dataset paths are provided or if no valid datasets could be loaded.
The function performs the following:
1. Loads and merges multiple parquet datasets
2. Trains the model using denoising and rotation tasks
3. Validates the model performance
4. Saves checkpoints at regular intervals (every 2 hours) and at 22:00
5. Maintains only the 3 most recent regular checkpoints
6. Saves the best model based on validation loss
"""
if not dataset_paths:
raise ValueError("No dataset paths provided")
@ -129,7 +175,6 @@ class Pretrainer:
# Create checkpoint directory if specified
if checkpoint_dir is not None:
os.makedirs(checkpoint_dir, exist_ok=True)
# Read and merge all datasets
dataframes = []
for path in dataset_paths:
@ -178,19 +223,8 @@ class Pretrainer:
# Regular 2-hour checkpoint
if checkpoint_dir and (current_time - last_checkpoint_time) >= checkpoint_interval:
checkpoint_path = os.path.join(
checkpoint_dir,
f"checkpoint_epoch{epoch+1}_batch{batch_count}.pt"
)
torch.save({
'epoch': epoch + 1,
'batch': batch_count,
'model_state_dict': self.model.state_dict(),
'projection_head_state_dict': self.projection_head.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_losses': self.train_losses,
'val_losses': self.val_losses,
}, checkpoint_path)
checkpoint_name = f"checkpoint_epoch{epoch+1}_batch{batch_count}.pt"
checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name)
# Track and maintain only 3 recent checkpoints
recent_checkpoints.append(checkpoint_path)
@ -206,23 +240,12 @@ class Pretrainer:
is_22_oclock = current_dt.hour == 22 and current_dt.minute == 0 and current_dt.second < 10
if checkpoint_dir and is_22_oclock and last_22_date != today:
checkpoint_path = os.path.join(
checkpoint_dir,
f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
)
torch.save({
'epoch': epoch + 1,
'batch': batch_count,
'model_state_dict': self.model.state_dict(),
'projection_head_state_dict': self.projection_head.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_losses': self.train_losses,
'val_losses': self.val_losses,
}, checkpoint_path)
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name)
last_22_date = today
print(f"22:00 Checkpoint saved at {checkpoint_path}")
# Process the batch
self.optimizer.zero_grad()
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)