base model fix

This commit is contained in:
Falko Victor Habel 2025-02-23 22:50:17 +01:00
parent c88961aee7
commit a51300c77c
1 changed files with 4 additions and 5 deletions

View File

@ -14,16 +14,15 @@ class Upsampler(AIIA):
mode=self.config.upsample_mode,
align_corners=self.config.upsample_align_corners
)
# Add a final conversion layer to change channels from 512 to 3.
self.to_rgb = nn.Conv2d(in_channels=512, out_channels=3, kernel_size=1)
# Conversion layer: change from 512 channels to 3 channels.
self.to_rgb = nn.Conv2d(in_channels=self.base_model.config.hidden_size, out_channels=3, kernel_size=1)
def forward(self, x):
x = self.base_model(x)
x = self.upsample(x)
x = self.to_rgb(x) # Convert feature map to RGB image.
x = self.to_rgb(x) # Ensures output has 3 channels.
return x
@classmethod
def load(cls, path: str):
"""