updated finetuning

This commit is contained in:
Falko Victor Habel 2025-01-31 16:03:28 +01:00
parent 9ec563e86d
commit 13fb2b76c1
2 changed files with 103 additions and 45 deletions

View File

@ -99,6 +99,8 @@ class ImageDataset(Dataset):
if 'high_res_stream' in locals(): if 'high_res_stream' in locals():
high_res_stream.close() high_res_stream.close()
class FineTuner: class FineTuner:
def __init__(self, def __init__(self,
model: AIIA, model: AIIA,
@ -127,7 +129,7 @@ class FineTuner:
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.train_ratio = train_ratio self.train_ratio = train_ratio
self.model = model self.model = model
self.output_dir = output_dir self.ouptut_dir = output_dir
# Create output directory if it doesn't exist # Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
@ -150,6 +152,17 @@ class FineTuner:
# Initialize training parameters # Initialize training parameters
self._initialize_training() self._initialize_training()
def _freeze_layers(self):
"""Freeze all layers except the upsample layer"""
try:
# Try to freeze layers based on their names
for name, param in self.model.named_parameters():
if 'upsample' not in name:
param.requires_grad = False
except Exception as e:
print(f"Warning: Couldn't freeze layers - {str(e)}")
pass
def _initialize_datasets(self): def _initialize_datasets(self):
""" """
Helper method to initialize datasets Helper method to initialize datasets
@ -159,15 +172,18 @@ class FineTuner:
else: else:
raise ValueError("Invalid dataset_paths format. Must be a list.") raise ValueError("Invalid dataset_paths format. Must be a list.")
# Split into train and validation sets
df_train, df_val = train_test_split( df_train, df_val = train_test_split(
df_train, df_train,
test_size=1 - self.train_ratio, test_size=1 - self.train_ratio,
random_state=42 random_state=42
) )
# Define preprocessing transforms # Define preprocessing transforms with augmentation
self.transform = transforms.Compose([ self.transform = transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.RandomResizedCrop(256),
transforms.RandomHorizontalFlip(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]) ])
@ -201,46 +217,66 @@ class FineTuner:
with open(self.log_file, 'a') as f: with open(self.log_file, 'a') as f:
f.write(f'{epoch},{train_loss:.6f},{val_loss:.6f},{self.best_val_loss:.6f}\n') f.write(f'{epoch},{train_loss:.6f},{val_loss:.6f},{self.best_val_loss:.6f}\n')
def _initialize_training(self): def _initialize_training(self):
""" """
Helper method to initialize training parameters Helper method to initialize training parameters
""" """
# Freeze CNN layers (if applicable) # Freeze all layers except upsample layer
try: self._freeze_layers()
for param in self.model.cnn.parameters():
param.requires_grad = False
except AttributeError:
pass # If model doesn't have a 'cnn' attribute, just continue
# Add upscaling layer if not already present # Add upscaling layer if not already present
if not hasattr(self.model, 'upsample'): if not hasattr(self.model, 'upsample'):
# Get existing configuration values or set defaults if necessary # Try to get existing configuration or set defaults
try:
hidden_size = self.model.config.hidden_size hidden_size = self.model.config.hidden_size
kernel_size = self.model.config.kernel_size kernel_size = 3 # Use odd-sized kernel for better performance
except AttributeError:
# Fallback values if config isn't available
hidden_size = 512
kernel_size = 3
self.model.upsample = nn.Sequential( self.model.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), nn.ConvTranspose2d(hidden_size,
nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1) hidden_size//2,
kernel_size=kernel_size,
stride=2,
padding=1,
output_padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(hidden_size//2,
3,
kernel_size=kernel_size,
stride=2,
padding=1,
output_padding=1)
) )
# Update the model's configuration with new parameters
self.model.config.upsample_hidden_size = hidden_size self.model.config.upsample_hidden_size = hidden_size
self.model.config.upsample_kernel_size = kernel_size self.model.config.upsample_kernel_size = kernel_size
# Initialize optimizer and loss function # Initialize optimizer and scheduler
self.criterion = nn.MSELoss() params_to_optimize = [p for p in self.model.parameters() if p.requires_grad]
# Get parameters of the upsample layer for training if not params_to_optimize:
params = [p for p in self.model.upsample.parameters() if p.requires_grad] raise ValueError("No parameters found to optimize")
if not params:
raise ValueError("No parameters found in upsample layer to optimize")
# Use Adam with weight decay for better regularization
self.optimizer = torch.optim.Adam( self.optimizer = torch.optim.Adam(
params, params_to_optimize,
lr=self.learning_rate lr=self.learning_rate,
weight_decay=1e-4 # Add L2 regularization
) )
self.best_val_loss = float('inf') # Reduce learning rate when validation loss plateaus
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
factor=0.1, # Multiply LR by this factor on plateau
patience=3, # Number of epochs to wait before reducing LR
verbose=True
)
# Use a combination of L1 and L2 losses for better performance
self.criterion = nn.L1Loss()
self.mse_criterion = nn.MSELoss()
def _train_epoch(self): def _train_epoch(self):
""" """
@ -255,18 +291,26 @@ class FineTuner:
low_ress = batch['low_ress'].to(self.device) low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device)
# Forward pass try:
features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress) # Try using CNN layer if available
features = self.model.cnn(low_ress)
except AttributeError:
# Fallback to extract_features method
features = self.model.extract_features(low_ress)
outputs = self.model.upsample(features) outputs = self.model.upsample(features)
loss = self.criterion(outputs, high_ress) # Calculate loss with different scaling for L1 and MSE components
l1_loss = self.criterion(outputs, high_ress) * 0.5
mse_loss = self.mse_criterion(outputs, high_ress) * 0.5
total_loss = l1_loss + mse_loss
# Backward pass and optimize # Backward pass and optimize
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss.backward() total_loss.backward()
self.optimizer.step() self.optimizer.step()
running_loss += loss.item() running_loss += total_loss.item()
epoch_loss = running_loss / len(self.train_loader) epoch_loss = running_loss / len(self.train_loader)
self.train_losses.append(epoch_loss) self.train_losses.append(epoch_loss)
@ -287,11 +331,19 @@ class FineTuner:
low_ress = batch['low_ress'].to(self.device) low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device)
features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress) try:
features = self.model.cnn(low_ress)
except AttributeError:
features = self.model.extract_features(low_ress)
outputs = self.model.upsample(features) outputs = self.model.upsample(features)
loss = self.criterion(outputs, high_ress) # Calculate same loss combination
val_loss += loss.item() l1_loss = self.criterion(outputs, high_ress) * 0.5
mse_loss = self.mse_criterion(outputs, high_ress) * 0.5
total_loss = l1_loss + mse_loss
val_loss += total_loss.item()
self.current_val_loss = val_loss / len(self.val_loader) self.current_val_loss = val_loss / len(self.val_loader)
self.val_losses.append(self.current_val_loss) self.val_losses.append(self.current_val_loss)
@ -318,6 +370,9 @@ class FineTuner:
if self.val_loader is not None: if self.val_loader is not None:
val_loss = self._validate_epoch() val_loss = self._validate_epoch()
# Update learning rate scheduler based on validation loss
self.scheduler.step(val_loss)
# Log metrics # Log metrics
self._log_metrics(epoch + 1, train_loss, val_loss) self._log_metrics(epoch + 1, train_loss, val_loss)
@ -325,25 +380,28 @@ class FineTuner:
if self.current_val_loss < self.best_val_loss: if self.current_val_loss < self.best_val_loss:
print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}") print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}")
self.best_val_loss = self.current_val_loss self.best_val_loss = self.current_val_loss
model_save_path = os.path.join(self.output_dir, "aiuNN-finetuned") model_save_path = os.path.join(self.ouptut_dir, "aiuNN-optimized")
self.model.save(model_save_path) self.model.save(model_save_path)
print(f"Model saved to: {model_save_path}") print(f"Model saved to: {model_save_path}")
def __repr__(self): # After training, save the final model
return f"ModelTrainer (model={type(self.model).__name__}, batch_size={self.batch_size})" final_model_path = os.path.join(self.ouptut_dir, "aiuNN-final")
self.model.save(final_model_path)
print(f"\nFinal model saved to: {final_model_path}")
if __name__ == "__main__": if __name__ == "__main__":
# Load your model first # Load your model first
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512/") model = AIIABase.load("/root/vision/dataset/AIIA-base-512/")
trainer = FineTuner( trainer = FineTuner(
model=model, model=model,
dataset_paths=[ dataset_paths=[
"/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_upscaler.0",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet" "/root/training_data/vision-dataset/image_vec_upscaler.0"
], ],
batch_size=2, batch_size=8, # Increased batch size
learning_rate=0.001 learning_rate=1e-4 # Reduced initial LR
) )
trainer.train(num_epochs=3) trainer.train(num_epochs=10) # Extended training time

View File

@ -5,7 +5,7 @@ from torch.nn import functional as F
from aiia.model import AIIABase from aiia.model import AIIABase
class UpScaler: class UpScaler:
def __init__(self, model_path="AIIA-base-512-upscaler", device="cuda"): def __init__(self, model_path="aiuNN-finetuned", device="cuda"):
self.device = torch.device(device) self.device = torch.device(device)
self.model = AIIABase.load(model_path).to(self.device) self.model = AIIABase.load(model_path).to(self.device)
self.model.eval() self.model.eval()