finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 33 additions and 58 deletions
Showing only changes of commit 50fa103579 - Show all commits

View File

@ -100,7 +100,6 @@ class ImageDataset(Dataset):
high_res_stream.close()
class FineTuner:
def __init__(self,
model: AIIA,
@ -129,7 +128,7 @@ class FineTuner:
self.learning_rate = learning_rate
self.train_ratio = train_ratio
self.model = model
self.ouptut_dir = output_dir
self.output_dir = output_dir
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
@ -153,12 +152,22 @@ class FineTuner:
self._initialize_training()
def _freeze_layers(self):
"""Freeze all layers except the upsample layer"""
"""
Freeze all layers except those that are part of the decoder or upsampling
We'll assume the last few layers are responsible for upsampling/reconstruction
"""
try:
# Try to freeze layers based on their names
# Try to identify encoder layers and freeze them
for name, param in self.model.named_parameters():
if 'upsample' not in name:
if 'encoder' in name:
param.requires_grad = False
# Unfreeze certain layers (example: last 3 decoder layers)
# Modify this based on your actual model architecture
for name, param in self.model.named_parameters():
if 'decoder' in name and 'block4' in name or 'block5' in name:
param.requires_grad = True
except Exception as e:
print(f"Warning: Couldn't freeze layers - {str(e)}")
pass
@ -221,38 +230,9 @@ class FineTuner:
"""
Helper method to initialize training parameters
"""
# Freeze all layers except upsample layer
# Freeze layers except those we want to finetune
self._freeze_layers()
# Add upscaling layer if not already present
if not hasattr(self.model, 'upsample'):
# Try to get existing configuration or set defaults
try:
hidden_size = self.model.config.hidden_size
kernel_size = 3 # Use odd-sized kernel for better performance
except AttributeError:
# Fallback values if config isn't available
hidden_size = 512
kernel_size = 3
self.model.upsample = nn.Sequential(
nn.ConvTranspose2d(hidden_size,
hidden_size//2,
kernel_size=kernel_size,
stride=2,
padding=1,
output_padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(hidden_size//2,
3,
kernel_size=kernel_size,
stride=2,
padding=1,
output_padding=1)
)
self.model.config.upsample_hidden_size = hidden_size
self.model.config.upsample_kernel_size = kernel_size
# Initialize optimizer and scheduler
params_to_optimize = [p for p in self.model.parameters() if p.requires_grad]
@ -269,8 +249,8 @@ class FineTuner:
# Reduce learning rate when validation loss plateaus
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
factor=0.1, # Multiply LR by this factor on plateau
patience=3, # Number of epochs to wait before reducing LR
factor=0.1, # Multiply LR by this factor on plateau
patience=3, # Number of epochs to wait before reducing LR
verbose=True
)
@ -291,18 +271,15 @@ class FineTuner:
low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device)
# Forward pass (we'll use the model's existing architecture without adding layers)
try:
# Try using CNN layer if available
features = self.model.cnn(low_ress)
except AttributeError:
# Fallback to extract_features method
features = self.model.extract_features(low_ress)
outputs = self.model.upsample(features)
features = self.model(low_ress)
except Exception as e:
raise RuntimeError(f"Error during forward pass: {str(e)}")
# Calculate loss with different scaling for L1 and MSE components
l1_loss = self.criterion(outputs, high_ress) * 0.5
mse_loss = self.mse_criterion(outputs, high_ress) * 0.5
l1_loss = self.criterion(features, high_ress) * 0.5
mse_loss = self.mse_criterion(features, high_ress) * 0.5
total_loss = l1_loss + mse_loss
# Backward pass and optimize
@ -332,15 +309,13 @@ class FineTuner:
high_ress = batch['high_ress'].to(self.device)
try:
features = self.model.cnn(low_ress)
except AttributeError:
features = self.model.extract_features(low_ress)
outputs = self.model.upsample(features)
features = self.model(low_ress)
except Exception as e:
raise RuntimeError(f"Error during validation forward pass: {str(e)}")
# Calculate same loss combination
l1_loss = self.criterion(outputs, high_ress) * 0.5
mse_loss = self.mse_criterion(outputs, high_ress) * 0.5
l1_loss = self.criterion(features, high_ress) * 0.5
mse_loss = self.mse_criterion(features, high_ress) * 0.5
total_loss = l1_loss + mse_loss
val_loss += total_loss.item()
@ -380,12 +355,12 @@ class FineTuner:
if self.current_val_loss < self.best_val_loss:
print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}")
self.best_val_loss = self.current_val_loss
model_save_path = os.path.join(self.ouptut_dir, "aiuNN-optimized")
model_save_path = os.path.join(self.output_dir, "aiuNN-optimized")
self.model.save(model_save_path)
print(f"Model saved to: {model_save_path}")
# After training, save the final model
final_model_path = os.path.join(self.ouptut_dir, "aiuNN-final")
final_model_path = os.path.join(self.output_dir, "aiuNN-final")
self.model.save(final_model_path)
print(f"\nFinal model saved to: {final_model_path}")
@ -400,8 +375,8 @@ if __name__ == "__main__":
"/root/training_data/vision-dataset/image_upscaler.parquet",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
],
batch_size=8, # Increased batch size
learning_rate=1e-4 # Reduced initial LR
batch_size=8, # Increased batch size
learning_rate=1e-4 # Reduced initial LR
)
trainer.train(num_epochs=10) # Extended training time
trainer.train(num_epochs=10) # Extended training time