develop #4
|
@ -1,409 +1,155 @@
|
|||
import torch
|
||||
import pandas as pd
|
||||
from PIL import Image, ImageFile
|
||||
import io
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torchvision.transforms as transforms
|
||||
from aiia.model import AIIABase, AIIA, AIIAConfig
|
||||
from sklearn.model_selection import train_test_split
|
||||
from typing import Dict, List, Union, Optional
|
||||
import base64
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import cv2
|
||||
import os
|
||||
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(self, dataframe, transform=None):
|
||||
"""
|
||||
Initialize the dataset with a dataframe containing image data
|
||||
|
||||
Args:
|
||||
dataframe (pd.DataFrame): DataFrame containing 'image_512' and 'image_1024' columns
|
||||
transform (callable, optional): Optional transform to be applied to both images
|
||||
"""
|
||||
self.dataframe = dataframe
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataframe)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Get a pair of low and high resolution images
|
||||
|
||||
Args:
|
||||
idx (int): Index of the data point
|
||||
|
||||
Returns:
|
||||
dict: Contains 'low_ress' and 'high_ress' PIL images or transformed tensors
|
||||
"""
|
||||
row = self.dataframe.iloc[idx]
|
||||
|
||||
try:
|
||||
# Handle both bytes and base64 encoded strings
|
||||
low_res_data = row['image_512']
|
||||
high_res_data = row['image_1024']
|
||||
|
||||
if isinstance(low_res_data, str):
|
||||
# Decode base64 string to bytes
|
||||
low_res_data = base64.b64decode(low_res_data)
|
||||
high_res_data = base64.b64decode(high_res_data)
|
||||
|
||||
# Verify data is valid before creating BytesIO
|
||||
if not isinstance(low_res_data, bytes) or not isinstance(high_res_data, bytes):
|
||||
raise ValueError(f"Invalid image data format at index {idx}")
|
||||
|
||||
# Create image streams
|
||||
low_res_stream = io.BytesIO(low_res_data)
|
||||
high_res_stream = io.BytesIO(high_res_data)
|
||||
|
||||
# Enable loading of truncated images
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
# Load and convert images to RGB
|
||||
low_res_image = Image.open(low_res_stream).convert('RGB')
|
||||
high_res_image = Image.open(high_res_stream).convert('RGB')
|
||||
|
||||
# Create fresh copies for verify() since it modifies the image object
|
||||
low_res_verify = low_res_image.copy()
|
||||
high_res_verify = high_res_image.copy()
|
||||
|
||||
# Verify images are valid
|
||||
try:
|
||||
low_res_verify.verify()
|
||||
high_res_verify.verify()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Image verification failed at index {idx}: {str(e)}")
|
||||
finally:
|
||||
low_res_verify.close()
|
||||
high_res_verify.close()
|
||||
|
||||
# Apply transforms if specified
|
||||
if self.transform:
|
||||
low_res_image = self.transform(low_res_image)
|
||||
high_res_image = self.transform(high_res_image)
|
||||
|
||||
return {
|
||||
'low_ress': low_res_image, # Note: Using 'low_ress' to match ModelTrainer
|
||||
'high_ress': high_res_image # Note: Using 'high_ress' to match ModelTrainer
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error loading images at index {idx}: {str(e)}")
|
||||
|
||||
finally:
|
||||
# Ensure streams are closed
|
||||
if 'low_res_stream' in locals():
|
||||
low_res_stream.close()
|
||||
if 'high_res_stream' in locals():
|
||||
high_res_stream.close()
|
||||
|
||||
class SuperResolutionModel(AIIA):
|
||||
def __init__(self, base_model: AIIA, config: AIIAConfig):
|
||||
super(SuperResolutionModel, self).__init__(config=config)
|
||||
# Use base model as encoder
|
||||
self.encoder = base_model
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False # Freeze encoder layers
|
||||
|
||||
# Add decoder layers to reconstruct image
|
||||
self.decoder = nn.Sequential(
|
||||
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.ConvTranspose2d(128, 3, kernel_size=3, padding=1)
|
||||
from albumentations import (
|
||||
Compose, Resize, Normalize, RandomBrightnessContrast,
|
||||
HorizontalFlip, VerticalFlip, Rotate, GaussianBlur
|
||||
)
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from torch import nn
|
||||
# Import the model and config from your existing code
|
||||
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
|
||||
|
||||
def forward(self, x):
|
||||
features = self.encoder(x)
|
||||
output = self.decoder(features)
|
||||
return output
|
||||
class aiuNNDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, parquet_path, config=None):
|
||||
# Read the Parquet file
|
||||
self.df = pd.read_parquet(parquet_path)
|
||||
|
||||
class FineTuner:
|
||||
def __init__(self,
|
||||
model: AIIA,
|
||||
dataset_paths: List[str],
|
||||
batch_size: int = 32,
|
||||
learning_rate: float = 0.001,
|
||||
num_workers: int = 4,
|
||||
train_ratio: float = 0.8,
|
||||
output_dir: str = "./training_logs"):
|
||||
"""
|
||||
Specialized trainer for image super resolution tasks
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model instance to finetune
|
||||
dataset_paths (List[str]): Paths to datasets
|
||||
batch_size (int): Batch size for training
|
||||
learning_rate (float): Learning rate for optimizer
|
||||
num_workers (int): Number of workers for data loading
|
||||
train_ratio (float): Ratio of data to use for training (rest goes to validation)
|
||||
output_dir (str): Directory to save training logs and model checkpoints
|
||||
"""
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.dataset_paths = dataset_paths
|
||||
self.learning_rate = learning_rate
|
||||
self.train_ratio = train_ratio
|
||||
self.model = model
|
||||
self.output_dir = output_dir
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Initialize history tracking
|
||||
self.train_losses = []
|
||||
self.val_losses = []
|
||||
self.best_val_loss = float('inf')
|
||||
self.current_val_loss = float('inf')
|
||||
|
||||
# Initialize CSV logging
|
||||
self.log_file = os.path.join(output_dir, 'training_log.csv')
|
||||
if not os.path.exists(self.log_file):
|
||||
with open(self.log_file, 'w') as f:
|
||||
f.write('epoch,train_loss,val_loss,best_val_loss\n')
|
||||
|
||||
# Initialize datasets and loaders
|
||||
self._initialize_datasets()
|
||||
|
||||
# Initialize training parameters
|
||||
self._initialize_training()
|
||||
|
||||
def _freeze_layers(self):
|
||||
"""
|
||||
Freeze all layers except those that are part of the decoder or upsampling
|
||||
We'll assume the last few layers are responsible for upsampling/reconstruction
|
||||
"""
|
||||
try:
|
||||
# Try to identify encoder layers and freeze them
|
||||
for name, param in self.model.named_parameters():
|
||||
if 'encoder' in name:
|
||||
param.requires_grad = False
|
||||
|
||||
# Unfreeze certain layers (example: last 3 decoder layers)
|
||||
# Modify this based on your actual model architecture
|
||||
for name, param in self.model.named_parameters():
|
||||
if 'decoder' in name and 'block4' in name or 'block5' in name:
|
||||
param.requires_grad = True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Couldn't freeze layers - {str(e)}")
|
||||
pass
|
||||
|
||||
def _initialize_datasets(self):
|
||||
"""
|
||||
Helper method to initialize datasets
|
||||
"""
|
||||
if isinstance(self.dataset_paths, list):
|
||||
df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True)
|
||||
else:
|
||||
raise ValueError("Invalid dataset_paths format. Must be a list.")
|
||||
|
||||
# Split into train and validation sets
|
||||
df_train, df_val = train_test_split(
|
||||
df_train,
|
||||
test_size=1 - self.train_ratio,
|
||||
random_state=42
|
||||
)
|
||||
|
||||
# Define preprocessing transforms with augmentation
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.RandomResizedCrop(256),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
# Data augmentation pipeline
|
||||
self.augmentation = Compose([
|
||||
Resize(height=512, width=512),
|
||||
RandomBrightnessContrast(),
|
||||
HorizontalFlip(p=0.5),
|
||||
VerticalFlip(p=0.5),
|
||||
Rotate(limit=45),
|
||||
GaussianBlur(p=0.3),
|
||||
Normalize(mean=[0.5], std=[0.5]),
|
||||
ToTensorV2()
|
||||
])
|
||||
|
||||
# Create datasets and dataloaders
|
||||
self.train_dataset = ImageDataset(df_train, transform=self.transform)
|
||||
self.val_dataset = ImageDataset(df_val, transform=self.transform)
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
self.train_loader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
def __getitem__(self, idx):
|
||||
# Get the byte strings
|
||||
low_res_bytes = self.df.iloc[idx]['low_res']
|
||||
high_res_bytes = self.df.iloc[idx]['high_res']
|
||||
|
||||
# Convert bytes to numpy arrays
|
||||
low_res = cv2.imdecode(np.frombuffer(low_res_bytes, np.uint8), -1)
|
||||
high_res = cv2.imdecode(np.frombuffer(high_res_bytes, np.uint8), -1)
|
||||
|
||||
# Apply augmentation and normalization
|
||||
augmented = self.augmentation(image=low_res, mask=high_res)
|
||||
low_res = augmented['image']
|
||||
high_res = augmented['mask']
|
||||
|
||||
return {
|
||||
'low_res': low_res,
|
||||
'high_res': high_res
|
||||
}
|
||||
|
||||
|
||||
def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=8, epochs = 10):
|
||||
# Initialize dataset and dataloader
|
||||
train_dataset = aiuNNDataset(train_parquet_path)
|
||||
val_dataset = aiuNNDataset(val_parquet_path)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=self.num_workers
|
||||
num_workers=4
|
||||
)
|
||||
|
||||
self.val_loader = DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size=self.batch_size,
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers
|
||||
) if df_val is not None else None
|
||||
|
||||
def _log_metrics(self, epoch: int, train_loss: float, val_loss: float):
|
||||
"""
|
||||
Log training metrics to CSV file
|
||||
|
||||
Args:
|
||||
epoch (int): Current epoch number
|
||||
train_loss (float): Training loss for the epoch
|
||||
val_loss (float): Validation loss for the epoch
|
||||
"""
|
||||
with open(self.log_file, 'a') as f:
|
||||
f.write(f'{epoch},{train_loss:.6f},{val_loss:.6f},{self.best_val_loss:.6f}\n')
|
||||
|
||||
def _initialize_training(self):
|
||||
"""
|
||||
Helper method to initialize training parameters
|
||||
"""
|
||||
# Freeze layers except those we want to finetune
|
||||
self._freeze_layers()
|
||||
|
||||
# Initialize optimizer and scheduler
|
||||
params_to_optimize = [p for p in self.model.parameters() if p.requires_grad]
|
||||
|
||||
if not params_to_optimize:
|
||||
raise ValueError("No parameters found to optimize")
|
||||
|
||||
# Use Adam with weight decay for better regularization
|
||||
self.optimizer = torch.optim.Adam(
|
||||
params_to_optimize,
|
||||
lr=self.learning_rate,
|
||||
weight_decay=1e-4 # Add L2 regularization
|
||||
num_workers=4
|
||||
)
|
||||
|
||||
# Reduce learning rate when validation loss plateaus
|
||||
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.optimizer,
|
||||
factor=0.1, # Multiply LR by this factor on plateau
|
||||
patience=3, # Number of epochs to wait before reducing LR
|
||||
verbose=True
|
||||
)
|
||||
# Set device
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model.to(device)
|
||||
|
||||
# Use a combination of L1 and L2 losses for better performance
|
||||
self.criterion = nn.L1Loss()
|
||||
self.mse_criterion = nn.MSELoss()
|
||||
# Define loss function and optimizer
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
|
||||
|
||||
def _train_epoch(self):
|
||||
"""
|
||||
Train model for one epoch
|
||||
Returns:
|
||||
float: Average training loss for the epoch
|
||||
"""
|
||||
self.model.train()
|
||||
running_loss = 0.0
|
||||
best_val_loss = float('inf')
|
||||
|
||||
for batch in tqdm(self.train_loader, desc="Training"):
|
||||
low_ress = batch['low_ress'].to(self.device)
|
||||
high_ress = batch['high_ress'].to(self.device)
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
|
||||
# Forward pass (we'll use the model's existing architecture without adding layers)
|
||||
try:
|
||||
features = self.model(low_ress)
|
||||
print("Features shape:", features.shape) # Check output dimensions
|
||||
print("High-res shape:", high_ress.shape) # Check target dimensions
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error during forward pass: {str(e)}")
|
||||
train_loss = 0.0
|
||||
|
||||
# Calculate loss with different scaling for L1 and MSE components
|
||||
l1_loss = self.criterion(features, high_ress) * 0.5
|
||||
mse_loss = self.mse_criterion(features, high_ress) * 0.5
|
||||
total_loss = l1_loss + mse_loss
|
||||
|
||||
# Backward pass and optimize
|
||||
self.optimizer.zero_grad()
|
||||
total_loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
running_loss += total_loss.item()
|
||||
|
||||
epoch_loss = running_loss / len(self.train_loader)
|
||||
self.train_losses.append(epoch_loss)
|
||||
print(f"Train Loss: {epoch_loss:.4f}")
|
||||
return epoch_loss
|
||||
|
||||
def _train_epoch(self):
|
||||
"""Train model for one epoch"""
|
||||
self.model.train()
|
||||
running_loss = 0.0
|
||||
|
||||
for batch in tqdm(self.train_loader, desc="Training"):
|
||||
low_ress = batch['low_ress'].to(self.device)
|
||||
high_ress = batch['high_ress'].to(self.device)
|
||||
for batch_idx, batch in enumerate(train_loader):
|
||||
low_res = batch['low_res'].to(device)
|
||||
high_res = batch['high_res'].to(device)
|
||||
|
||||
# Forward pass
|
||||
try:
|
||||
outputs = self.model(low_ress) # Now outputs are images
|
||||
print("Output shape:", outputs.shape)
|
||||
print("High-res shape:", high_ress.shape)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error during forward pass: {str(e)}")
|
||||
outputs = model(low_res)
|
||||
|
||||
# Calculate loss
|
||||
l1_loss = self.criterion(outputs, high_ress) * 0.5
|
||||
mse_loss = self.mse_criterion(outputs, high_ress) * 0.5
|
||||
total_loss = l1_loss + mse_loss
|
||||
loss = criterion(outputs, high_res.permute(0, 3, 1, 2)) # Adjust for channel dimensions
|
||||
|
||||
# Backward pass and optimize
|
||||
self.optimizer.zero_grad()
|
||||
total_loss.backward()
|
||||
self.optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += total_loss.item()
|
||||
train_loss += loss.item()
|
||||
|
||||
epoch_loss = running_loss / len(self.train_loader)
|
||||
self.train_lossess.append(epoch_loss)
|
||||
print(f"Train Loss: {epoch_loss:.4f}")
|
||||
return epoch_loss
|
||||
avg_train_loss = train_loss / len(train_loader)
|
||||
|
||||
def train(self, num_epochs: int = 10):
|
||||
"""
|
||||
Train the model for specified number of epochs
|
||||
Args:
|
||||
num_epochs (int): Number of epochs to train for
|
||||
"""
|
||||
self.model.to(self.device)
|
||||
print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}")
|
||||
|
||||
print(f"Training metrics will be logged to: {self.log_file}")
|
||||
# Validation
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
low_res = batch['low_res'].to(device)
|
||||
high_res = batch['high_res'].to(device)
|
||||
|
||||
# Train phase
|
||||
train_loss = self._train_epoch()
|
||||
outputs = model(low_res)
|
||||
loss = criterion(outputs, high_res.permute(0, 3, 1, 2))
|
||||
|
||||
# Validation phase
|
||||
if self.val_loader is not None:
|
||||
val_loss = self._validate_epoch()
|
||||
val_loss += loss.item()
|
||||
|
||||
# Update learning rate scheduler based on validation loss
|
||||
self.scheduler.step(val_loss)
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
|
||||
# Log metrics
|
||||
self._log_metrics(epoch + 1, train_loss, val_loss)
|
||||
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
|
||||
|
||||
# Save best model based on validation loss
|
||||
if self.current_val_loss < self.best_val_loss:
|
||||
print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}")
|
||||
self.best_val_loss = self.current_val_loss
|
||||
model_save_path = os.path.join(self.output_dir, "aiuNN-optimized")
|
||||
self.model.save(model_save_path)
|
||||
print(f"Model saved to: {model_save_path}")
|
||||
# Save best model
|
||||
if avg_val_loss < best_val_loss:
|
||||
best_val_loss = avg_val_loss
|
||||
model.save("best_model")
|
||||
|
||||
# After training, save the final model
|
||||
final_model_path = os.path.join(self.output_dir, "aiuNN-final")
|
||||
self.model.save(final_model_path)
|
||||
print(f"\nFinal model saved to: {final_model_path}")
|
||||
return model
|
||||
|
||||
def main():
|
||||
# Paths to your data
|
||||
train_parquet_path = "/root/training_data/vision-dataset/image_upscaler.parquet"
|
||||
val_parquet_path = "/root/training_data/vision-dataset/image_vec_upscaler.parquet"
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Load your model first
|
||||
config = AIIAConfig.load("/root/vision/AIIA/AIIA-base-512")
|
||||
model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"), config=config)
|
||||
# Load pretrained model
|
||||
model = AIIA.load("/root/vision/AIIA/AIIA-base-512")
|
||||
|
||||
trainer = FineTuner(
|
||||
model=model,
|
||||
dataset_paths=[
|
||||
"/root/training_data/vision-dataset/image_upscaler.parquet",
|
||||
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
|
||||
],
|
||||
batch_size=2, # Increased batch size
|
||||
learning_rate=1e-4 # Reduced initial LR
|
||||
# Add final upsampling layer if needed (depending on your specific architecture)
|
||||
if hasattr(model, 'chunked_'):
|
||||
model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear'))
|
||||
|
||||
# Fine-tune
|
||||
finetune_model(
|
||||
model,
|
||||
train_parquet_path,
|
||||
val_parquet_path
|
||||
)
|
||||
|
||||
trainer.train(num_epochs=10) # Extended training time
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue