finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 5 additions and 0 deletions
Showing only changes of commit 73d52f733c - Show all commits

View File

@ -122,12 +122,17 @@ class ModelTrainer:
# Add upscaling layer if not already present
if not hasattr(self.model, 'upsample'):
# Get existing configuration values or set defaults if necessary
hidden_size = self.model.config.hidden_size
kernel_size = self.model.config.kernel_size
self.model.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1)
)
# Update the model's configuration with new parameters
self.model.config.upsample_hidden_size = hidden_size
self.model.config.upsample_kernel_size = kernel_size
# Initialize optimizer and loss function
self.criterion = nn.MSELoss()