develop #4
|
@ -122,12 +122,17 @@ class ModelTrainer:
|
||||||
|
|
||||||
# Add upscaling layer if not already present
|
# Add upscaling layer if not already present
|
||||||
if not hasattr(self.model, 'upsample'):
|
if not hasattr(self.model, 'upsample'):
|
||||||
|
# Get existing configuration values or set defaults if necessary
|
||||||
hidden_size = self.model.config.hidden_size
|
hidden_size = self.model.config.hidden_size
|
||||||
kernel_size = self.model.config.kernel_size
|
kernel_size = self.model.config.kernel_size
|
||||||
|
|
||||||
self.model.upsample = nn.Sequential(
|
self.model.upsample = nn.Sequential(
|
||||||
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
||||||
nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1)
|
nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1)
|
||||||
)
|
)
|
||||||
|
# Update the model's configuration with new parameters
|
||||||
|
self.model.config.upsample_hidden_size = hidden_size
|
||||||
|
self.model.config.upsample_kernel_size = kernel_size
|
||||||
|
|
||||||
# Initialize optimizer and loss function
|
# Initialize optimizer and loss function
|
||||||
self.criterion = nn.MSELoss()
|
self.criterion = nn.MSELoss()
|
||||||
|
|
Loading…
Reference in New Issue