develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 141 additions and 395 deletions
Showing only changes of commit de6a67cb4e - Show all commits

View File

@ -1,409 +1,155 @@
import torch import torch
import pandas as pd import pandas as pd
from PIL import Image, ImageFile import numpy as np
import io import cv2
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from aiia.model import AIIABase, AIIA, AIIAConfig
from sklearn.model_selection import train_test_split
from typing import Dict, List, Union, Optional
import base64
from tqdm import tqdm
import os import os
from albumentations import (
Compose, Resize, Normalize, RandomBrightnessContrast,
HorizontalFlip, VerticalFlip, Rotate, GaussianBlur
)
from albumentations.pytorch import ToTensorV2
from torch import nn
# Import the model and config from your existing code
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path, config=None):
# Read the Parquet file
self.df = pd.read_parquet(parquet_path)
class ImageDataset(Dataset): # Data augmentation pipeline
def __init__(self, dataframe, transform=None): self.augmentation = Compose([
""" Resize(height=512, width=512),
Initialize the dataset with a dataframe containing image data RandomBrightnessContrast(),
HorizontalFlip(p=0.5),
Args: VerticalFlip(p=0.5),
dataframe (pd.DataFrame): DataFrame containing 'image_512' and 'image_1024' columns Rotate(limit=45),
transform (callable, optional): Optional transform to be applied to both images GaussianBlur(p=0.3),
""" Normalize(mean=[0.5], std=[0.5]),
self.dataframe = dataframe ToTensorV2()
self.transform = transform
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
"""
Get a pair of low and high resolution images
Args:
idx (int): Index of the data point
Returns:
dict: Contains 'low_ress' and 'high_ress' PIL images or transformed tensors
"""
row = self.dataframe.iloc[idx]
try:
# Handle both bytes and base64 encoded strings
low_res_data = row['image_512']
high_res_data = row['image_1024']
if isinstance(low_res_data, str):
# Decode base64 string to bytes
low_res_data = base64.b64decode(low_res_data)
high_res_data = base64.b64decode(high_res_data)
# Verify data is valid before creating BytesIO
if not isinstance(low_res_data, bytes) or not isinstance(high_res_data, bytes):
raise ValueError(f"Invalid image data format at index {idx}")
# Create image streams
low_res_stream = io.BytesIO(low_res_data)
high_res_stream = io.BytesIO(high_res_data)
# Enable loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Load and convert images to RGB
low_res_image = Image.open(low_res_stream).convert('RGB')
high_res_image = Image.open(high_res_stream).convert('RGB')
# Create fresh copies for verify() since it modifies the image object
low_res_verify = low_res_image.copy()
high_res_verify = high_res_image.copy()
# Verify images are valid
try:
low_res_verify.verify()
high_res_verify.verify()
except Exception as e:
raise ValueError(f"Image verification failed at index {idx}: {str(e)}")
finally:
low_res_verify.close()
high_res_verify.close()
# Apply transforms if specified
if self.transform:
low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image)
return {
'low_ress': low_res_image, # Note: Using 'low_ress' to match ModelTrainer
'high_ress': high_res_image # Note: Using 'high_ress' to match ModelTrainer
}
except Exception as e:
raise RuntimeError(f"Error loading images at index {idx}: {str(e)}")
finally:
# Ensure streams are closed
if 'low_res_stream' in locals():
low_res_stream.close()
if 'high_res_stream' in locals():
high_res_stream.close()
class SuperResolutionModel(AIIA):
def __init__(self, base_model: AIIA, config: AIIAConfig):
super(SuperResolutionModel, self).__init__(config=config)
# Use base model as encoder
self.encoder = base_model
for param in self.encoder.parameters():
param.requires_grad = False # Freeze encoder layers
# Add decoder layers to reconstruct image
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(128, 3, kernel_size=3, padding=1)
)
def forward(self, x):
features = self.encoder(x)
output = self.decoder(features)
return output
class FineTuner:
def __init__(self,
model: AIIA,
dataset_paths: List[str],
batch_size: int = 32,
learning_rate: float = 0.001,
num_workers: int = 4,
train_ratio: float = 0.8,
output_dir: str = "./training_logs"):
"""
Specialized trainer for image super resolution tasks
Args:
model (nn.Module): Model instance to finetune
dataset_paths (List[str]): Paths to datasets
batch_size (int): Batch size for training
learning_rate (float): Learning rate for optimizer
num_workers (int): Number of workers for data loading
train_ratio (float): Ratio of data to use for training (rest goes to validation)
output_dir (str): Directory to save training logs and model checkpoints
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset_paths = dataset_paths
self.learning_rate = learning_rate
self.train_ratio = train_ratio
self.model = model
self.output_dir = output_dir
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Initialize history tracking
self.train_losses = []
self.val_losses = []
self.best_val_loss = float('inf')
self.current_val_loss = float('inf')
# Initialize CSV logging
self.log_file = os.path.join(output_dir, 'training_log.csv')
if not os.path.exists(self.log_file):
with open(self.log_file, 'w') as f:
f.write('epoch,train_loss,val_loss,best_val_loss\n')
# Initialize datasets and loaders
self._initialize_datasets()
# Initialize training parameters
self._initialize_training()
def _freeze_layers(self):
"""
Freeze all layers except those that are part of the decoder or upsampling
We'll assume the last few layers are responsible for upsampling/reconstruction
"""
try:
# Try to identify encoder layers and freeze them
for name, param in self.model.named_parameters():
if 'encoder' in name:
param.requires_grad = False
# Unfreeze certain layers (example: last 3 decoder layers)
# Modify this based on your actual model architecture
for name, param in self.model.named_parameters():
if 'decoder' in name and 'block4' in name or 'block5' in name:
param.requires_grad = True
except Exception as e:
print(f"Warning: Couldn't freeze layers - {str(e)}")
pass
def _initialize_datasets(self):
"""
Helper method to initialize datasets
"""
if isinstance(self.dataset_paths, list):
df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True)
else:
raise ValueError("Invalid dataset_paths format. Must be a list.")
# Split into train and validation sets
df_train, df_val = train_test_split(
df_train,
test_size=1 - self.train_ratio,
random_state=42
)
# Define preprocessing transforms with augmentation
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.RandomResizedCrop(256),
transforms.RandomHorizontalFlip(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]) ])
# Create datasets and dataloaders def __len__(self):
self.train_dataset = ImageDataset(df_train, transform=self.transform) return len(self.df)
self.val_dataset = ImageDataset(df_val, transform=self.transform)
self.train_loader = DataLoader( def __getitem__(self, idx):
self.train_dataset, # Get the byte strings
batch_size=self.batch_size, low_res_bytes = self.df.iloc[idx]['low_res']
high_res_bytes = self.df.iloc[idx]['high_res']
# Convert bytes to numpy arrays
low_res = cv2.imdecode(np.frombuffer(low_res_bytes, np.uint8), -1)
high_res = cv2.imdecode(np.frombuffer(high_res_bytes, np.uint8), -1)
# Apply augmentation and normalization
augmented = self.augmentation(image=low_res, mask=high_res)
low_res = augmented['image']
high_res = augmented['mask']
return {
'low_res': low_res,
'high_res': high_res
}
def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=8, epochs = 10):
# Initialize dataset and dataloader
train_dataset = aiuNNDataset(train_parquet_path)
val_dataset = aiuNNDataset(val_parquet_path)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True, shuffle=True,
num_workers=self.num_workers num_workers=4
) )
self.val_loader = DataLoader( val_loader = torch.utils.data.DataLoader(
self.val_dataset, val_dataset,
batch_size=self.batch_size, batch_size=batch_size,
shuffle=False, shuffle=False,
num_workers=self.num_workers num_workers=4
) if df_val is not None else None
def _log_metrics(self, epoch: int, train_loss: float, val_loss: float):
"""
Log training metrics to CSV file
Args:
epoch (int): Current epoch number
train_loss (float): Training loss for the epoch
val_loss (float): Validation loss for the epoch
"""
with open(self.log_file, 'a') as f:
f.write(f'{epoch},{train_loss:.6f},{val_loss:.6f},{self.best_val_loss:.6f}\n')
def _initialize_training(self):
"""
Helper method to initialize training parameters
"""
# Freeze layers except those we want to finetune
self._freeze_layers()
# Initialize optimizer and scheduler
params_to_optimize = [p for p in self.model.parameters() if p.requires_grad]
if not params_to_optimize:
raise ValueError("No parameters found to optimize")
# Use Adam with weight decay for better regularization
self.optimizer = torch.optim.Adam(
params_to_optimize,
lr=self.learning_rate,
weight_decay=1e-4 # Add L2 regularization
) )
# Reduce learning rate when validation loss plateaus # Set device
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.optimizer, model.to(device)
factor=0.1, # Multiply LR by this factor on plateau
patience=3, # Number of epochs to wait before reducing LR
verbose=True
)
# Use a combination of L1 and L2 losses for better performance # Define loss function and optimizer
self.criterion = nn.L1Loss() criterion = nn.MSELoss()
self.mse_criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
def _train_epoch(self): best_val_loss = float('inf')
"""
Train model for one epoch
Returns:
float: Average training loss for the epoch
"""
self.model.train()
running_loss = 0.0
for batch in tqdm(self.train_loader, desc="Training"): for epoch in range(epochs):
low_ress = batch['low_ress'].to(self.device) model.train()
high_ress = batch['high_ress'].to(self.device)
# Forward pass (we'll use the model's existing architecture without adding layers) train_loss = 0.0
try:
features = self.model(low_ress)
print("Features shape:", features.shape) # Check output dimensions
print("High-res shape:", high_ress.shape) # Check target dimensions
except Exception as e:
raise RuntimeError(f"Error during forward pass: {str(e)}")
# Calculate loss with different scaling for L1 and MSE components for batch_idx, batch in enumerate(train_loader):
l1_loss = self.criterion(features, high_ress) * 0.5 low_res = batch['low_res'].to(device)
mse_loss = self.mse_criterion(features, high_ress) * 0.5 high_res = batch['high_res'].to(device)
total_loss = l1_loss + mse_loss
# Backward pass and optimize
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
running_loss += total_loss.item()
epoch_loss = running_loss / len(self.train_loader)
self.train_losses.append(epoch_loss)
print(f"Train Loss: {epoch_loss:.4f}")
return epoch_loss
def _train_epoch(self):
"""Train model for one epoch"""
self.model.train()
running_loss = 0.0
for batch in tqdm(self.train_loader, desc="Training"):
low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device)
# Forward pass # Forward pass
try: outputs = model(low_res)
outputs = self.model(low_ress) # Now outputs are images
print("Output shape:", outputs.shape)
print("High-res shape:", high_ress.shape)
except Exception as e:
raise RuntimeError(f"Error during forward pass: {str(e)}")
# Calculate loss # Calculate loss
l1_loss = self.criterion(outputs, high_ress) * 0.5 loss = criterion(outputs, high_res.permute(0, 3, 1, 2)) # Adjust for channel dimensions
mse_loss = self.mse_criterion(outputs, high_ress) * 0.5
total_loss = l1_loss + mse_loss
# Backward pass and optimize # Backward pass and optimize
self.optimizer.zero_grad() optimizer.zero_grad()
total_loss.backward() loss.backward()
self.optimizer.step() optimizer.step()
running_loss += total_loss.item() train_loss += loss.item()
epoch_loss = running_loss / len(self.train_loader) avg_train_loss = train_loss / len(train_loader)
self.train_lossess.append(epoch_loss)
print(f"Train Loss: {epoch_loss:.4f}")
return epoch_loss
def train(self, num_epochs: int = 10): print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}")
"""
Train the model for specified number of epochs
Args:
num_epochs (int): Number of epochs to train for
"""
self.model.to(self.device)
print(f"Training metrics will be logged to: {self.log_file}") # Validation
model.eval()
val_loss = 0.0
for epoch in range(num_epochs): with torch.no_grad():
print(f"\nEpoch {epoch+1}/{num_epochs}") for batch in val_loader:
low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device)
# Train phase outputs = model(low_res)
train_loss = self._train_epoch() loss = criterion(outputs, high_res.permute(0, 3, 1, 2))
# Validation phase val_loss += loss.item()
if self.val_loader is not None:
val_loss = self._validate_epoch()
# Update learning rate scheduler based on validation loss avg_val_loss = val_loss / len(val_loader)
self.scheduler.step(val_loss)
# Log metrics print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
self._log_metrics(epoch + 1, train_loss, val_loss)
# Save best model based on validation loss # Save best model
if self.current_val_loss < self.best_val_loss: if avg_val_loss < best_val_loss:
print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}") best_val_loss = avg_val_loss
self.best_val_loss = self.current_val_loss model.save("best_model")
model_save_path = os.path.join(self.output_dir, "aiuNN-optimized")
self.model.save(model_save_path)
print(f"Model saved to: {model_save_path}")
# After training, save the final model return model
final_model_path = os.path.join(self.output_dir, "aiuNN-final")
self.model.save(final_model_path)
print(f"\nFinal model saved to: {final_model_path}")
def main():
# Paths to your data
train_parquet_path = "/root/training_data/vision-dataset/image_upscaler.parquet"
val_parquet_path = "/root/training_data/vision-dataset/image_vec_upscaler.parquet"
if __name__ == "__main__": # Load pretrained model
# Load your model first model = AIIA.load("/root/vision/AIIA/AIIA-base-512")
config = AIIAConfig.load("/root/vision/AIIA/AIIA-base-512")
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"), config=config)
trainer = FineTuner( # Add final upsampling layer if needed (depending on your specific architecture)
model=model, if hasattr(model, 'chunked_'):
dataset_paths=[ model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear'))
"/root/training_data/vision-dataset/image_upscaler.parquet",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet" # Fine-tune
], finetune_model(
batch_size=2, # Increased batch size model,
learning_rate=1e-4 # Reduced initial LR train_parquet_path,
val_parquet_path
) )
trainer.train(num_epochs=10) # Extended training time if __name__ == '__main__':
main()