develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 33 additions and 58 deletions
Showing only changes of commit 50fa103579 - Show all commits

View File

@ -100,7 +100,6 @@ class ImageDataset(Dataset):
high_res_stream.close() high_res_stream.close()
class FineTuner: class FineTuner:
def __init__(self, def __init__(self,
model: AIIA, model: AIIA,
@ -129,7 +128,7 @@ class FineTuner:
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.train_ratio = train_ratio self.train_ratio = train_ratio
self.model = model self.model = model
self.ouptut_dir = output_dir self.output_dir = output_dir
# Create output directory if it doesn't exist # Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
@ -153,12 +152,22 @@ class FineTuner:
self._initialize_training() self._initialize_training()
def _freeze_layers(self): def _freeze_layers(self):
"""Freeze all layers except the upsample layer""" """
Freeze all layers except those that are part of the decoder or upsampling
We'll assume the last few layers are responsible for upsampling/reconstruction
"""
try: try:
# Try to freeze layers based on their names # Try to identify encoder layers and freeze them
for name, param in self.model.named_parameters(): for name, param in self.model.named_parameters():
if 'upsample' not in name: if 'encoder' in name:
param.requires_grad = False param.requires_grad = False
# Unfreeze certain layers (example: last 3 decoder layers)
# Modify this based on your actual model architecture
for name, param in self.model.named_parameters():
if 'decoder' in name and 'block4' in name or 'block5' in name:
param.requires_grad = True
except Exception as e: except Exception as e:
print(f"Warning: Couldn't freeze layers - {str(e)}") print(f"Warning: Couldn't freeze layers - {str(e)}")
pass pass
@ -221,38 +230,9 @@ class FineTuner:
""" """
Helper method to initialize training parameters Helper method to initialize training parameters
""" """
# Freeze all layers except upsample layer # Freeze layers except those we want to finetune
self._freeze_layers() self._freeze_layers()
# Add upscaling layer if not already present
if not hasattr(self.model, 'upsample'):
# Try to get existing configuration or set defaults
try:
hidden_size = self.model.config.hidden_size
kernel_size = 3 # Use odd-sized kernel for better performance
except AttributeError:
# Fallback values if config isn't available
hidden_size = 512
kernel_size = 3
self.model.upsample = nn.Sequential(
nn.ConvTranspose2d(hidden_size,
hidden_size//2,
kernel_size=kernel_size,
stride=2,
padding=1,
output_padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(hidden_size//2,
3,
kernel_size=kernel_size,
stride=2,
padding=1,
output_padding=1)
)
self.model.config.upsample_hidden_size = hidden_size
self.model.config.upsample_kernel_size = kernel_size
# Initialize optimizer and scheduler # Initialize optimizer and scheduler
params_to_optimize = [p for p in self.model.parameters() if p.requires_grad] params_to_optimize = [p for p in self.model.parameters() if p.requires_grad]
@ -291,18 +271,15 @@ class FineTuner:
low_ress = batch['low_ress'].to(self.device) low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device)
# Forward pass (we'll use the model's existing architecture without adding layers)
try: try:
# Try using CNN layer if available features = self.model(low_ress)
features = self.model.cnn(low_ress) except Exception as e:
except AttributeError: raise RuntimeError(f"Error during forward pass: {str(e)}")
# Fallback to extract_features method
features = self.model.extract_features(low_ress)
outputs = self.model.upsample(features)
# Calculate loss with different scaling for L1 and MSE components # Calculate loss with different scaling for L1 and MSE components
l1_loss = self.criterion(outputs, high_ress) * 0.5 l1_loss = self.criterion(features, high_ress) * 0.5
mse_loss = self.mse_criterion(outputs, high_ress) * 0.5 mse_loss = self.mse_criterion(features, high_ress) * 0.5
total_loss = l1_loss + mse_loss total_loss = l1_loss + mse_loss
# Backward pass and optimize # Backward pass and optimize
@ -332,15 +309,13 @@ class FineTuner:
high_ress = batch['high_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device)
try: try:
features = self.model.cnn(low_ress) features = self.model(low_ress)
except AttributeError: except Exception as e:
features = self.model.extract_features(low_ress) raise RuntimeError(f"Error during validation forward pass: {str(e)}")
outputs = self.model.upsample(features)
# Calculate same loss combination # Calculate same loss combination
l1_loss = self.criterion(outputs, high_ress) * 0.5 l1_loss = self.criterion(features, high_ress) * 0.5
mse_loss = self.mse_criterion(outputs, high_ress) * 0.5 mse_loss = self.mse_criterion(features, high_ress) * 0.5
total_loss = l1_loss + mse_loss total_loss = l1_loss + mse_loss
val_loss += total_loss.item() val_loss += total_loss.item()
@ -380,12 +355,12 @@ class FineTuner:
if self.current_val_loss < self.best_val_loss: if self.current_val_loss < self.best_val_loss:
print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}") print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}")
self.best_val_loss = self.current_val_loss self.best_val_loss = self.current_val_loss
model_save_path = os.path.join(self.ouptut_dir, "aiuNN-optimized") model_save_path = os.path.join(self.output_dir, "aiuNN-optimized")
self.model.save(model_save_path) self.model.save(model_save_path)
print(f"Model saved to: {model_save_path}") print(f"Model saved to: {model_save_path}")
# After training, save the final model # After training, save the final model
final_model_path = os.path.join(self.ouptut_dir, "aiuNN-final") final_model_path = os.path.join(self.output_dir, "aiuNN-final")
self.model.save(final_model_path) self.model.save(final_model_path)
print(f"\nFinal model saved to: {final_model_path}") print(f"\nFinal model saved to: {final_model_path}")