updated finetuning
This commit is contained in:
parent
9ec563e86d
commit
13fb2b76c1
|
@ -99,6 +99,8 @@ class ImageDataset(Dataset):
|
||||||
if 'high_res_stream' in locals():
|
if 'high_res_stream' in locals():
|
||||||
high_res_stream.close()
|
high_res_stream.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FineTuner:
|
class FineTuner:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model: AIIA,
|
model: AIIA,
|
||||||
|
@ -127,7 +129,7 @@ class FineTuner:
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.train_ratio = train_ratio
|
self.train_ratio = train_ratio
|
||||||
self.model = model
|
self.model = model
|
||||||
self.output_dir = output_dir
|
self.ouptut_dir = output_dir
|
||||||
|
|
||||||
# Create output directory if it doesn't exist
|
# Create output directory if it doesn't exist
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
@ -150,6 +152,17 @@ class FineTuner:
|
||||||
# Initialize training parameters
|
# Initialize training parameters
|
||||||
self._initialize_training()
|
self._initialize_training()
|
||||||
|
|
||||||
|
def _freeze_layers(self):
|
||||||
|
"""Freeze all layers except the upsample layer"""
|
||||||
|
try:
|
||||||
|
# Try to freeze layers based on their names
|
||||||
|
for name, param in self.model.named_parameters():
|
||||||
|
if 'upsample' not in name:
|
||||||
|
param.requires_grad = False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Couldn't freeze layers - {str(e)}")
|
||||||
|
pass
|
||||||
|
|
||||||
def _initialize_datasets(self):
|
def _initialize_datasets(self):
|
||||||
"""
|
"""
|
||||||
Helper method to initialize datasets
|
Helper method to initialize datasets
|
||||||
|
@ -159,15 +172,18 @@ class FineTuner:
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid dataset_paths format. Must be a list.")
|
raise ValueError("Invalid dataset_paths format. Must be a list.")
|
||||||
|
|
||||||
|
# Split into train and validation sets
|
||||||
df_train, df_val = train_test_split(
|
df_train, df_val = train_test_split(
|
||||||
df_train,
|
df_train,
|
||||||
test_size=1 - self.train_ratio,
|
test_size=1 - self.train_ratio,
|
||||||
random_state=42
|
random_state=42
|
||||||
)
|
)
|
||||||
|
|
||||||
# Define preprocessing transforms
|
# Define preprocessing transforms with augmentation
|
||||||
self.transform = transforms.Compose([
|
self.transform = transforms.Compose([
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
|
transforms.RandomResizedCrop(256),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||||
])
|
])
|
||||||
|
|
||||||
|
@ -201,46 +217,66 @@ class FineTuner:
|
||||||
with open(self.log_file, 'a') as f:
|
with open(self.log_file, 'a') as f:
|
||||||
f.write(f'{epoch},{train_loss:.6f},{val_loss:.6f},{self.best_val_loss:.6f}\n')
|
f.write(f'{epoch},{train_loss:.6f},{val_loss:.6f},{self.best_val_loss:.6f}\n')
|
||||||
|
|
||||||
|
|
||||||
def _initialize_training(self):
|
def _initialize_training(self):
|
||||||
"""
|
"""
|
||||||
Helper method to initialize training parameters
|
Helper method to initialize training parameters
|
||||||
"""
|
"""
|
||||||
# Freeze CNN layers (if applicable)
|
# Freeze all layers except upsample layer
|
||||||
try:
|
self._freeze_layers()
|
||||||
for param in self.model.cnn.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
except AttributeError:
|
|
||||||
pass # If model doesn't have a 'cnn' attribute, just continue
|
|
||||||
|
|
||||||
# Add upscaling layer if not already present
|
# Add upscaling layer if not already present
|
||||||
if not hasattr(self.model, 'upsample'):
|
if not hasattr(self.model, 'upsample'):
|
||||||
# Get existing configuration values or set defaults if necessary
|
# Try to get existing configuration or set defaults
|
||||||
|
try:
|
||||||
hidden_size = self.model.config.hidden_size
|
hidden_size = self.model.config.hidden_size
|
||||||
kernel_size = self.model.config.kernel_size
|
kernel_size = 3 # Use odd-sized kernel for better performance
|
||||||
|
except AttributeError:
|
||||||
|
# Fallback values if config isn't available
|
||||||
|
hidden_size = 512
|
||||||
|
kernel_size = 3
|
||||||
|
|
||||||
self.model.upsample = nn.Sequential(
|
self.model.upsample = nn.Sequential(
|
||||||
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
nn.ConvTranspose2d(hidden_size,
|
||||||
nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1)
|
hidden_size//2,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.ConvTranspose2d(hidden_size//2,
|
||||||
|
3,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1)
|
||||||
)
|
)
|
||||||
# Update the model's configuration with new parameters
|
|
||||||
self.model.config.upsample_hidden_size = hidden_size
|
self.model.config.upsample_hidden_size = hidden_size
|
||||||
self.model.config.upsample_kernel_size = kernel_size
|
self.model.config.upsample_kernel_size = kernel_size
|
||||||
|
|
||||||
# Initialize optimizer and loss function
|
# Initialize optimizer and scheduler
|
||||||
self.criterion = nn.MSELoss()
|
params_to_optimize = [p for p in self.model.parameters() if p.requires_grad]
|
||||||
|
|
||||||
# Get parameters of the upsample layer for training
|
if not params_to_optimize:
|
||||||
params = [p for p in self.model.upsample.parameters() if p.requires_grad]
|
raise ValueError("No parameters found to optimize")
|
||||||
if not params:
|
|
||||||
raise ValueError("No parameters found in upsample layer to optimize")
|
|
||||||
|
|
||||||
|
# Use Adam with weight decay for better regularization
|
||||||
self.optimizer = torch.optim.Adam(
|
self.optimizer = torch.optim.Adam(
|
||||||
params,
|
params_to_optimize,
|
||||||
lr=self.learning_rate
|
lr=self.learning_rate,
|
||||||
|
weight_decay=1e-4 # Add L2 regularization
|
||||||
)
|
)
|
||||||
|
|
||||||
self.best_val_loss = float('inf')
|
# Reduce learning rate when validation loss plateaus
|
||||||
|
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
self.optimizer,
|
||||||
|
factor=0.1, # Multiply LR by this factor on plateau
|
||||||
|
patience=3, # Number of epochs to wait before reducing LR
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use a combination of L1 and L2 losses for better performance
|
||||||
|
self.criterion = nn.L1Loss()
|
||||||
|
self.mse_criterion = nn.MSELoss()
|
||||||
|
|
||||||
def _train_epoch(self):
|
def _train_epoch(self):
|
||||||
"""
|
"""
|
||||||
|
@ -255,18 +291,26 @@ class FineTuner:
|
||||||
low_ress = batch['low_ress'].to(self.device)
|
low_ress = batch['low_ress'].to(self.device)
|
||||||
high_ress = batch['high_ress'].to(self.device)
|
high_ress = batch['high_ress'].to(self.device)
|
||||||
|
|
||||||
# Forward pass
|
try:
|
||||||
features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress)
|
# Try using CNN layer if available
|
||||||
|
features = self.model.cnn(low_ress)
|
||||||
|
except AttributeError:
|
||||||
|
# Fallback to extract_features method
|
||||||
|
features = self.model.extract_features(low_ress)
|
||||||
|
|
||||||
outputs = self.model.upsample(features)
|
outputs = self.model.upsample(features)
|
||||||
|
|
||||||
loss = self.criterion(outputs, high_ress)
|
# Calculate loss with different scaling for L1 and MSE components
|
||||||
|
l1_loss = self.criterion(outputs, high_ress) * 0.5
|
||||||
|
mse_loss = self.mse_criterion(outputs, high_ress) * 0.5
|
||||||
|
total_loss = l1_loss + mse_loss
|
||||||
|
|
||||||
# Backward pass and optimize
|
# Backward pass and optimize
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
loss.backward()
|
total_loss.backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
running_loss += loss.item()
|
running_loss += total_loss.item()
|
||||||
|
|
||||||
epoch_loss = running_loss / len(self.train_loader)
|
epoch_loss = running_loss / len(self.train_loader)
|
||||||
self.train_losses.append(epoch_loss)
|
self.train_losses.append(epoch_loss)
|
||||||
|
@ -287,11 +331,19 @@ class FineTuner:
|
||||||
low_ress = batch['low_ress'].to(self.device)
|
low_ress = batch['low_ress'].to(self.device)
|
||||||
high_ress = batch['high_ress'].to(self.device)
|
high_ress = batch['high_ress'].to(self.device)
|
||||||
|
|
||||||
features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress)
|
try:
|
||||||
|
features = self.model.cnn(low_ress)
|
||||||
|
except AttributeError:
|
||||||
|
features = self.model.extract_features(low_ress)
|
||||||
|
|
||||||
outputs = self.model.upsample(features)
|
outputs = self.model.upsample(features)
|
||||||
|
|
||||||
loss = self.criterion(outputs, high_ress)
|
# Calculate same loss combination
|
||||||
val_loss += loss.item()
|
l1_loss = self.criterion(outputs, high_ress) * 0.5
|
||||||
|
mse_loss = self.mse_criterion(outputs, high_ress) * 0.5
|
||||||
|
total_loss = l1_loss + mse_loss
|
||||||
|
|
||||||
|
val_loss += total_loss.item()
|
||||||
|
|
||||||
self.current_val_loss = val_loss / len(self.val_loader)
|
self.current_val_loss = val_loss / len(self.val_loader)
|
||||||
self.val_losses.append(self.current_val_loss)
|
self.val_losses.append(self.current_val_loss)
|
||||||
|
@ -318,6 +370,9 @@ class FineTuner:
|
||||||
if self.val_loader is not None:
|
if self.val_loader is not None:
|
||||||
val_loss = self._validate_epoch()
|
val_loss = self._validate_epoch()
|
||||||
|
|
||||||
|
# Update learning rate scheduler based on validation loss
|
||||||
|
self.scheduler.step(val_loss)
|
||||||
|
|
||||||
# Log metrics
|
# Log metrics
|
||||||
self._log_metrics(epoch + 1, train_loss, val_loss)
|
self._log_metrics(epoch + 1, train_loss, val_loss)
|
||||||
|
|
||||||
|
@ -325,25 +380,28 @@ class FineTuner:
|
||||||
if self.current_val_loss < self.best_val_loss:
|
if self.current_val_loss < self.best_val_loss:
|
||||||
print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}")
|
print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}")
|
||||||
self.best_val_loss = self.current_val_loss
|
self.best_val_loss = self.current_val_loss
|
||||||
model_save_path = os.path.join(self.output_dir, "aiuNN-finetuned")
|
model_save_path = os.path.join(self.ouptut_dir, "aiuNN-optimized")
|
||||||
self.model.save(model_save_path)
|
self.model.save(model_save_path)
|
||||||
print(f"Model saved to: {model_save_path}")
|
print(f"Model saved to: {model_save_path}")
|
||||||
|
|
||||||
def __repr__(self):
|
# After training, save the final model
|
||||||
return f"ModelTrainer (model={type(self.model).__name__}, batch_size={self.batch_size})"
|
final_model_path = os.path.join(self.ouptut_dir, "aiuNN-final")
|
||||||
|
self.model.save(final_model_path)
|
||||||
|
print(f"\nFinal model saved to: {final_model_path}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Load your model first
|
# Load your model first
|
||||||
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512/")
|
model = AIIABase.load("/root/vision/dataset/AIIA-base-512/")
|
||||||
|
|
||||||
trainer = FineTuner(
|
trainer = FineTuner(
|
||||||
model=model,
|
model=model,
|
||||||
dataset_paths=[
|
dataset_paths=[
|
||||||
"/root/training_data/vision-dataset/image_upscaler.parquet",
|
"/root/training_data/vision-dataset/image_upscaler.0",
|
||||||
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
|
"/root/training_data/vision-dataset/image_vec_upscaler.0"
|
||||||
],
|
],
|
||||||
batch_size=2,
|
batch_size=8, # Increased batch size
|
||||||
learning_rate=0.001
|
learning_rate=1e-4 # Reduced initial LR
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.train(num_epochs=3)
|
trainer.train(num_epochs=10) # Extended training time
|
|
@ -5,7 +5,7 @@ from torch.nn import functional as F
|
||||||
from aiia.model import AIIABase
|
from aiia.model import AIIABase
|
||||||
|
|
||||||
class UpScaler:
|
class UpScaler:
|
||||||
def __init__(self, model_path="AIIA-base-512-upscaler", device="cuda"):
|
def __init__(self, model_path="aiuNN-finetuned", device="cuda"):
|
||||||
self.device = torch.device(device)
|
self.device = torch.device(device)
|
||||||
self.model = AIIABase.load(model_path).to(self.device)
|
self.model = AIIABase.load(model_path).to(self.device)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
Loading…
Reference in New Issue