develop #41
|
@ -114,9 +114,55 @@ class Pretrainer:
|
|||
|
||||
return batch_loss
|
||||
|
||||
def _save_checkpoint(self, checkpoint_dir, epoch, batch_count, checkpoint_name):
|
||||
"""Save a model checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): Directory to save the checkpoint
|
||||
epoch (int): Current epoch number
|
||||
batch_count (int): Current batch count
|
||||
checkpoint_name (str): Name for the checkpoint file
|
||||
|
||||
Returns:
|
||||
str: Path to the saved checkpoint
|
||||
"""
|
||||
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
|
||||
checkpoint_data = {
|
||||
'epoch': epoch + 1,
|
||||
'batch': batch_count,
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'projection_head_state_dict': self.projection_head.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'train_losses': self.train_losses,
|
||||
'val_losses': self.val_losses,
|
||||
}
|
||||
torch.save(checkpoint_data, checkpoint_path)
|
||||
return checkpoint_path
|
||||
|
||||
def train(self, dataset_paths, output_path="AIIA", column="image_bytes",
|
||||
num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None):
|
||||
"""Train the model using multiple specified datasets."""
|
||||
"""Train the model using multiple specified datasets.
|
||||
|
||||
Args:
|
||||
dataset_paths (List[str]): List of paths to parquet dataset files
|
||||
output_path (str, optional): Path to save the trained model. Defaults to "AIIA".
|
||||
column (str, optional): Column name containing image data. Defaults to "image_bytes".
|
||||
num_epochs (int, optional): Number of training epochs. Defaults to 3.
|
||||
batch_size (int, optional): Size of training batches. Defaults to 2.
|
||||
sample_size (int, optional): Number of samples to use from each dataset. Defaults to 10000.
|
||||
checkpoint_dir (str, optional): Directory to save checkpoints. If None, no checkpoints are saved.
|
||||
|
||||
Raises:
|
||||
ValueError: If no dataset paths are provided or if no valid datasets could be loaded.
|
||||
|
||||
The function performs the following:
|
||||
1. Loads and merges multiple parquet datasets
|
||||
2. Trains the model using denoising and rotation tasks
|
||||
3. Validates the model performance
|
||||
4. Saves checkpoints at regular intervals (every 2 hours) and at 22:00
|
||||
5. Maintains only the 3 most recent regular checkpoints
|
||||
6. Saves the best model based on validation loss
|
||||
"""
|
||||
if not dataset_paths:
|
||||
raise ValueError("No dataset paths provided")
|
||||
|
||||
|
@ -129,7 +175,6 @@ class Pretrainer:
|
|||
# Create checkpoint directory if specified
|
||||
if checkpoint_dir is not None:
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Read and merge all datasets
|
||||
dataframes = []
|
||||
for path in dataset_paths:
|
||||
|
@ -178,19 +223,8 @@ class Pretrainer:
|
|||
|
||||
# Regular 2-hour checkpoint
|
||||
if checkpoint_dir and (current_time - last_checkpoint_time) >= checkpoint_interval:
|
||||
checkpoint_path = os.path.join(
|
||||
checkpoint_dir,
|
||||
f"checkpoint_epoch{epoch+1}_batch{batch_count}.pt"
|
||||
)
|
||||
torch.save({
|
||||
'epoch': epoch + 1,
|
||||
'batch': batch_count,
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'projection_head_state_dict': self.projection_head.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'train_losses': self.train_losses,
|
||||
'val_losses': self.val_losses,
|
||||
}, checkpoint_path)
|
||||
checkpoint_name = f"checkpoint_epoch{epoch+1}_batch{batch_count}.pt"
|
||||
checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name)
|
||||
|
||||
# Track and maintain only 3 recent checkpoints
|
||||
recent_checkpoints.append(checkpoint_path)
|
||||
|
@ -206,23 +240,12 @@ class Pretrainer:
|
|||
is_22_oclock = current_dt.hour == 22 and current_dt.minute == 0 and current_dt.second < 10
|
||||
|
||||
if checkpoint_dir and is_22_oclock and last_22_date != today:
|
||||
checkpoint_path = os.path.join(
|
||||
checkpoint_dir,
|
||||
f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
|
||||
)
|
||||
torch.save({
|
||||
'epoch': epoch + 1,
|
||||
'batch': batch_count,
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'projection_head_state_dict': self.projection_head.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'train_losses': self.train_losses,
|
||||
'val_losses': self.val_losses,
|
||||
}, checkpoint_path)
|
||||
|
||||
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
|
||||
checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name)
|
||||
last_22_date = today
|
||||
print(f"22:00 Checkpoint saved at {checkpoint_path}")
|
||||
|
||||
|
||||
# Process the batch
|
||||
self.optimizer.zero_grad()
|
||||
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
||||
|
|
Loading…
Reference in New Issue