develop #4
|
@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split
|
|||
from typing import Dict, List, Union, Optional
|
||||
import base64
|
||||
from tqdm import tqdm
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
|
@ -99,14 +99,15 @@ class ImageDataset(Dataset):
|
|||
if 'high_res_stream' in locals():
|
||||
high_res_stream.close()
|
||||
|
||||
class ModelTrainer:
|
||||
class FineTuner:
|
||||
def __init__(self,
|
||||
model: AIIA,
|
||||
dataset_paths: List[str],
|
||||
batch_size: int = 32,
|
||||
learning_rate: float = 0.001,
|
||||
num_workers: int = 4,
|
||||
train_ratio: float = 0.8):
|
||||
train_ratio: float = 0.8,
|
||||
output_dir: str = "./training_logs"):
|
||||
"""
|
||||
Specialized trainer for image super resolution tasks
|
||||
|
||||
|
@ -117,6 +118,7 @@ class ModelTrainer:
|
|||
learning_rate (float): Learning rate for optimizer
|
||||
num_workers (int): Number of workers for data loading
|
||||
train_ratio (float): Ratio of data to use for training (rest goes to validation)
|
||||
output_dir (str): Directory to save training logs and model checkpoints
|
||||
"""
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.batch_size = batch_size
|
||||
|
@ -125,6 +127,22 @@ class ModelTrainer:
|
|||
self.learning_rate = learning_rate
|
||||
self.train_ratio = train_ratio
|
||||
self.model = model
|
||||
self.output_dir = output_dir
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Initialize history tracking
|
||||
self.train_losses = []
|
||||
self.val_losses = []
|
||||
self.best_val_loss = float('inf')
|
||||
self.current_val_loss = float('inf')
|
||||
|
||||
# Initialize CSV logging
|
||||
self.log_file = os.path.join(output_dir, 'training_log.csv')
|
||||
if not os.path.exists(self.log_file):
|
||||
with open(self.log_file, 'w') as f:
|
||||
f.write('epoch,train_loss,val_loss,best_val_loss\n')
|
||||
|
||||
# Initialize datasets and loaders
|
||||
self._initialize_datasets()
|
||||
|
@ -171,6 +189,19 @@ class ModelTrainer:
|
|||
num_workers=self.num_workers
|
||||
) if df_val is not None else None
|
||||
|
||||
def _log_metrics(self, epoch: int, train_loss: float, val_loss: float):
|
||||
"""
|
||||
Log training metrics to CSV file
|
||||
|
||||
Args:
|
||||
epoch (int): Current epoch number
|
||||
train_loss (float): Training loss for the epoch
|
||||
val_loss (float): Validation loss for the epoch
|
||||
"""
|
||||
with open(self.log_file, 'a') as f:
|
||||
f.write(f'{epoch},{train_loss:.6f},{val_loss:.6f},{self.best_val_loss:.6f}\n')
|
||||
|
||||
|
||||
def _initialize_training(self):
|
||||
"""
|
||||
Helper method to initialize training parameters
|
||||
|
@ -211,29 +242,11 @@ class ModelTrainer:
|
|||
|
||||
self.best_val_loss = float('inf')
|
||||
|
||||
def train(self, num_epochs: int = 10):
|
||||
"""
|
||||
Train the model for specified number of epochs
|
||||
"""
|
||||
self.model.to(self.device)
|
||||
|
||||
for epoch in tqdm(range(num_epochs), desc="Training"):
|
||||
print(f"Epoch {epoch+1}/{num_epochs}")
|
||||
|
||||
# Train phase
|
||||
self._train_epoch()
|
||||
|
||||
# Validation phase
|
||||
if self.val_loader is not None:
|
||||
self._validate_epoch()
|
||||
|
||||
# Save best model based on validation loss
|
||||
if self.val_loader is not None and self.current_val_loss < self.best_val_loss:
|
||||
self.model.save("aiuNN-finetuned")
|
||||
|
||||
def _train_epoch(self):
|
||||
"""
|
||||
Train model for one epoch
|
||||
Returns:
|
||||
float: Average training loss for the epoch
|
||||
"""
|
||||
self.model.train()
|
||||
running_loss = 0.0
|
||||
|
@ -256,11 +269,15 @@ class ModelTrainer:
|
|||
running_loss += loss.item()
|
||||
|
||||
epoch_loss = running_loss / len(self.train_loader)
|
||||
self.train_losses.append(epoch_loss)
|
||||
print(f"Train Loss: {epoch_loss:.4f}")
|
||||
return epoch_loss
|
||||
|
||||
def _validate_epoch(self):
|
||||
"""
|
||||
Validate model performance
|
||||
Returns:
|
||||
float: Average validation loss for the epoch
|
||||
"""
|
||||
self.model.eval()
|
||||
val_loss = 0.0
|
||||
|
@ -276,12 +293,41 @@ class ModelTrainer:
|
|||
loss = self.criterion(outputs, high_ress)
|
||||
val_loss += loss.item()
|
||||
|
||||
avg_val_loss = val_loss / len(self.val_loader) if self.val_loader else 0
|
||||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
||||
self.current_val_loss = val_loss / len(self.val_loader)
|
||||
self.val_losses.append(self.current_val_loss)
|
||||
print(f"Validation Loss: {self.current_val_loss:.4f}")
|
||||
return self.current_val_loss
|
||||
|
||||
# Update best model
|
||||
if avg_val_loss < self.best_val_loss:
|
||||
self.best_val_loss = avg_val_loss
|
||||
def train(self, num_epochs: int = 10):
|
||||
"""
|
||||
Train the model for specified number of epochs
|
||||
Args:
|
||||
num_epochs (int): Number of epochs to train for
|
||||
"""
|
||||
self.model.to(self.device)
|
||||
|
||||
print(f"Training metrics will be logged to: {self.log_file}")
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||
|
||||
# Train phase
|
||||
train_loss = self._train_epoch()
|
||||
|
||||
# Validation phase
|
||||
if self.val_loader is not None:
|
||||
val_loss = self._validate_epoch()
|
||||
|
||||
# Log metrics
|
||||
self._log_metrics(epoch + 1, train_loss, val_loss)
|
||||
|
||||
# Save best model based on validation loss
|
||||
if self.current_val_loss < self.best_val_loss:
|
||||
print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}")
|
||||
self.best_val_loss = self.current_val_loss
|
||||
model_save_path = os.path.join(self.output_dir, "aiuNN-finetuned")
|
||||
self.model.save(model_save_path)
|
||||
print(f"Model saved to: {model_save_path}")
|
||||
|
||||
def __repr__(self):
|
||||
return f"ModelTrainer (model={type(self.model).__name__}, batch_size={self.batch_size})"
|
||||
|
@ -290,7 +336,7 @@ if __name__ == "__main__":
|
|||
# Load your model first
|
||||
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512/")
|
||||
|
||||
trainer = ModelTrainer(
|
||||
trainer = FineTuner(
|
||||
model=model,
|
||||
dataset_paths=[
|
||||
"/root/training_data/vision-dataset/image_upscaler.parquet",
|
||||
|
|
Loading…
Reference in New Issue