finetune improvement

This commit is contained in:
Falko Victor Habel 2025-01-30 10:36:15 +01:00
parent 4a60045320
commit 2121316e3b
1 changed files with 75 additions and 70 deletions

View File

@ -5,10 +5,9 @@ import io
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from aiia.model import AIIABase
from aiia.model import AIIABase, AIIA
from sklearn.model_selection import train_test_split
from typing import Dict, List, Union
from typing import Dict, List, Union, Optional
class ImageDataset(Dataset):
def __init__(self, dataframe, transform=None):
@ -36,24 +35,21 @@ class ImageDataset(Dataset):
low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image)
return {'low_res': low_res_image, 'high_res': high_res_image}
return {'low_ress': low_res_image, 'high_ress': high_res_image}
class ModelTrainer:
def __init__(self,
model_name: str = "AIIA-Base-512",
dataset_paths: List[str] = None,
model: AIIA,
dataset_paths: List[str],
batch_size: int = 32,
learning_rate: float = 0.001,
num_workers: int = 4,
train_ratio: float = 0.8):
"""
Specialized trainer for image super resolution tasks
Args:
model_name (str): Name of the model to initialize
model (nn.Module): Model instance to finetune
dataset_paths (List[str]): Paths to datasets
batch_size (int): Batch size for training
learning_rate (float): Learning rate for optimizer
@ -64,120 +60,126 @@ class ModelTrainer:
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset_paths = dataset_paths
self.model_name = model_name
self.learning_rate = learning_rate
self.train_ratio = train_ratio
self.model = model
# Initialize datasets and loaders
self._initialize_datasets()
# Initialize model and training parameters
self._initialize_model()
# Initialize training parameters
self._initialize_training()
def _initialize_datasets(self):
"""
Helper method to initialize datasets
"""
# Read training data based on input format
if isinstance(self.dataset_paths, list):
df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True)
else:
raise ValueError("Invalid dataset_paths format. Must be a list or dictionary.")
raise ValueError("Invalid dataset_paths format. Must be a list.")
df_train, df_val = train_test_split(
df_train,
test_size=1 - self.train_ratio,
random_state=42
)
# Define preprocessing transforms
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Create datasets and dataloaders
self.train_dataset = ImageDataset(df_train, transform=self.transform)
self.val_dataset = ImageDataset(df_val, transform=self.transform)
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers
) if df_val is not None else None
def _initialize_model(self):
def _initialize_training(self):
"""
Helper method to initialize model architecture and training parameters
Helper method to initialize training parameters
"""
# Load base model
self.model = AIIABase.load(self.model_name)
# Freeze CNN layers
for param in self.model.cnn.parameters():
param.requires_grad = False
# Add upscaling layer
hidden_size = self.model.config.hidden_size
kernel_size = self.model.config.kernel_size
self.model.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1)
)
# Freeze CNN layers (if applicable)
try:
for param in self.model.cnn.parameters():
param.requires_grad = False
except AttributeError:
pass # If model doesn't have a 'cnn' attribute, just continue
# Add upscaling layer if not already present
if not hasattr(self.model, 'upsample'):
hidden_size = self.model.config.hidden_size
kernel_size = self.model.config.kernel_size
self.model.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1)
)
# Initialize optimizer and loss function
self.criterion = nn.MSELoss()
# Get parameters of the upsample layer for training
params = [p for p in self.model.upsample.parameters() if p.requires_grad]
if not params:
raise ValueError("No parameters found in upsample layer to optimize")
self.optimizer = torch.optim.Adam(
[param for param in self.model.parameters() if 'upsample' in str(param)],
params,
lr=self.learning_rate
)
self.best_val_loss = float('inf')
def train(self, num_epochs: int = 10):
"""
Train the model for specified number of epochs
"""
self.model.to(self.device)
for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}")
# Train phase
self._train_epoch()
# Validation phase
if self.val_loader is not None:
self._validate_epoch()
# Save best model based on validation loss
if self.val_loader is not None and self.current_val_loss < self.best_val_loss:
self.model.save("aiuNN-base")
self.model.save("aiuNN-finetuned")
def _train_epoch(self):
"""
Train model for one epoch
"""
self.model.train()
running_loss = 0.0
for batch in self.train_loader:
low_res = batch['low_ress'].to(self.device)
high_res = batch['high_ress'].to(self.device)
low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device)
# Forward pass
features = self.model.cnn(low_res)
features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress)
outputs = self.model.upsample(features)
loss = self.criterion(outputs, high_res)
loss = self.criterion(outputs, high_ress)
# Backward pass and optimize
self.optimizer.zero_grad()
@ -185,41 +187,44 @@ class ModelTrainer:
self.optimizer.step()
running_loss += loss.item()
epoch_loss = running_loss / len(self.train_loader)
print(f"Train Loss: {epoch_loss:.4f}")
def _validate_epoch(self):
"""
Validate model performance
"""
self.model.eval()
val_oss = 0.0
val_loss = 0.0
with torch.no_grad():
for batch in self.val_loader:
low_res = batch['low_ress'].to(self.device)
high_res = batch['high_ress'].to(self.device)
low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device)
features = self.model.cnn(low_res)
features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress)
outputs = self.model.upsample(features)
loss = self.criterion(outputs, high_res)
loss = self.criterion(outputs, high_ress)
val_loss += loss.item()
avg_val_loss = val_loss / len(self.val_loader) if self.val_loader else 0
print(f"Validation Loss: {avg_val_loss:.4f}")
# Update best model
if avg_val_loss < self.best_val_loss:
self.best_val_loss = avg_val_loss
def __repr__(self):
return f"Model ({self.model_name}, batch_size={self.batch_size})"
return f"ModelTrainer (model={type(self.model).__name__}, batch_size={self.batch_size})"
if __name__ == "__main__":
# Load your model first
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512/")
trainer = ModelTrainer(
model_name="/root/vision/AIIA/AIIA-base-512/",
model=model,
dataset_paths=[
"/root/training_data/vision-dataset/image_upscaler.parquet",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
@ -227,5 +232,5 @@ if __name__ == "__main__":
batch_size=2,
learning_rate=0.001
)
trainer.train(num__epochs=3)
trainer.train(num_epochs=3)