develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 141 additions and 395 deletions
Showing only changes of commit de6a67cb4e - Show all commits

View File

@ -1,409 +1,155 @@
import torch
import pandas as pd
from PIL import Image, ImageFile
import io
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from aiia.model import AIIABase, AIIA, AIIAConfig
from sklearn.model_selection import train_test_split
from typing import Dict, List, Union, Optional
import base64
from tqdm import tqdm
import numpy as np
import cv2
import os
class ImageDataset(Dataset):
def __init__(self, dataframe, transform=None):
"""
Initialize the dataset with a dataframe containing image data
Args:
dataframe (pd.DataFrame): DataFrame containing 'image_512' and 'image_1024' columns
transform (callable, optional): Optional transform to be applied to both images
"""
self.dataframe = dataframe
self.transform = transform
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
"""
Get a pair of low and high resolution images
Args:
idx (int): Index of the data point
Returns:
dict: Contains 'low_ress' and 'high_ress' PIL images or transformed tensors
"""
row = self.dataframe.iloc[idx]
try:
# Handle both bytes and base64 encoded strings
low_res_data = row['image_512']
high_res_data = row['image_1024']
if isinstance(low_res_data, str):
# Decode base64 string to bytes
low_res_data = base64.b64decode(low_res_data)
high_res_data = base64.b64decode(high_res_data)
# Verify data is valid before creating BytesIO
if not isinstance(low_res_data, bytes) or not isinstance(high_res_data, bytes):
raise ValueError(f"Invalid image data format at index {idx}")
# Create image streams
low_res_stream = io.BytesIO(low_res_data)
high_res_stream = io.BytesIO(high_res_data)
# Enable loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Load and convert images to RGB
low_res_image = Image.open(low_res_stream).convert('RGB')
high_res_image = Image.open(high_res_stream).convert('RGB')
# Create fresh copies for verify() since it modifies the image object
low_res_verify = low_res_image.copy()
high_res_verify = high_res_image.copy()
# Verify images are valid
try:
low_res_verify.verify()
high_res_verify.verify()
except Exception as e:
raise ValueError(f"Image verification failed at index {idx}: {str(e)}")
finally:
low_res_verify.close()
high_res_verify.close()
# Apply transforms if specified
if self.transform:
low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image)
return {
'low_ress': low_res_image, # Note: Using 'low_ress' to match ModelTrainer
'high_ress': high_res_image # Note: Using 'high_ress' to match ModelTrainer
}
except Exception as e:
raise RuntimeError(f"Error loading images at index {idx}: {str(e)}")
finally:
# Ensure streams are closed
if 'low_res_stream' in locals():
low_res_stream.close()
if 'high_res_stream' in locals():
high_res_stream.close()
class SuperResolutionModel(AIIA):
def __init__(self, base_model: AIIA, config: AIIAConfig):
super(SuperResolutionModel, self).__init__(config=config)
# Use base model as encoder
self.encoder = base_model
for param in self.encoder.parameters():
param.requires_grad = False # Freeze encoder layers
# Add decoder layers to reconstruct image
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(128, 3, kernel_size=3, padding=1)
from albumentations import (
Compose, Resize, Normalize, RandomBrightnessContrast,
HorizontalFlip, VerticalFlip, Rotate, GaussianBlur
)
from albumentations.pytorch import ToTensorV2
from torch import nn
# Import the model and config from your existing code
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
def forward(self, x):
features = self.encoder(x)
output = self.decoder(features)
return output
class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path, config=None):
# Read the Parquet file
self.df = pd.read_parquet(parquet_path)
class FineTuner:
def __init__(self,
model: AIIA,
dataset_paths: List[str],
batch_size: int = 32,
learning_rate: float = 0.001,
num_workers: int = 4,
train_ratio: float = 0.8,
output_dir: str = "./training_logs"):
"""
Specialized trainer for image super resolution tasks
Args:
model (nn.Module): Model instance to finetune
dataset_paths (List[str]): Paths to datasets
batch_size (int): Batch size for training
learning_rate (float): Learning rate for optimizer
num_workers (int): Number of workers for data loading
train_ratio (float): Ratio of data to use for training (rest goes to validation)
output_dir (str): Directory to save training logs and model checkpoints
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset_paths = dataset_paths
self.learning_rate = learning_rate
self.train_ratio = train_ratio
self.model = model
self.output_dir = output_dir
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Initialize history tracking
self.train_losses = []
self.val_losses = []
self.best_val_loss = float('inf')
self.current_val_loss = float('inf')
# Initialize CSV logging
self.log_file = os.path.join(output_dir, 'training_log.csv')
if not os.path.exists(self.log_file):
with open(self.log_file, 'w') as f:
f.write('epoch,train_loss,val_loss,best_val_loss\n')
# Initialize datasets and loaders
self._initialize_datasets()
# Initialize training parameters
self._initialize_training()
def _freeze_layers(self):
"""
Freeze all layers except those that are part of the decoder or upsampling
We'll assume the last few layers are responsible for upsampling/reconstruction
"""
try:
# Try to identify encoder layers and freeze them
for name, param in self.model.named_parameters():
if 'encoder' in name:
param.requires_grad = False
# Unfreeze certain layers (example: last 3 decoder layers)
# Modify this based on your actual model architecture
for name, param in self.model.named_parameters():
if 'decoder' in name and 'block4' in name or 'block5' in name:
param.requires_grad = True
except Exception as e:
print(f"Warning: Couldn't freeze layers - {str(e)}")
pass
def _initialize_datasets(self):
"""
Helper method to initialize datasets
"""
if isinstance(self.dataset_paths, list):
df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True)
else:
raise ValueError("Invalid dataset_paths format. Must be a list.")
# Split into train and validation sets
df_train, df_val = train_test_split(
df_train,
test_size=1 - self.train_ratio,
random_state=42
)
# Define preprocessing transforms with augmentation
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.RandomResizedCrop(256),
transforms.RandomHorizontalFlip(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
# Data augmentation pipeline
self.augmentation = Compose([
Resize(height=512, width=512),
RandomBrightnessContrast(),
HorizontalFlip(p=0.5),
VerticalFlip(p=0.5),
Rotate(limit=45),
GaussianBlur(p=0.3),
Normalize(mean=[0.5], std=[0.5]),
ToTensorV2()
])
# Create datasets and dataloaders
self.train_dataset = ImageDataset(df_train, transform=self.transform)
self.val_dataset = ImageDataset(df_val, transform=self.transform)
def __len__(self):
return len(self.df)
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size,
def __getitem__(self, idx):
# Get the byte strings
low_res_bytes = self.df.iloc[idx]['low_res']
high_res_bytes = self.df.iloc[idx]['high_res']
# Convert bytes to numpy arrays
low_res = cv2.imdecode(np.frombuffer(low_res_bytes, np.uint8), -1)
high_res = cv2.imdecode(np.frombuffer(high_res_bytes, np.uint8), -1)
# Apply augmentation and normalization
augmented = self.augmentation(image=low_res, mask=high_res)
low_res = augmented['image']
high_res = augmented['mask']
return {
'low_res': low_res,
'high_res': high_res
}
def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=8, epochs = 10):
# Initialize dataset and dataloader
train_dataset = aiuNNDataset(train_parquet_path)
val_dataset = aiuNNDataset(val_parquet_path)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=self.num_workers
num_workers=4
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=self.batch_size,
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=self.num_workers
) if df_val is not None else None
def _log_metrics(self, epoch: int, train_loss: float, val_loss: float):
"""
Log training metrics to CSV file
Args:
epoch (int): Current epoch number
train_loss (float): Training loss for the epoch
val_loss (float): Validation loss for the epoch
"""
with open(self.log_file, 'a') as f:
f.write(f'{epoch},{train_loss:.6f},{val_loss:.6f},{self.best_val_loss:.6f}\n')
def _initialize_training(self):
"""
Helper method to initialize training parameters
"""
# Freeze layers except those we want to finetune
self._freeze_layers()
# Initialize optimizer and scheduler
params_to_optimize = [p for p in self.model.parameters() if p.requires_grad]
if not params_to_optimize:
raise ValueError("No parameters found to optimize")
# Use Adam with weight decay for better regularization
self.optimizer = torch.optim.Adam(
params_to_optimize,
lr=self.learning_rate,
weight_decay=1e-4 # Add L2 regularization
num_workers=4
)
# Reduce learning rate when validation loss plateaus
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
factor=0.1, # Multiply LR by this factor on plateau
patience=3, # Number of epochs to wait before reducing LR
verbose=True
)
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
# Use a combination of L1 and L2 losses for better performance
self.criterion = nn.L1Loss()
self.mse_criterion = nn.MSELoss()
# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
def _train_epoch(self):
"""
Train model for one epoch
Returns:
float: Average training loss for the epoch
"""
self.model.train()
running_loss = 0.0
best_val_loss = float('inf')
for batch in tqdm(self.train_loader, desc="Training"):
low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device)
for epoch in range(epochs):
model.train()
# Forward pass (we'll use the model's existing architecture without adding layers)
try:
features = self.model(low_ress)
print("Features shape:", features.shape) # Check output dimensions
print("High-res shape:", high_ress.shape) # Check target dimensions
except Exception as e:
raise RuntimeError(f"Error during forward pass: {str(e)}")
train_loss = 0.0
# Calculate loss with different scaling for L1 and MSE components
l1_loss = self.criterion(features, high_ress) * 0.5
mse_loss = self.mse_criterion(features, high_ress) * 0.5
total_loss = l1_loss + mse_loss
# Backward pass and optimize
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
running_loss += total_loss.item()
epoch_loss = running_loss / len(self.train_loader)
self.train_losses.append(epoch_loss)
print(f"Train Loss: {epoch_loss:.4f}")
return epoch_loss
def _train_epoch(self):
"""Train model for one epoch"""
self.model.train()
running_loss = 0.0
for batch in tqdm(self.train_loader, desc="Training"):
low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device)
for batch_idx, batch in enumerate(train_loader):
low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device)
# Forward pass
try:
outputs = self.model(low_ress) # Now outputs are images
print("Output shape:", outputs.shape)
print("High-res shape:", high_ress.shape)
except Exception as e:
raise RuntimeError(f"Error during forward pass: {str(e)}")
outputs = model(low_res)
# Calculate loss
l1_loss = self.criterion(outputs, high_ress) * 0.5
mse_loss = self.mse_criterion(outputs, high_ress) * 0.5
total_loss = l1_loss + mse_loss
loss = criterion(outputs, high_res.permute(0, 3, 1, 2)) # Adjust for channel dimensions
# Backward pass and optimize
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += total_loss.item()
train_loss += loss.item()
epoch_loss = running_loss / len(self.train_loader)
self.train_lossess.append(epoch_loss)
print(f"Train Loss: {epoch_loss:.4f}")
return epoch_loss
avg_train_loss = train_loss / len(train_loader)
def train(self, num_epochs: int = 10):
"""
Train the model for specified number of epochs
Args:
num_epochs (int): Number of epochs to train for
"""
self.model.to(self.device)
print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}")
print(f"Training metrics will be logged to: {self.log_file}")
# Validation
model.eval()
val_loss = 0.0
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
with torch.no_grad():
for batch in val_loader:
low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device)
# Train phase
train_loss = self._train_epoch()
outputs = model(low_res)
loss = criterion(outputs, high_res.permute(0, 3, 1, 2))
# Validation phase
if self.val_loader is not None:
val_loss = self._validate_epoch()
val_loss += loss.item()
# Update learning rate scheduler based on validation loss
self.scheduler.step(val_loss)
avg_val_loss = val_loss / len(val_loader)
# Log metrics
self._log_metrics(epoch + 1, train_loss, val_loss)
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
# Save best model based on validation loss
if self.current_val_loss < self.best_val_loss:
print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}")
self.best_val_loss = self.current_val_loss
model_save_path = os.path.join(self.output_dir, "aiuNN-optimized")
self.model.save(model_save_path)
print(f"Model saved to: {model_save_path}")
# Save best model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
model.save("best_model")
# After training, save the final model
final_model_path = os.path.join(self.output_dir, "aiuNN-final")
self.model.save(final_model_path)
print(f"\nFinal model saved to: {final_model_path}")
return model
def main():
# Paths to your data
train_parquet_path = "/root/training_data/vision-dataset/image_upscaler.parquet"
val_parquet_path = "/root/training_data/vision-dataset/image_vec_upscaler.parquet"
if __name__ == "__main__":
# Load your model first
config = AIIAConfig.load("/root/vision/AIIA/AIIA-base-512")
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"), config=config)
# Load pretrained model
model = AIIA.load("/root/vision/AIIA/AIIA-base-512")
trainer = FineTuner(
model=model,
dataset_paths=[
"/root/training_data/vision-dataset/image_upscaler.parquet",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
],
batch_size=2, # Increased batch size
learning_rate=1e-4 # Reduced initial LR
# Add final upsampling layer if needed (depending on your specific architecture)
if hasattr(model, 'chunked_'):
model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear'))
# Fine-tune
finetune_model(
model,
train_parquet_path,
val_parquet_path
)
trainer.train(num_epochs=10) # Extended training time
if __name__ == '__main__':
main()