finetune_class #1
|
@ -14,16 +14,15 @@ class Upsampler(AIIA):
|
||||||
mode=self.config.upsample_mode,
|
mode=self.config.upsample_mode,
|
||||||
align_corners=self.config.upsample_align_corners
|
align_corners=self.config.upsample_align_corners
|
||||||
)
|
)
|
||||||
# Add a final conversion layer to change channels from 512 to 3.
|
# Conversion layer: change from 512 channels to 3 channels.
|
||||||
self.to_rgb = nn.Conv2d(in_channels=512, out_channels=3, kernel_size=1)
|
self.to_rgb = nn.Conv2d(in_channels=self.base_model.config.hidden_size, out_channels=3, kernel_size=1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.base_model(x)
|
x = self.base_model(x)
|
||||||
x = self.upsample(x)
|
x = self.upsample(x)
|
||||||
x = self.to_rgb(x) # Convert feature map to RGB image.
|
x = self.to_rgb(x) # Ensures output has 3 channels.
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path: str):
|
def load(cls, path: str):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue