develop #4
|
@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split
|
||||||
from typing import Dict, List, Union, Optional
|
from typing import Dict, List, Union, Optional
|
||||||
import base64
|
import base64
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
class ImageDataset(Dataset):
|
class ImageDataset(Dataset):
|
||||||
|
@ -99,14 +99,15 @@ class ImageDataset(Dataset):
|
||||||
if 'high_res_stream' in locals():
|
if 'high_res_stream' in locals():
|
||||||
high_res_stream.close()
|
high_res_stream.close()
|
||||||
|
|
||||||
class ModelTrainer:
|
class FineTuner:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model: AIIA,
|
model: AIIA,
|
||||||
dataset_paths: List[str],
|
dataset_paths: List[str],
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
learning_rate: float = 0.001,
|
learning_rate: float = 0.001,
|
||||||
num_workers: int = 4,
|
num_workers: int = 4,
|
||||||
train_ratio: float = 0.8):
|
train_ratio: float = 0.8,
|
||||||
|
output_dir: str = "./training_logs"):
|
||||||
"""
|
"""
|
||||||
Specialized trainer for image super resolution tasks
|
Specialized trainer for image super resolution tasks
|
||||||
|
|
||||||
|
@ -117,6 +118,7 @@ class ModelTrainer:
|
||||||
learning_rate (float): Learning rate for optimizer
|
learning_rate (float): Learning rate for optimizer
|
||||||
num_workers (int): Number of workers for data loading
|
num_workers (int): Number of workers for data loading
|
||||||
train_ratio (float): Ratio of data to use for training (rest goes to validation)
|
train_ratio (float): Ratio of data to use for training (rest goes to validation)
|
||||||
|
output_dir (str): Directory to save training logs and model checkpoints
|
||||||
"""
|
"""
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
@ -125,6 +127,22 @@ class ModelTrainer:
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.train_ratio = train_ratio
|
self.train_ratio = train_ratio
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.output_dir = output_dir
|
||||||
|
|
||||||
|
# Create output directory if it doesn't exist
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Initialize history tracking
|
||||||
|
self.train_losses = []
|
||||||
|
self.val_losses = []
|
||||||
|
self.best_val_loss = float('inf')
|
||||||
|
self.current_val_loss = float('inf')
|
||||||
|
|
||||||
|
# Initialize CSV logging
|
||||||
|
self.log_file = os.path.join(output_dir, 'training_log.csv')
|
||||||
|
if not os.path.exists(self.log_file):
|
||||||
|
with open(self.log_file, 'w') as f:
|
||||||
|
f.write('epoch,train_loss,val_loss,best_val_loss\n')
|
||||||
|
|
||||||
# Initialize datasets and loaders
|
# Initialize datasets and loaders
|
||||||
self._initialize_datasets()
|
self._initialize_datasets()
|
||||||
|
@ -171,6 +189,19 @@ class ModelTrainer:
|
||||||
num_workers=self.num_workers
|
num_workers=self.num_workers
|
||||||
) if df_val is not None else None
|
) if df_val is not None else None
|
||||||
|
|
||||||
|
def _log_metrics(self, epoch: int, train_loss: float, val_loss: float):
|
||||||
|
"""
|
||||||
|
Log training metrics to CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epoch (int): Current epoch number
|
||||||
|
train_loss (float): Training loss for the epoch
|
||||||
|
val_loss (float): Validation loss for the epoch
|
||||||
|
"""
|
||||||
|
with open(self.log_file, 'a') as f:
|
||||||
|
f.write(f'{epoch},{train_loss:.6f},{val_loss:.6f},{self.best_val_loss:.6f}\n')
|
||||||
|
|
||||||
|
|
||||||
def _initialize_training(self):
|
def _initialize_training(self):
|
||||||
"""
|
"""
|
||||||
Helper method to initialize training parameters
|
Helper method to initialize training parameters
|
||||||
|
@ -211,29 +242,11 @@ class ModelTrainer:
|
||||||
|
|
||||||
self.best_val_loss = float('inf')
|
self.best_val_loss = float('inf')
|
||||||
|
|
||||||
def train(self, num_epochs: int = 10):
|
|
||||||
"""
|
|
||||||
Train the model for specified number of epochs
|
|
||||||
"""
|
|
||||||
self.model.to(self.device)
|
|
||||||
|
|
||||||
for epoch in tqdm(range(num_epochs), desc="Training"):
|
|
||||||
print(f"Epoch {epoch+1}/{num_epochs}")
|
|
||||||
|
|
||||||
# Train phase
|
|
||||||
self._train_epoch()
|
|
||||||
|
|
||||||
# Validation phase
|
|
||||||
if self.val_loader is not None:
|
|
||||||
self._validate_epoch()
|
|
||||||
|
|
||||||
# Save best model based on validation loss
|
|
||||||
if self.val_loader is not None and self.current_val_loss < self.best_val_loss:
|
|
||||||
self.model.save("aiuNN-finetuned")
|
|
||||||
|
|
||||||
def _train_epoch(self):
|
def _train_epoch(self):
|
||||||
"""
|
"""
|
||||||
Train model for one epoch
|
Train model for one epoch
|
||||||
|
Returns:
|
||||||
|
float: Average training loss for the epoch
|
||||||
"""
|
"""
|
||||||
self.model.train()
|
self.model.train()
|
||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
|
@ -256,11 +269,15 @@ class ModelTrainer:
|
||||||
running_loss += loss.item()
|
running_loss += loss.item()
|
||||||
|
|
||||||
epoch_loss = running_loss / len(self.train_loader)
|
epoch_loss = running_loss / len(self.train_loader)
|
||||||
|
self.train_losses.append(epoch_loss)
|
||||||
print(f"Train Loss: {epoch_loss:.4f}")
|
print(f"Train Loss: {epoch_loss:.4f}")
|
||||||
|
return epoch_loss
|
||||||
|
|
||||||
def _validate_epoch(self):
|
def _validate_epoch(self):
|
||||||
"""
|
"""
|
||||||
Validate model performance
|
Validate model performance
|
||||||
|
Returns:
|
||||||
|
float: Average validation loss for the epoch
|
||||||
"""
|
"""
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
val_loss = 0.0
|
val_loss = 0.0
|
||||||
|
@ -276,12 +293,41 @@ class ModelTrainer:
|
||||||
loss = self.criterion(outputs, high_ress)
|
loss = self.criterion(outputs, high_ress)
|
||||||
val_loss += loss.item()
|
val_loss += loss.item()
|
||||||
|
|
||||||
avg_val_loss = val_loss / len(self.val_loader) if self.val_loader else 0
|
self.current_val_loss = val_loss / len(self.val_loader)
|
||||||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
self.val_losses.append(self.current_val_loss)
|
||||||
|
print(f"Validation Loss: {self.current_val_loss:.4f}")
|
||||||
|
return self.current_val_loss
|
||||||
|
|
||||||
# Update best model
|
def train(self, num_epochs: int = 10):
|
||||||
if avg_val_loss < self.best_val_loss:
|
"""
|
||||||
self.best_val_loss = avg_val_loss
|
Train the model for specified number of epochs
|
||||||
|
Args:
|
||||||
|
num_epochs (int): Number of epochs to train for
|
||||||
|
"""
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
print(f"Training metrics will be logged to: {self.log_file}")
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||||
|
|
||||||
|
# Train phase
|
||||||
|
train_loss = self._train_epoch()
|
||||||
|
|
||||||
|
# Validation phase
|
||||||
|
if self.val_loader is not None:
|
||||||
|
val_loss = self._validate_epoch()
|
||||||
|
|
||||||
|
# Log metrics
|
||||||
|
self._log_metrics(epoch + 1, train_loss, val_loss)
|
||||||
|
|
||||||
|
# Save best model based on validation loss
|
||||||
|
if self.current_val_loss < self.best_val_loss:
|
||||||
|
print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}")
|
||||||
|
self.best_val_loss = self.current_val_loss
|
||||||
|
model_save_path = os.path.join(self.output_dir, "aiuNN-finetuned")
|
||||||
|
self.model.save(model_save_path)
|
||||||
|
print(f"Model saved to: {model_save_path}")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"ModelTrainer (model={type(self.model).__name__}, batch_size={self.batch_size})"
|
return f"ModelTrainer (model={type(self.model).__name__}, batch_size={self.batch_size})"
|
||||||
|
@ -290,7 +336,7 @@ if __name__ == "__main__":
|
||||||
# Load your model first
|
# Load your model first
|
||||||
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512/")
|
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512/")
|
||||||
|
|
||||||
trainer = ModelTrainer(
|
trainer = FineTuner(
|
||||||
model=model,
|
model=model,
|
||||||
dataset_paths=[
|
dataset_paths=[
|
||||||
"/root/training_data/vision-dataset/image_upscaler.parquet",
|
"/root/training_data/vision-dataset/image_upscaler.parquet",
|
||||||
|
|
Loading…
Reference in New Issue