diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 1fefbe4..e9c13d6 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -100,7 +100,6 @@ class ImageDataset(Dataset): high_res_stream.close() - class FineTuner: def __init__(self, model: AIIA, @@ -129,7 +128,7 @@ class FineTuner: self.learning_rate = learning_rate self.train_ratio = train_ratio self.model = model - self.ouptut_dir = output_dir + self.output_dir = output_dir # Create output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) @@ -153,12 +152,22 @@ class FineTuner: self._initialize_training() def _freeze_layers(self): - """Freeze all layers except the upsample layer""" + """ + Freeze all layers except those that are part of the decoder or upsampling + We'll assume the last few layers are responsible for upsampling/reconstruction + """ try: - # Try to freeze layers based on their names + # Try to identify encoder layers and freeze them for name, param in self.model.named_parameters(): - if 'upsample' not in name: + if 'encoder' in name: param.requires_grad = False + + # Unfreeze certain layers (example: last 3 decoder layers) + # Modify this based on your actual model architecture + for name, param in self.model.named_parameters(): + if 'decoder' in name and 'block4' in name or 'block5' in name: + param.requires_grad = True + except Exception as e: print(f"Warning: Couldn't freeze layers - {str(e)}") pass @@ -221,38 +230,9 @@ class FineTuner: """ Helper method to initialize training parameters """ - # Freeze all layers except upsample layer + # Freeze layers except those we want to finetune self._freeze_layers() - # Add upscaling layer if not already present - if not hasattr(self.model, 'upsample'): - # Try to get existing configuration or set defaults - try: - hidden_size = self.model.config.hidden_size - kernel_size = 3 # Use odd-sized kernel for better performance - except AttributeError: - # Fallback values if config isn't available - hidden_size = 512 - kernel_size = 3 - - self.model.upsample = nn.Sequential( - nn.ConvTranspose2d(hidden_size, - hidden_size//2, - kernel_size=kernel_size, - stride=2, - padding=1, - output_padding=1), - nn.ReLU(inplace=True), - nn.ConvTranspose2d(hidden_size//2, - 3, - kernel_size=kernel_size, - stride=2, - padding=1, - output_padding=1) - ) - self.model.config.upsample_hidden_size = hidden_size - self.model.config.upsample_kernel_size = kernel_size - # Initialize optimizer and scheduler params_to_optimize = [p for p in self.model.parameters() if p.requires_grad] @@ -269,8 +249,8 @@ class FineTuner: # Reduce learning rate when validation loss plateaus self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, - factor=0.1, # Multiply LR by this factor on plateau - patience=3, # Number of epochs to wait before reducing LR + factor=0.1, # Multiply LR by this factor on plateau + patience=3, # Number of epochs to wait before reducing LR verbose=True ) @@ -291,18 +271,15 @@ class FineTuner: low_ress = batch['low_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device) + # Forward pass (we'll use the model's existing architecture without adding layers) try: - # Try using CNN layer if available - features = self.model.cnn(low_ress) - except AttributeError: - # Fallback to extract_features method - features = self.model.extract_features(low_ress) - - outputs = self.model.upsample(features) + features = self.model(low_ress) + except Exception as e: + raise RuntimeError(f"Error during forward pass: {str(e)}") # Calculate loss with different scaling for L1 and MSE components - l1_loss = self.criterion(outputs, high_ress) * 0.5 - mse_loss = self.mse_criterion(outputs, high_ress) * 0.5 + l1_loss = self.criterion(features, high_ress) * 0.5 + mse_loss = self.mse_criterion(features, high_ress) * 0.5 total_loss = l1_loss + mse_loss # Backward pass and optimize @@ -332,15 +309,13 @@ class FineTuner: high_ress = batch['high_ress'].to(self.device) try: - features = self.model.cnn(low_ress) - except AttributeError: - features = self.model.extract_features(low_ress) - - outputs = self.model.upsample(features) + features = self.model(low_ress) + except Exception as e: + raise RuntimeError(f"Error during validation forward pass: {str(e)}") # Calculate same loss combination - l1_loss = self.criterion(outputs, high_ress) * 0.5 - mse_loss = self.mse_criterion(outputs, high_ress) * 0.5 + l1_loss = self.criterion(features, high_ress) * 0.5 + mse_loss = self.mse_criterion(features, high_ress) * 0.5 total_loss = l1_loss + mse_loss val_loss += total_loss.item() @@ -380,12 +355,12 @@ class FineTuner: if self.current_val_loss < self.best_val_loss: print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}") self.best_val_loss = self.current_val_loss - model_save_path = os.path.join(self.ouptut_dir, "aiuNN-optimized") + model_save_path = os.path.join(self.output_dir, "aiuNN-optimized") self.model.save(model_save_path) print(f"Model saved to: {model_save_path}") # After training, save the final model - final_model_path = os.path.join(self.ouptut_dir, "aiuNN-final") + final_model_path = os.path.join(self.output_dir, "aiuNN-final") self.model.save(final_model_path) print(f"\nFinal model saved to: {final_model_path}") @@ -400,8 +375,8 @@ if __name__ == "__main__": "/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet" ], - batch_size=8, # Increased batch size - learning_rate=1e-4 # Reduced initial LR + batch_size=8, # Increased batch size + learning_rate=1e-4 # Reduced initial LR ) - trainer.train(num_epochs=10) # Extended training time \ No newline at end of file + trainer.train(num_epochs=10) # Extended training time \ No newline at end of file