base model fix
This commit is contained in:
parent
c88961aee7
commit
a51300c77c
|
@ -14,16 +14,15 @@ class Upsampler(AIIA):
|
|||
mode=self.config.upsample_mode,
|
||||
align_corners=self.config.upsample_align_corners
|
||||
)
|
||||
# Add a final conversion layer to change channels from 512 to 3.
|
||||
self.to_rgb = nn.Conv2d(in_channels=512, out_channels=3, kernel_size=1)
|
||||
# Conversion layer: change from 512 channels to 3 channels.
|
||||
self.to_rgb = nn.Conv2d(in_channels=self.base_model.config.hidden_size, out_channels=3, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base_model(x)
|
||||
x = self.upsample(x)
|
||||
x = self.to_rgb(x) # Convert feature map to RGB image.
|
||||
x = self.to_rgb(x) # Ensures output has 3 channels.
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue