develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 75 additions and 29 deletions
Showing only changes of commit 9ec563e86d - Show all commits

View File

@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split
from typing import Dict, List, Union, Optional
import base64
from tqdm import tqdm
import os
class ImageDataset(Dataset):
@ -99,14 +99,15 @@ class ImageDataset(Dataset):
if 'high_res_stream' in locals():
high_res_stream.close()
class ModelTrainer:
class FineTuner:
def __init__(self,
model: AIIA,
dataset_paths: List[str],
batch_size: int = 32,
learning_rate: float = 0.001,
num_workers: int = 4,
train_ratio: float = 0.8):
train_ratio: float = 0.8,
output_dir: str = "./training_logs"):
"""
Specialized trainer for image super resolution tasks
@ -117,6 +118,7 @@ class ModelTrainer:
learning_rate (float): Learning rate for optimizer
num_workers (int): Number of workers for data loading
train_ratio (float): Ratio of data to use for training (rest goes to validation)
output_dir (str): Directory to save training logs and model checkpoints
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.batch_size = batch_size
@ -125,6 +127,22 @@ class ModelTrainer:
self.learning_rate = learning_rate
self.train_ratio = train_ratio
self.model = model
self.output_dir = output_dir
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Initialize history tracking
self.train_losses = []
self.val_losses = []
self.best_val_loss = float('inf')
self.current_val_loss = float('inf')
# Initialize CSV logging
self.log_file = os.path.join(output_dir, 'training_log.csv')
if not os.path.exists(self.log_file):
with open(self.log_file, 'w') as f:
f.write('epoch,train_loss,val_loss,best_val_loss\n')
# Initialize datasets and loaders
self._initialize_datasets()
@ -171,6 +189,19 @@ class ModelTrainer:
num_workers=self.num_workers
) if df_val is not None else None
def _log_metrics(self, epoch: int, train_loss: float, val_loss: float):
"""
Log training metrics to CSV file
Args:
epoch (int): Current epoch number
train_loss (float): Training loss for the epoch
val_loss (float): Validation loss for the epoch
"""
with open(self.log_file, 'a') as f:
f.write(f'{epoch},{train_loss:.6f},{val_loss:.6f},{self.best_val_loss:.6f}\n')
def _initialize_training(self):
"""
Helper method to initialize training parameters
@ -211,29 +242,11 @@ class ModelTrainer:
self.best_val_loss = float('inf')
def train(self, num_epochs: int = 10):
"""
Train the model for specified number of epochs
"""
self.model.to(self.device)
for epoch in tqdm(range(num_epochs), desc="Training"):
print(f"Epoch {epoch+1}/{num_epochs}")
# Train phase
self._train_epoch()
# Validation phase
if self.val_loader is not None:
self._validate_epoch()
# Save best model based on validation loss
if self.val_loader is not None and self.current_val_loss < self.best_val_loss:
self.model.save("aiuNN-finetuned")
def _train_epoch(self):
"""
Train model for one epoch
Returns:
float: Average training loss for the epoch
"""
self.model.train()
running_loss = 0.0
@ -256,11 +269,15 @@ class ModelTrainer:
running_loss += loss.item()
epoch_loss = running_loss / len(self.train_loader)
self.train_losses.append(epoch_loss)
print(f"Train Loss: {epoch_loss:.4f}")
return epoch_loss
def _validate_epoch(self):
"""
Validate model performance
Returns:
float: Average validation loss for the epoch
"""
self.model.eval()
val_loss = 0.0
@ -276,12 +293,41 @@ class ModelTrainer:
loss = self.criterion(outputs, high_ress)
val_loss += loss.item()
avg_val_loss = val_loss / len(self.val_loader) if self.val_loader else 0
print(f"Validation Loss: {avg_val_loss:.4f}")
self.current_val_loss = val_loss / len(self.val_loader)
self.val_losses.append(self.current_val_loss)
print(f"Validation Loss: {self.current_val_loss:.4f}")
return self.current_val_loss
# Update best model
if avg_val_loss < self.best_val_loss:
self.best_val_loss = avg_val_loss
def train(self, num_epochs: int = 10):
"""
Train the model for specified number of epochs
Args:
num_epochs (int): Number of epochs to train for
"""
self.model.to(self.device)
print(f"Training metrics will be logged to: {self.log_file}")
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
# Train phase
train_loss = self._train_epoch()
# Validation phase
if self.val_loader is not None:
val_loss = self._validate_epoch()
# Log metrics
self._log_metrics(epoch + 1, train_loss, val_loss)
# Save best model based on validation loss
if self.current_val_loss < self.best_val_loss:
print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}")
self.best_val_loss = self.current_val_loss
model_save_path = os.path.join(self.output_dir, "aiuNN-finetuned")
self.model.save(model_save_path)
print(f"Model saved to: {model_save_path}")
def __repr__(self):
return f"ModelTrainer (model={type(self.model).__name__}, batch_size={self.batch_size})"
@ -290,7 +336,7 @@ if __name__ == "__main__":
# Load your model first
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512/")
trainer = ModelTrainer(
trainer = FineTuner(
model=model,
dataset_paths=[
"/root/training_data/vision-dataset/image_upscaler.parquet",