feat/energy_efficenty #38
|
@ -139,43 +139,110 @@ class Pretrainer:
|
||||||
torch.save(checkpoint_data, checkpoint_path)
|
torch.save(checkpoint_data, checkpoint_path)
|
||||||
return checkpoint_path
|
return checkpoint_path
|
||||||
|
|
||||||
def train(self, dataset_paths, output_path="AIIA", column="image_bytes",
|
def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None):
|
||||||
num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None):
|
"""
|
||||||
"""Train the model using multiple specified datasets.
|
Check for checkpoints and load if available.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_paths (List[str]): List of paths to parquet dataset files
|
checkpoint_dir (str): Directory where checkpoints are stored
|
||||||
output_path (str, optional): Path to save the trained model. Defaults to "AIIA".
|
specific_checkpoint (str, optional): Specific checkpoint file to load. If None, loads the most recent.
|
||||||
column (str, optional): Column name containing image data. Defaults to "image_bytes".
|
|
||||||
num_epochs (int, optional): Number of training epochs. Defaults to 3.
|
|
||||||
batch_size (int, optional): Size of training batches. Defaults to 2.
|
|
||||||
sample_size (int, optional): Number of samples to use from each dataset. Defaults to 10000.
|
|
||||||
checkpoint_dir (str, optional): Directory to save checkpoints. If None, no checkpoints are saved.
|
|
||||||
|
|
||||||
Raises:
|
Returns:
|
||||||
ValueError: If no dataset paths are provided or if no valid datasets could be loaded.
|
tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise
|
||||||
|
|
||||||
The function performs the following:
|
|
||||||
1. Loads and merges multiple parquet datasets
|
|
||||||
2. Trains the model using denoising and rotation tasks
|
|
||||||
3. Validates the model performance
|
|
||||||
4. Saves checkpoints at regular intervals (every 2 hours) and at 22:00
|
|
||||||
5. Maintains only the 3 most recent regular checkpoints
|
|
||||||
6. Saves the best model based on validation loss
|
|
||||||
"""
|
"""
|
||||||
|
# Create checkpoint directory if it doesn't exist
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# If a specific checkpoint is requested
|
||||||
|
if specific_checkpoint:
|
||||||
|
checkpoint_path = os.path.join(checkpoint_dir, specific_checkpoint)
|
||||||
|
if os.path.exists(checkpoint_path):
|
||||||
|
return self._load_checkpoint_file(checkpoint_path)
|
||||||
|
else:
|
||||||
|
print(f"Specified checkpoint {specific_checkpoint} not found.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find all checkpoint files
|
||||||
|
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pt")]
|
||||||
|
|
||||||
|
if not checkpoint_files:
|
||||||
|
print("No checkpoints found in directory.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find the most recent checkpoint
|
||||||
|
checkpoint_files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
|
||||||
|
most_recent = checkpoint_files[0]
|
||||||
|
checkpoint_path = os.path.join(checkpoint_dir, most_recent)
|
||||||
|
|
||||||
|
return self._load_checkpoint_file(checkpoint_path)
|
||||||
|
|
||||||
|
def _load_checkpoint_file(self, checkpoint_path):
|
||||||
|
"""
|
||||||
|
Load a specific checkpoint file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path (str): Path to the checkpoint file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||||
|
|
||||||
|
# Load model state
|
||||||
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
|
||||||
|
# Load projection head state
|
||||||
|
self.projection_head.load_state_dict(checkpoint['projection_head_state_dict'])
|
||||||
|
|
||||||
|
# Load optimizer state
|
||||||
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
|
||||||
|
# Load loss history
|
||||||
|
self.train_losses = checkpoint.get('train_losses', [])
|
||||||
|
self.val_losses = checkpoint.get('val_losses', [])
|
||||||
|
|
||||||
|
loaded_epoch = checkpoint['epoch']
|
||||||
|
loaded_batch = checkpoint['batch']
|
||||||
|
|
||||||
|
print(f"Checkpoint loaded from {checkpoint_path}")
|
||||||
|
print(f"Resuming from epoch {loaded_epoch}, batch {loaded_batch}")
|
||||||
|
|
||||||
|
return loaded_epoch, loaded_batch
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading checkpoint: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def train(self, dataset_paths, output_path="AIIA", column="image_bytes",
|
||||||
|
num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None):
|
||||||
|
"""Train the model using multiple specified datasets with checkpoint resumption support."""
|
||||||
if not dataset_paths:
|
if not dataset_paths:
|
||||||
raise ValueError("No dataset paths provided")
|
raise ValueError("No dataset paths provided")
|
||||||
|
|
||||||
# Checkpoint tracking variables
|
# Initialize checkpoint tracking variables
|
||||||
last_checkpoint_time = time.time()
|
last_checkpoint_time = time.time()
|
||||||
checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds
|
checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds
|
||||||
last_22_date = None
|
last_22_date = None
|
||||||
recent_checkpoints = []
|
recent_checkpoints = []
|
||||||
|
|
||||||
# Create checkpoint directory if specified
|
# Initialize resumption variables
|
||||||
|
start_epoch = 0
|
||||||
|
start_batch = 0
|
||||||
|
resume_training = False
|
||||||
|
|
||||||
|
# Check for existing checkpoint and load if available
|
||||||
if checkpoint_dir is not None:
|
if checkpoint_dir is not None:
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
# Read and merge all datasets
|
checkpoint_info = self.load_checkpoint(checkpoint_dir)
|
||||||
|
if checkpoint_info:
|
||||||
|
start_epoch, start_batch = checkpoint_info
|
||||||
|
resume_training = True
|
||||||
|
# Adjust epoch to be 0-indexed for the loop
|
||||||
|
start_epoch -= 1
|
||||||
|
|
||||||
|
# Load and merge datasets
|
||||||
dataframes = []
|
dataframes = []
|
||||||
for path in dataset_paths:
|
for path in dataset_paths:
|
||||||
try:
|
try:
|
||||||
|
@ -198,11 +265,13 @@ class Pretrainer:
|
||||||
collate_fn=self.safe_collate
|
collate_fn=self.safe_collate
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize loss functions and tracking variables
|
||||||
criterion_denoise = nn.MSELoss()
|
criterion_denoise = nn.MSELoss()
|
||||||
criterion_rotate = nn.CrossEntropyLoss()
|
criterion_rotate = nn.CrossEntropyLoss()
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
# Main training loop
|
||||||
|
for epoch in range(start_epoch, num_epochs):
|
||||||
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||||
print("-" * 20)
|
print("-" * 20)
|
||||||
|
|
||||||
|
@ -212,10 +281,22 @@ class Pretrainer:
|
||||||
total_train_loss = 0.0
|
total_train_loss = 0.0
|
||||||
batch_count = 0
|
batch_count = 0
|
||||||
|
|
||||||
for batch_data in tqdm(aiia_loader.train_loader):
|
# Convert data loader to enumerated list for batch tracking and resumption
|
||||||
|
train_batches = list(enumerate(aiia_loader.train_loader))
|
||||||
|
|
||||||
|
# Determine how many batches to skip if resuming from checkpoint
|
||||||
|
skip_batches = start_batch if (epoch == start_epoch and resume_training) else 0
|
||||||
|
|
||||||
|
# Process batches with proper resumption handling
|
||||||
|
for i, batch_data in tqdm(train_batches[skip_batches:],
|
||||||
|
initial=skip_batches,
|
||||||
|
total=len(train_batches)):
|
||||||
if batch_data is None:
|
if batch_data is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Use i+1 as the actual batch count (to match 1-indexed batch numbers in checkpoints)
|
||||||
|
current_batch = i + 1
|
||||||
|
|
||||||
# Check if we need to save a checkpoint
|
# Check if we need to save a checkpoint
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time
|
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time
|
||||||
|
@ -223,8 +304,8 @@ class Pretrainer:
|
||||||
|
|
||||||
# Regular 2-hour checkpoint
|
# Regular 2-hour checkpoint
|
||||||
if checkpoint_dir and (current_time - last_checkpoint_time) >= checkpoint_interval:
|
if checkpoint_dir and (current_time - last_checkpoint_time) >= checkpoint_interval:
|
||||||
checkpoint_name = f"checkpoint_epoch{epoch+1}_batch{batch_count}.pt"
|
checkpoint_name = f"checkpoint_epoch{epoch+1}_batch{current_batch}.pt"
|
||||||
checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name)
|
checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, current_batch, checkpoint_name)
|
||||||
|
|
||||||
# Track and maintain only 3 recent checkpoints
|
# Track and maintain only 3 recent checkpoints
|
||||||
recent_checkpoints.append(checkpoint_path)
|
recent_checkpoints.append(checkpoint_path)
|
||||||
|
@ -236,16 +317,15 @@ class Pretrainer:
|
||||||
last_checkpoint_time = current_time
|
last_checkpoint_time = current_time
|
||||||
print(f"Checkpoint saved at {checkpoint_path}")
|
print(f"Checkpoint saved at {checkpoint_path}")
|
||||||
|
|
||||||
# Special 22:00 checkpoint
|
# Special 22:00 checkpoint (considering it's currently 10:15 PM)
|
||||||
is_22_oclock = current_dt.hour == 22 and current_dt.minute == 0 and current_dt.second < 10
|
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
|
||||||
|
|
||||||
if checkpoint_dir and is_22_oclock and last_22_date != today:
|
if checkpoint_dir and is_22_oclock and last_22_date != today:
|
||||||
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
|
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
|
||||||
checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name)
|
checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, current_batch, checkpoint_name)
|
||||||
last_22_date = today
|
last_22_date = today
|
||||||
print(f"22:00 Checkpoint saved at {checkpoint_path}")
|
print(f"22:00 Checkpoint saved at {checkpoint_path}")
|
||||||
|
|
||||||
|
|
||||||
# Process the batch
|
# Process the batch
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
||||||
|
@ -256,6 +336,12 @@ class Pretrainer:
|
||||||
total_train_loss += batch_loss.item()
|
total_train_loss += batch_loss.item()
|
||||||
batch_count += 1
|
batch_count += 1
|
||||||
|
|
||||||
|
# Reset batch skipping after completing the resumed epoch
|
||||||
|
if resume_training and epoch == start_epoch:
|
||||||
|
resume_training = False
|
||||||
|
start_batch = 0
|
||||||
|
|
||||||
|
# Calculate and store training loss
|
||||||
avg_train_loss = total_train_loss / max(batch_count, 1)
|
avg_train_loss = total_train_loss / max(batch_count, 1)
|
||||||
self.train_losses.append(avg_train_loss)
|
self.train_losses.append(avg_train_loss)
|
||||||
print(f"Training Loss: {avg_train_loss:.4f}")
|
print(f"Training Loss: {avg_train_loss:.4f}")
|
||||||
|
@ -265,11 +351,13 @@ class Pretrainer:
|
||||||
self.projection_head.eval()
|
self.projection_head.eval()
|
||||||
val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
|
val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
|
||||||
|
|
||||||
|
# Save best model based on validation loss
|
||||||
if val_loss < best_val_loss:
|
if val_loss < best_val_loss:
|
||||||
best_val_loss = val_loss
|
best_val_loss = val_loss
|
||||||
self.model.save(output_path)
|
self.model.save(output_path)
|
||||||
print("Best model saved!")
|
print("Best model saved!")
|
||||||
|
|
||||||
|
# Save training history
|
||||||
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
|
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
|
||||||
self.save_losses(losses_path)
|
self.save_losses(losses_path)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue