finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 141 additions and 395 deletions
Showing only changes of commit de6a67cb4e - Show all commits

View File

@ -1,409 +1,155 @@
import torch import torch
import pandas as pd import pandas as pd
from PIL import Image, ImageFile import numpy as np
import io import cv2
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from aiia.model import AIIABase, AIIA, AIIAConfig
from sklearn.model_selection import train_test_split
from typing import Dict, List, Union, Optional
import base64
from tqdm import tqdm
import os import os
from albumentations import (
Compose, Resize, Normalize, RandomBrightnessContrast,
HorizontalFlip, VerticalFlip, Rotate, GaussianBlur
)
from albumentations.pytorch import ToTensorV2
from torch import nn
# Import the model and config from your existing code
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path, config=None):
# Read the Parquet file
self.df = pd.read_parquet(parquet_path)
class ImageDataset(Dataset): # Data augmentation pipeline
def __init__(self, dataframe, transform=None): self.augmentation = Compose([
""" Resize(height=512, width=512),
Initialize the dataset with a dataframe containing image data RandomBrightnessContrast(),
HorizontalFlip(p=0.5),
Args: VerticalFlip(p=0.5),
dataframe (pd.DataFrame): DataFrame containing 'image_512' and 'image_1024' columns Rotate(limit=45),
transform (callable, optional): Optional transform to be applied to both images GaussianBlur(p=0.3),
""" Normalize(mean=[0.5], std=[0.5]),
self.dataframe = dataframe ToTensorV2()
self.transform = transform
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
"""
Get a pair of low and high resolution images
Args:
idx (int): Index of the data point
Returns:
dict: Contains 'low_ress' and 'high_ress' PIL images or transformed tensors
"""
row = self.dataframe.iloc[idx]
try:
# Handle both bytes and base64 encoded strings
low_res_data = row['image_512']
high_res_data = row['image_1024']
if isinstance(low_res_data, str):
# Decode base64 string to bytes
low_res_data = base64.b64decode(low_res_data)
high_res_data = base64.b64decode(high_res_data)
# Verify data is valid before creating BytesIO
if not isinstance(low_res_data, bytes) or not isinstance(high_res_data, bytes):
raise ValueError(f"Invalid image data format at index {idx}")
# Create image streams
low_res_stream = io.BytesIO(low_res_data)
high_res_stream = io.BytesIO(high_res_data)
# Enable loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Load and convert images to RGB
low_res_image = Image.open(low_res_stream).convert('RGB')
high_res_image = Image.open(high_res_stream).convert('RGB')
# Create fresh copies for verify() since it modifies the image object
low_res_verify = low_res_image.copy()
high_res_verify = high_res_image.copy()
# Verify images are valid
try:
low_res_verify.verify()
high_res_verify.verify()
except Exception as e:
raise ValueError(f"Image verification failed at index {idx}: {str(e)}")
finally:
low_res_verify.close()
high_res_verify.close()
# Apply transforms if specified
if self.transform:
low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image)
return {
'low_ress': low_res_image, # Note: Using 'low_ress' to match ModelTrainer
'high_ress': high_res_image # Note: Using 'high_ress' to match ModelTrainer
}
except Exception as e:
raise RuntimeError(f"Error loading images at index {idx}: {str(e)}")
finally:
# Ensure streams are closed
if 'low_res_stream' in locals():
low_res_stream.close()
if 'high_res_stream' in locals():
high_res_stream.close()
class SuperResolutionModel(AIIA):
def __init__(self, base_model: AIIA, config: AIIAConfig):
super(SuperResolutionModel, self).__init__(config=config)
# Use base model as encoder
self.encoder = base_model
for param in self.encoder.parameters():
param.requires_grad = False # Freeze encoder layers
# Add decoder layers to reconstruct image
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(128, 3, kernel_size=3, padding=1)
)
def forward(self, x):
features = self.encoder(x)
output = self.decoder(features)
return output
class FineTuner:
def __init__(self,
model: AIIA,
dataset_paths: List[str],
batch_size: int = 32,
learning_rate: float = 0.001,
num_workers: int = 4,
train_ratio: float = 0.8,
output_dir: str = "./training_logs"):
"""
Specialized trainer for image super resolution tasks
Args:
model (nn.Module): Model instance to finetune
dataset_paths (List[str]): Paths to datasets
batch_size (int): Batch size for training
learning_rate (float): Learning rate for optimizer
num_workers (int): Number of workers for data loading
train_ratio (float): Ratio of data to use for training (rest goes to validation)
output_dir (str): Directory to save training logs and model checkpoints
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset_paths = dataset_paths
self.learning_rate = learning_rate
self.train_ratio = train_ratio
self.model = model
self.output_dir = output_dir
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Initialize history tracking
self.train_losses = []
self.val_losses = []
self.best_val_loss = float('inf')
self.current_val_loss = float('inf')
# Initialize CSV logging
self.log_file = os.path.join(output_dir, 'training_log.csv')
if not os.path.exists(self.log_file):
with open(self.log_file, 'w') as f:
f.write('epoch,train_loss,val_loss,best_val_loss\n')
# Initialize datasets and loaders
self._initialize_datasets()
# Initialize training parameters
self._initialize_training()
def _freeze_layers(self):
"""
Freeze all layers except those that are part of the decoder or upsampling
We'll assume the last few layers are responsible for upsampling/reconstruction
"""
try:
# Try to identify encoder layers and freeze them
for name, param in self.model.named_parameters():
if 'encoder' in name:
param.requires_grad = False
# Unfreeze certain layers (example: last 3 decoder layers)
# Modify this based on your actual model architecture
for name, param in self.model.named_parameters():
if 'decoder' in name and 'block4' in name or 'block5' in name:
param.requires_grad = True
except Exception as e:
print(f"Warning: Couldn't freeze layers - {str(e)}")
pass
def _initialize_datasets(self):
"""
Helper method to initialize datasets
"""
if isinstance(self.dataset_paths, list):
df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True)
else:
raise ValueError("Invalid dataset_paths format. Must be a list.")
# Split into train and validation sets
df_train, df_val = train_test_split(
df_train,
test_size=1 - self.train_ratio,
random_state=42
)
# Define preprocessing transforms with augmentation
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.RandomResizedCrop(256),
transforms.RandomHorizontalFlip(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]) ])
# Create datasets and dataloaders def __len__(self):
self.train_dataset = ImageDataset(df_train, transform=self.transform) return len(self.df)
self.val_dataset = ImageDataset(df_val, transform=self.transform)
self.train_loader = DataLoader( def __getitem__(self, idx):
self.train_dataset, # Get the byte strings
batch_size=self.batch_size, low_res_bytes = self.df.iloc[idx]['low_res']
shuffle=True, high_res_bytes = self.df.iloc[idx]['high_res']
num_workers=self.num_workers
)
self.val_loader = DataLoader( # Convert bytes to numpy arrays
self.val_dataset, low_res = cv2.imdecode(np.frombuffer(low_res_bytes, np.uint8), -1)
batch_size=self.batch_size, high_res = cv2.imdecode(np.frombuffer(high_res_bytes, np.uint8), -1)
shuffle=False,
num_workers=self.num_workers
) if df_val is not None else None
def _log_metrics(self, epoch: int, train_loss: float, val_loss: float): # Apply augmentation and normalization
""" augmented = self.augmentation(image=low_res, mask=high_res)
Log training metrics to CSV file low_res = augmented['image']
high_res = augmented['mask']
Args: return {
epoch (int): Current epoch number 'low_res': low_res,
train_loss (float): Training loss for the epoch 'high_res': high_res
val_loss (float): Validation loss for the epoch }
"""
with open(self.log_file, 'a') as f:
f.write(f'{epoch},{train_loss:.6f},{val_loss:.6f},{self.best_val_loss:.6f}\n')
def _initialize_training(self):
"""
Helper method to initialize training parameters
"""
# Freeze layers except those we want to finetune
self._freeze_layers()
# Initialize optimizer and scheduler
params_to_optimize = [p for p in self.model.parameters() if p.requires_grad]
if not params_to_optimize:
raise ValueError("No parameters found to optimize")
# Use Adam with weight decay for better regularization
self.optimizer = torch.optim.Adam(
params_to_optimize,
lr=self.learning_rate,
weight_decay=1e-4 # Add L2 regularization
)
# Reduce learning rate when validation loss plateaus
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
factor=0.1, # Multiply LR by this factor on plateau
patience=3, # Number of epochs to wait before reducing LR
verbose=True
)
# Use a combination of L1 and L2 losses for better performance
self.criterion = nn.L1Loss()
self.mse_criterion = nn.MSELoss()
def _train_epoch(self):
"""
Train model for one epoch
Returns:
float: Average training loss for the epoch
"""
self.model.train()
running_loss = 0.0
for batch in tqdm(self.train_loader, desc="Training"):
low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device)
# Forward pass (we'll use the model's existing architecture without adding layers)
try:
features = self.model(low_ress)
print("Features shape:", features.shape) # Check output dimensions
print("High-res shape:", high_ress.shape) # Check target dimensions
except Exception as e:
raise RuntimeError(f"Error during forward pass: {str(e)}")
# Calculate loss with different scaling for L1 and MSE components
l1_loss = self.criterion(features, high_ress) * 0.5
mse_loss = self.mse_criterion(features, high_ress) * 0.5
total_loss = l1_loss + mse_loss
# Backward pass and optimize
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
running_loss += total_loss.item()
epoch_loss = running_loss / len(self.train_loader)
self.train_losses.append(epoch_loss)
print(f"Train Loss: {epoch_loss:.4f}")
return epoch_loss
def _train_epoch(self):
"""Train model for one epoch"""
self.model.train()
running_loss = 0.0
for batch in tqdm(self.train_loader, desc="Training"):
low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device)
# Forward pass
try:
outputs = self.model(low_ress) # Now outputs are images
print("Output shape:", outputs.shape)
print("High-res shape:", high_ress.shape)
except Exception as e:
raise RuntimeError(f"Error during forward pass: {str(e)}")
# Calculate loss
l1_loss = self.criterion(outputs, high_ress) * 0.5
mse_loss = self.mse_criterion(outputs, high_ress) * 0.5
total_loss = l1_loss + mse_loss
# Backward pass and optimize
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
running_loss += total_loss.item()
epoch_loss = running_loss / len(self.train_loader)
self.train_lossess.append(epoch_loss)
print(f"Train Loss: {epoch_loss:.4f}")
return epoch_loss
def train(self, num_epochs: int = 10):
"""
Train the model for specified number of epochs
Args:
num_epochs (int): Number of epochs to train for
"""
self.model.to(self.device)
print(f"Training metrics will be logged to: {self.log_file}")
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
# Train phase
train_loss = self._train_epoch()
# Validation phase
if self.val_loader is not None:
val_loss = self._validate_epoch()
# Update learning rate scheduler based on validation loss
self.scheduler.step(val_loss)
# Log metrics
self._log_metrics(epoch + 1, train_loss, val_loss)
# Save best model based on validation loss
if self.current_val_loss < self.best_val_loss:
print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}")
self.best_val_loss = self.current_val_loss
model_save_path = os.path.join(self.output_dir, "aiuNN-optimized")
self.model.save(model_save_path)
print(f"Model saved to: {model_save_path}")
# After training, save the final model
final_model_path = os.path.join(self.output_dir, "aiuNN-final")
self.model.save(final_model_path)
print(f"\nFinal model saved to: {final_model_path}")
if __name__ == "__main__": def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=8, epochs = 10):
# Load your model first # Initialize dataset and dataloader
config = AIIAConfig.load("/root/vision/AIIA/AIIA-base-512") train_dataset = aiuNNDataset(train_parquet_path)
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"), config=config) val_dataset = aiuNNDataset(val_parquet_path)
trainer = FineTuner( train_loader = torch.utils.data.DataLoader(
model=model, train_dataset,
dataset_paths=[ batch_size=batch_size,
"/root/training_data/vision-dataset/image_upscaler.parquet", shuffle=True,
"/root/training_data/vision-dataset/image_vec_upscaler.parquet" num_workers=4
],
batch_size=2, # Increased batch size
learning_rate=1e-4 # Reduced initial LR
) )
trainer.train(num_epochs=10) # Extended training time val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4
)
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
best_val_loss = float('inf')
for epoch in range(epochs):
model.train()
train_loss = 0.0
for batch_idx, batch in enumerate(train_loader):
low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device)
# Forward pass
outputs = model(low_res)
# Calculate loss
loss = criterion(outputs, high_res.permute(0, 3, 1, 2)) # Adjust for channel dimensions
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
avg_train_loss = train_loss / len(train_loader)
print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}")
# Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
for batch in val_loader:
low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device)
outputs = model(low_res)
loss = criterion(outputs, high_res.permute(0, 3, 1, 2))
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
# Save best model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
model.save("best_model")
return model
def main():
# Paths to your data
train_parquet_path = "/root/training_data/vision-dataset/image_upscaler.parquet"
val_parquet_path = "/root/training_data/vision-dataset/image_vec_upscaler.parquet"
# Load pretrained model
model = AIIA.load("/root/vision/AIIA/AIIA-base-512")
# Add final upsampling layer if needed (depending on your specific architecture)
if hasattr(model, 'chunked_'):
model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear'))
# Fine-tune
finetune_model(
model,
train_parquet_path,
val_parquet_path
)
if __name__ == '__main__':
main()