updated config for updated model

This commit is contained in:
Falko Victor Habel 2025-01-30 10:41:37 +01:00
parent 2121316e3b
commit 73d52f733c
1 changed files with 5 additions and 0 deletions

View File

@ -122,12 +122,17 @@ class ModelTrainer:
# Add upscaling layer if not already present # Add upscaling layer if not already present
if not hasattr(self.model, 'upsample'): if not hasattr(self.model, 'upsample'):
# Get existing configuration values or set defaults if necessary
hidden_size = self.model.config.hidden_size hidden_size = self.model.config.hidden_size
kernel_size = self.model.config.kernel_size kernel_size = self.model.config.kernel_size
self.model.upsample = nn.Sequential( self.model.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1) nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1)
) )
# Update the model's configuration with new parameters
self.model.config.upsample_hidden_size = hidden_size
self.model.config.upsample_kernel_size = kernel_size
# Initialize optimizer and loss function # Initialize optimizer and loss function
self.criterion = nn.MSELoss() self.criterion = nn.MSELoss()