Merge pull request 'feat/checkpoints' (#14) from feat/checkpoints into develop
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 39s Details

Reviewed-on: #14
This commit is contained in:
Falko Victor Habel 2025-04-20 20:44:27 +00:00
commit 96e14b9674
4 changed files with 169 additions and 83 deletions

View File

@ -26,15 +26,22 @@ pip install git+https://gitea.fabelous.app/Machine-Learning/aiuNN.git
Here's a basic example of how to use `aiuNN` for image upscaling: Here's a basic example of how to use `aiuNN` for image upscaling:
```python src/main.py ```python src/main.py
from aiia import AIIABase from aiia import AIIABase, AIIAConfig
from aiunn import aiuNN, aiuNNTrainer from aiunn import aiuNN, aiuNNTrainer
import pandas as pd import pandas as pd
from torchvision import transforms from torchvision import transforms
# Create a configuration and build a base model.
config = AIIAConfig()
ai_config = aiuNNConfig()
base_model = AIIABase(config)
upscaler = aiuNN(config=ai_config)
# Load your base model and upscaler # Load your base model and upscaler
pretrained_model_path = "path/to/aiia/model" pretrained_model_path = "path/to/aiia/model"
base_model = AIIABase.load(pretrained_model_path, precision="bf16") base_model = AIIABase.from_pretrained(pretrained_model_path)
upscaler = aiuNN(base_model) upscaler.load_base_model(base_model)
# Create trainer with your dataset class # Create trainer with your dataset class
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset) trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)

View File

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup( setup(
name="aiunn", name="aiunn",
version="0.2.0", version="0.2.1",
packages=find_packages(where="src"), packages=find_packages(where="src"),
package_dir={"": "src"}, package_dir={"": "src"},
install_requires=[ install_requires=[

View File

@ -3,4 +3,4 @@ from .upsampler.aiunn import aiuNN
from .upsampler.config import aiuNNConfig from .upsampler.config import aiuNNConfig
from .inference.inference import aiuNNInference from .inference.inference import aiuNNInference
__version__ = "0.2.0" __version__ = "0.2.1"

View File

@ -10,6 +10,7 @@ from torch.utils.checkpoint import checkpoint
import gc import gc
import time import time
import shutil import shutil
import datetime
class EarlyStopping: class EarlyStopping:
@ -50,10 +51,16 @@ class aiuNNTrainer:
self.optimizer = None self.optimizer = None
self.scaler = GradScaler() self.scaler = GradScaler()
self.best_loss = float('inf') self.best_loss = float('inf')
self.use_checkpointing = True self.csv_path = None
self.checkpoint_dir = None
self.data_loader = None self.data_loader = None
self.validation_loader = None self.validation_loader = None
self.log_dir = None self.last_checkpoint_time = time.time()
self.checkpoint_interval = 2 * 60 * 60 # 2 hours
self.last_22_date = None
self.recent_checkpoints = []
self.current_epoch = 0
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None): def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
""" """
@ -110,23 +117,19 @@ class aiuNNTrainer:
return self.data_loader, self.validation_loader return self.data_loader, self.validation_loader
def _setup_logging(self, output_path): def _setup_logging(self, output_path):
"""Set up directory structure for logging and model checkpoints""" """Set up basic logging and checkpoint directory"""
timestamp = time.strftime("%Y%m%d-%H%M%S")
self.log_dir = os.path.join(output_path, f"training_run_{timestamp}")
os.makedirs(self.log_dir, exist_ok=True)
# Create checkpoint directory # Create checkpoint directory
self.checkpoint_dir = os.path.join(self.log_dir, "checkpoints") self.checkpoint_dir = os.path.join(output_path, "checkpoints")
os.makedirs(self.checkpoint_dir, exist_ok=True) os.makedirs(self.checkpoint_dir, exist_ok=True)
# Set up CSV logging # Set up CSV logging
self.csv_path = os.path.join(self.log_dir, 'training_log.csv') self.csv_path = os.path.join(output_path, 'training_log.csv')
with open(self.csv_path, mode='w', newline='') as file: with open(self.csv_path, mode='w', newline='') as file:
writer = csv.writer(file) writer = csv.writer(file)
if self.validation_loader: if self.validation_loader:
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss', 'Improved']) writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
else: else:
writer.writerow(['Epoch', 'Train Loss', 'Improved']) writer.writerow(['Epoch', 'Train Loss'])
def _evaluate(self): def _evaluate(self):
"""Evaluate the model on validation data""" """Evaluate the model on validation data"""
@ -152,63 +155,99 @@ class aiuNNTrainer:
self.model.train() self.model.train()
return val_loss return val_loss
def _save_checkpoint(self, epoch, is_best=False): def _save_checkpoint(self, epoch, batch_count, is_best=False, is_22=False):
"""Save model checkpoint""" """Save checkpoint with support for regular, best, and 22:00 saves"""
checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}.pt") if is_22:
best_model_path = os.path.join(self.log_dir, "best_model") today = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))).date()
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
else:
checkpoint_name = f"checkpoint_epoch{epoch}_batch{batch_count}.pt"
# Save the model checkpoint checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name)
self.model.save(checkpoint_path)
# If this is the best model so far, copy it to best_model checkpoint_data = {
'epoch': epoch,
'batch': batch_count,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'best_loss': self.best_loss,
'scaler_state_dict': self.scaler.state_dict()
}
torch.save(checkpoint_data, checkpoint_path)
# Save best model separately
if is_best: if is_best:
if os.path.exists(best_model_path): best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model")
shutil.rmtree(best_model_path) self.model.save_pretrained(best_model_path)
self.model.save(best_model_path)
print(f"Saved new best model with loss: {self.best_loss:.6f}") return checkpoint_path
def _handle_checkpoints(self, epoch, batch_count, is_improved):
"""Handle periodic and 22:00 checkpoint saving"""
current_time = time.time()
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2)))
# Regular interval checkpoint
if (current_time - self.last_checkpoint_time) >= self.checkpoint_interval:
self._save_checkpoint(epoch, batch_count, is_improved)
self.last_checkpoint_time = current_time
# Special 22:00 checkpoint
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
if is_22_oclock and self.last_22_date != current_dt.date():
self._save_checkpoint(epoch, batch_count, is_improved, is_22=True)
self.last_22_date = current_dt.date()
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001): def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
""" """Finetune the upscaler model"""
Finetune the upscaler model
Args:
output_path (str): Directory to save models and logs
epochs (int): Maximum number of training epochs
lr (float): Learning rate
patience (int): Early stopping patience
min_delta (float): Minimum improvement for early stopping
"""
# Check if data is loaded
if self.data_loader is None: if self.data_loader is None:
raise ValueError("Data not loaded. Call load_data first.") raise ValueError("Data not loaded. Call load_data first.")
# Setup optimizer # Setup optimizer and directories
self.optimizer = optim.Adam(self.model.parameters(), lr=lr) self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
self.checkpoint_dir = os.path.join(output_path, "checkpoints")
os.makedirs(self.checkpoint_dir, exist_ok=True)
# Set up logging # Setup CSV logging
self._setup_logging(output_path) self.csv_path = os.path.join(output_path, 'training_log.csv')
with open(self.csv_path, mode='w', newline='') as file:
writer = csv.writer(file)
header = ['Epoch', 'Train Loss', 'Validation Loss'] if self.validation_loader else ['Epoch', 'Train Loss']
writer.writerow(header)
# Load existing checkpoint if available
checkpoint_info = self.load_checkpoint()
start_epoch = checkpoint_info[0] if checkpoint_info else 0
start_batch = checkpoint_info[1] if checkpoint_info else 0
# Setup early stopping # Setup early stopping
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta) early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
self.best_loss = float('inf')
# Training loop # Training loop
self.model.train() self.model.train()
for epoch in range(start_epoch, epochs):
for epoch in range(epochs): self.current_epoch = epoch
# Training phase
epoch_loss = 0.0 epoch_loss = 0.0
progress_bar = tqdm(self.data_loader, desc=f"Epoch {epoch + 1}/{epochs}")
for low_res, high_res in progress_bar: train_batches = list(enumerate(self.data_loader))
# Move data to GPU with channels_last format where possible start_idx = start_batch if epoch == start_epoch else 0
progress_bar = tqdm(train_batches[start_idx:],
initial=start_idx,
total=len(train_batches),
desc=f"Epoch {epoch + 1}/{epochs}")
for batch_idx, (low_res, high_res) in progress_bar:
# Training step
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last) low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
high_res = high_res.to(self.device, non_blocking=True) high_res = high_res.to(self.device, non_blocking=True)
self.optimizer.zero_grad() self.optimizer.zero_grad()
with autocast(device_type=self.device.type): with autocast(device_type=self.device.type):
if self.use_checkpointing: if hasattr(self, 'use_checkpointing') and self.use_checkpointing:
# Ensure the input tensor requires gradient so that checkpointing records the computation graph
low_res.requires_grad_() low_res.requires_grad_()
outputs = checkpoint(self.model, low_res) outputs = checkpoint(self.model, low_res)
else: else:
@ -222,69 +261,109 @@ class aiuNNTrainer:
epoch_loss += loss.item() epoch_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()}) progress_bar.set_postfix({'loss': loss.item()})
# Optionally delete variables to free memory # Handle checkpoints
self._handle_checkpoints(epoch + 1, batch_idx + 1, loss.item() < self.best_loss)
del low_res, high_res, outputs, loss del low_res, high_res, outputs, loss
# Calculate average epoch loss # End of epoch processing
avg_train_loss = epoch_loss / len(self.data_loader) avg_train_loss = epoch_loss / len(self.data_loader)
# Validation phase (if validation loader exists) # Validation phase
if self.validation_loader: if self.validation_loader:
val_loss = self._evaluate() / len(self.validation_loader) val_loss = self._evaluate() / len(self.validation_loader)
is_improved = val_loss < self.best_loss is_improved = val_loss < self.best_loss
if is_improved: if is_improved:
self.best_loss = val_loss self.best_loss = val_loss
# Log results # Log to CSV
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
with open(self.csv_path, mode='a', newline='') as file: with open(self.csv_path, mode='a', newline='') as file:
writer = csv.writer(file) writer = csv.writer(file)
writer.writerow([epoch + 1, avg_train_loss, val_loss, "Yes" if is_improved else "No"]) writer.writerow([epoch + 1, avg_train_loss, val_loss])
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
else: else:
# If no validation, use training loss for improvement tracking
is_improved = avg_train_loss < self.best_loss is_improved = avg_train_loss < self.best_loss
if is_improved: if is_improved:
self.best_loss = avg_train_loss self.best_loss = avg_train_loss
# Log results # Log to CSV
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
with open(self.csv_path, mode='a', newline='') as file: with open(self.csv_path, mode='a', newline='') as file:
writer = csv.writer(file) writer = csv.writer(file)
writer.writerow([epoch + 1, avg_train_loss, "Yes" if is_improved else "No"]) writer.writerow([epoch + 1, avg_train_loss])
# Save checkpoint print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
self._save_checkpoint(epoch + 1, is_best=is_improved)
# Perform garbage collection and clear GPU cache after each epoch # Save best model if improved
gc.collect() if is_improved:
torch.cuda.empty_cache() best_model_path = os.path.join(output_path, "best_model")
self.model.save_pretrained(best_model_path)
# Check early stopping # Check early stopping
early_stopping(val_loss if self.validation_loader else avg_train_loss) if early_stopping(val_loss if self.validation_loader else avg_train_loss):
if early_stopping.early_stop:
print(f"Early stopping triggered at epoch {epoch + 1}") print(f"Early stopping triggered at epoch {epoch + 1}")
break break
# Cleanup
gc.collect()
torch.cuda.empty_cache()
return self.best_loss return self.best_loss
def load_checkpoint(self, specific_checkpoint=None):
"""Enhanced checkpoint loading with specific checkpoint support"""
if specific_checkpoint:
checkpoint_path = os.path.join(self.checkpoint_dir, specific_checkpoint)
else:
checkpoint_files = [f for f in os.listdir(self.checkpoint_dir)
if f.startswith("checkpoint_") and f.endswith(".pt")]
if not checkpoint_files:
return None
checkpoint_files.sort(key=lambda x: os.path.getmtime(
os.path.join(self.checkpoint_dir, x)), reverse=True)
checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_files[0])
if not os.path.exists(checkpoint_path):
return None
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
self.best_loss = checkpoint['best_loss']
print(f"Loaded checkpoint from {checkpoint_path}")
return checkpoint['epoch'], checkpoint['batch']
def save(self, output_path=None): def save(self, output_path=None):
""" """
Save the best model to the specified path Save the best model to the specified path
Args: Args:
output_path (str, optional): Path to save the model. If None, uses the best model from training. output_path (str, optional): Path to save the model. If None, tries to use the checkpoint directory from training.
Returns:
str: Path where the model was saved
Raises:
ValueError: If no output path is specified and no checkpoint directory exists
""" """
if output_path is None and self.log_dir is not None: if output_path is None and self.checkpoint_dir is not None:
best_model_path = os.path.join(self.log_dir, "best_model") # First try to copy the best model if it exists
best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model")
if os.path.exists(best_model_path): if os.path.exists(best_model_path):
print(f"Best model already saved at {best_model_path}") output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model")
return best_model_path shutil.copytree(best_model_path, output_path, dirs_exist_ok=True)
print(f"Copied best model to {output_path}")
return output_path
else: else:
output_path = os.path.join(self.log_dir, "final_model") # If no best model exists, save current model state
output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model")
if output_path is None: if output_path is None:
raise ValueError("No output path specified and no training has been done yet.") raise ValueError("No output path specified and no checkpoint directory exists from training.")
self.model.save(output_path) self.model.save_pretrained(output_path)
print(f"Model saved to {output_path}") print(f"Model saved to {output_path}")
return output_path return output_path