feat/energy_efficenty #38
|
@ -1,6 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import csv
|
import csv
|
||||||
|
import datetime
|
||||||
|
import time
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from ..model.Model import AIIA
|
from ..model.Model import AIIA
|
||||||
|
@ -112,19 +114,22 @@ class Pretrainer:
|
||||||
|
|
||||||
return batch_loss
|
return batch_loss
|
||||||
|
|
||||||
def train(self, dataset_paths,output_path:str="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000):
|
def train(self, dataset_paths, output_path="AIIA", column="image_bytes",
|
||||||
"""
|
num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None):
|
||||||
Train the model using multiple specified datasets.
|
"""Train the model using multiple specified datasets."""
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_paths (list): List of paths to parquet datasets
|
|
||||||
num_epochs (int): Number of training epochs
|
|
||||||
batch_size (int): Batch size for training
|
|
||||||
sample_size (int): Number of samples to use from each dataset
|
|
||||||
"""
|
|
||||||
if not dataset_paths:
|
if not dataset_paths:
|
||||||
raise ValueError("No dataset paths provided")
|
raise ValueError("No dataset paths provided")
|
||||||
|
|
||||||
|
# Checkpoint tracking variables
|
||||||
|
last_checkpoint_time = time.time()
|
||||||
|
checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds
|
||||||
|
last_22_date = None
|
||||||
|
recent_checkpoints = []
|
||||||
|
|
||||||
|
# Create checkpoint directory if specified
|
||||||
|
if checkpoint_dir is not None:
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
# Read and merge all datasets
|
# Read and merge all datasets
|
||||||
dataframes = []
|
dataframes = []
|
||||||
for path in dataset_paths:
|
for path in dataset_paths:
|
||||||
|
@ -166,6 +171,59 @@ class Pretrainer:
|
||||||
if batch_data is None:
|
if batch_data is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Check if we need to save a checkpoint
|
||||||
|
current_time = time.time()
|
||||||
|
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time
|
||||||
|
today = current_dt.date()
|
||||||
|
|
||||||
|
# Regular 2-hour checkpoint
|
||||||
|
if checkpoint_dir and (current_time - last_checkpoint_time) >= checkpoint_interval:
|
||||||
|
checkpoint_path = os.path.join(
|
||||||
|
checkpoint_dir,
|
||||||
|
f"checkpoint_epoch{epoch+1}_batch{batch_count}.pt"
|
||||||
|
)
|
||||||
|
torch.save({
|
||||||
|
'epoch': epoch + 1,
|
||||||
|
'batch': batch_count,
|
||||||
|
'model_state_dict': self.model.state_dict(),
|
||||||
|
'projection_head_state_dict': self.projection_head.state_dict(),
|
||||||
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||||
|
'train_losses': self.train_losses,
|
||||||
|
'val_losses': self.val_losses,
|
||||||
|
}, checkpoint_path)
|
||||||
|
|
||||||
|
# Track and maintain only 3 recent checkpoints
|
||||||
|
recent_checkpoints.append(checkpoint_path)
|
||||||
|
if len(recent_checkpoints) > 3:
|
||||||
|
oldest = recent_checkpoints.pop(0)
|
||||||
|
if os.path.exists(oldest):
|
||||||
|
os.remove(oldest)
|
||||||
|
|
||||||
|
last_checkpoint_time = current_time
|
||||||
|
print(f"Checkpoint saved at {checkpoint_path}")
|
||||||
|
|
||||||
|
# Special 22:00 checkpoint
|
||||||
|
is_22_oclock = current_dt.hour == 22 and current_dt.minute == 0 and current_dt.second < 10
|
||||||
|
|
||||||
|
if checkpoint_dir and is_22_oclock and last_22_date != today:
|
||||||
|
checkpoint_path = os.path.join(
|
||||||
|
checkpoint_dir,
|
||||||
|
f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
|
||||||
|
)
|
||||||
|
torch.save({
|
||||||
|
'epoch': epoch + 1,
|
||||||
|
'batch': batch_count,
|
||||||
|
'model_state_dict': self.model.state_dict(),
|
||||||
|
'projection_head_state_dict': self.projection_head.state_dict(),
|
||||||
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||||
|
'train_losses': self.train_losses,
|
||||||
|
'val_losses': self.val_losses,
|
||||||
|
}, checkpoint_path)
|
||||||
|
|
||||||
|
last_22_date = today
|
||||||
|
print(f"22:00 Checkpoint saved at {checkpoint_path}")
|
||||||
|
|
||||||
|
# Process the batch
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
|
||||||
|
|
||||||
|
@ -192,6 +250,7 @@ class Pretrainer:
|
||||||
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
|
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
|
||||||
self.save_losses(losses_path)
|
self.save_losses(losses_path)
|
||||||
|
|
||||||
|
|
||||||
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
|
def _validate(self, val_loader, criterion_denoise, criterion_rotate):
|
||||||
"""Perform validation and return average validation loss."""
|
"""Perform validation and return average validation loss."""
|
||||||
val_loss = 0.0
|
val_loss = 0.0
|
||||||
|
|
Loading…
Reference in New Issue