diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 4cf3d2e..9dc5b2e 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -99,6 +99,8 @@ class ImageDataset(Dataset): if 'high_res_stream' in locals(): high_res_stream.close() + + class FineTuner: def __init__(self, model: AIIA, @@ -127,7 +129,7 @@ class FineTuner: self.learning_rate = learning_rate self.train_ratio = train_ratio self.model = model - self.output_dir = output_dir + self.ouptut_dir = output_dir # Create output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) @@ -150,6 +152,17 @@ class FineTuner: # Initialize training parameters self._initialize_training() + def _freeze_layers(self): + """Freeze all layers except the upsample layer""" + try: + # Try to freeze layers based on their names + for name, param in self.model.named_parameters(): + if 'upsample' not in name: + param.requires_grad = False + except Exception as e: + print(f"Warning: Couldn't freeze layers - {str(e)}") + pass + def _initialize_datasets(self): """ Helper method to initialize datasets @@ -159,15 +172,18 @@ class FineTuner: else: raise ValueError("Invalid dataset_paths format. Must be a list.") + # Split into train and validation sets df_train, df_val = train_test_split( df_train, test_size=1 - self.train_ratio, random_state=42 ) - # Define preprocessing transforms + # Define preprocessing transforms with augmentation self.transform = transforms.Compose([ transforms.ToTensor(), + transforms.RandomResizedCrop(256), + transforms.RandomHorizontalFlip(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) @@ -200,47 +216,67 @@ class FineTuner: """ with open(self.log_file, 'a') as f: f.write(f'{epoch},{train_loss:.6f},{val_loss:.6f},{self.best_val_loss:.6f}\n') - - + def _initialize_training(self): """ Helper method to initialize training parameters """ - # Freeze CNN layers (if applicable) - try: - for param in self.model.cnn.parameters(): - param.requires_grad = False - except AttributeError: - pass # If model doesn't have a 'cnn' attribute, just continue + # Freeze all layers except upsample layer + self._freeze_layers() # Add upscaling layer if not already present if not hasattr(self.model, 'upsample'): - # Get existing configuration values or set defaults if necessary - hidden_size = self.model.config.hidden_size - kernel_size = self.model.config.kernel_size + # Try to get existing configuration or set defaults + try: + hidden_size = self.model.config.hidden_size + kernel_size = 3 # Use odd-sized kernel for better performance + except AttributeError: + # Fallback values if config isn't available + hidden_size = 512 + kernel_size = 3 self.model.upsample = nn.Sequential( - nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), - nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1) + nn.ConvTranspose2d(hidden_size, + hidden_size//2, + kernel_size=kernel_size, + stride=2, + padding=1, + output_padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(hidden_size//2, + 3, + kernel_size=kernel_size, + stride=2, + padding=1, + output_padding=1) ) - # Update the model's configuration with new parameters self.model.config.upsample_hidden_size = hidden_size self.model.config.upsample_kernel_size = kernel_size - # Initialize optimizer and loss function - self.criterion = nn.MSELoss() + # Initialize optimizer and scheduler + params_to_optimize = [p for p in self.model.parameters() if p.requires_grad] - # Get parameters of the upsample layer for training - params = [p for p in self.model.upsample.parameters() if p.requires_grad] - if not params: - raise ValueError("No parameters found in upsample layer to optimize") + if not params_to_optimize: + raise ValueError("No parameters found to optimize") + # Use Adam with weight decay for better regularization self.optimizer = torch.optim.Adam( - params, - lr=self.learning_rate + params_to_optimize, + lr=self.learning_rate, + weight_decay=1e-4 # Add L2 regularization ) - self.best_val_loss = float('inf') + # Reduce learning rate when validation loss plateaus + self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, + factor=0.1, # Multiply LR by this factor on plateau + patience=3, # Number of epochs to wait before reducing LR + verbose=True + ) + + # Use a combination of L1 and L2 losses for better performance + self.criterion = nn.L1Loss() + self.mse_criterion = nn.MSELoss() def _train_epoch(self): """ @@ -255,18 +291,26 @@ class FineTuner: low_ress = batch['low_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device) - # Forward pass - features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress) + try: + # Try using CNN layer if available + features = self.model.cnn(low_ress) + except AttributeError: + # Fallback to extract_features method + features = self.model.extract_features(low_ress) + outputs = self.model.upsample(features) - loss = self.criterion(outputs, high_ress) + # Calculate loss with different scaling for L1 and MSE components + l1_loss = self.criterion(outputs, high_ress) * 0.5 + mse_loss = self.mse_criterion(outputs, high_ress) * 0.5 + total_loss = l1_loss + mse_loss # Backward pass and optimize self.optimizer.zero_grad() - loss.backward() + total_loss.backward() self.optimizer.step() - running_loss += loss.item() + running_loss += total_loss.item() epoch_loss = running_loss / len(self.train_loader) self.train_losses.append(epoch_loss) @@ -287,11 +331,19 @@ class FineTuner: low_ress = batch['low_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device) - features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress) + try: + features = self.model.cnn(low_ress) + except AttributeError: + features = self.model.extract_features(low_ress) + outputs = self.model.upsample(features) - loss = self.criterion(outputs, high_ress) - val_loss += loss.item() + # Calculate same loss combination + l1_loss = self.criterion(outputs, high_ress) * 0.5 + mse_loss = self.mse_criterion(outputs, high_ress) * 0.5 + total_loss = l1_loss + mse_loss + + val_loss += total_loss.item() self.current_val_loss = val_loss / len(self.val_loader) self.val_losses.append(self.current_val_loss) @@ -318,6 +370,9 @@ class FineTuner: if self.val_loader is not None: val_loss = self._validate_epoch() + # Update learning rate scheduler based on validation loss + self.scheduler.step(val_loss) + # Log metrics self._log_metrics(epoch + 1, train_loss, val_loss) @@ -325,25 +380,28 @@ class FineTuner: if self.current_val_loss < self.best_val_loss: print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}") self.best_val_loss = self.current_val_loss - model_save_path = os.path.join(self.output_dir, "aiuNN-finetuned") + model_save_path = os.path.join(self.ouptut_dir, "aiuNN-optimized") self.model.save(model_save_path) print(f"Model saved to: {model_save_path}") - def __repr__(self): - return f"ModelTrainer (model={type(self.model).__name__}, batch_size={self.batch_size})" - + # After training, save the final model + final_model_path = os.path.join(self.ouptut_dir, "aiuNN-final") + self.model.save(final_model_path) + print(f"\nFinal model saved to: {final_model_path}") + + if __name__ == "__main__": # Load your model first - model = AIIABase.load("/root/vision/AIIA/AIIA-base-512/") - + model = AIIABase.load("/root/vision/dataset/AIIA-base-512/") + trainer = FineTuner( model=model, dataset_paths=[ - "/root/training_data/vision-dataset/image_upscaler.parquet", - "/root/training_data/vision-dataset/image_vec_upscaler.parquet" + "/root/training_data/vision-dataset/image_upscaler.0", + "/root/training_data/vision-dataset/image_vec_upscaler.0" ], - batch_size=2, - learning_rate=0.001 + batch_size=8, # Increased batch size + learning_rate=1e-4 # Reduced initial LR ) - trainer.train(num_epochs=3) \ No newline at end of file + trainer.train(num_epochs=10) # Extended training time \ No newline at end of file diff --git a/src/aiunn/inference.py b/src/aiunn/inference.py index 12b2b76..c86905b 100644 --- a/src/aiunn/inference.py +++ b/src/aiunn/inference.py @@ -5,7 +5,7 @@ from torch.nn import functional as F from aiia.model import AIIABase class UpScaler: - def __init__(self, model_path="AIIA-base-512-upscaler", device="cuda"): + def __init__(self, model_path="aiuNN-finetuned", device="cuda"): self.device = torch.device(device) self.model = AIIABase.load(model_path).to(self.device) self.model.eval()