From a51300c77ca7d78b979bae882c4aacc6a7aa4267 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sun, 23 Feb 2025 22:50:17 +0100 Subject: [PATCH] base model fix --- src/aiunn/upsampler.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/aiunn/upsampler.py b/src/aiunn/upsampler.py index e5f17ce..d2eeb82 100644 --- a/src/aiunn/upsampler.py +++ b/src/aiunn/upsampler.py @@ -14,16 +14,15 @@ class Upsampler(AIIA): mode=self.config.upsample_mode, align_corners=self.config.upsample_align_corners ) - # Add a final conversion layer to change channels from 512 to 3. - self.to_rgb = nn.Conv2d(in_channels=512, out_channels=3, kernel_size=1) + # Conversion layer: change from 512 channels to 3 channels. + self.to_rgb = nn.Conv2d(in_channels=self.base_model.config.hidden_size, out_channels=3, kernel_size=1) def forward(self, x): x = self.base_model(x) x = self.upsample(x) - x = self.to_rgb(x) # Convert feature map to RGB image. + x = self.to_rgb(x) # Ensures output has 3 channels. return x - - + @classmethod def load(cls, path: str): """