finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 4 additions and 5 deletions
Showing only changes of commit a51300c77c - Show all commits

View File

@ -14,16 +14,15 @@ class Upsampler(AIIA):
mode=self.config.upsample_mode,
align_corners=self.config.upsample_align_corners
)
# Add a final conversion layer to change channels from 512 to 3.
self.to_rgb = nn.Conv2d(in_channels=512, out_channels=3, kernel_size=1)
# Conversion layer: change from 512 channels to 3 channels.
self.to_rgb = nn.Conv2d(in_channels=self.base_model.config.hidden_size, out_channels=3, kernel_size=1)
def forward(self, x):
x = self.base_model(x)
x = self.upsample(x)
x = self.to_rgb(x) # Convert feature map to RGB image.
x = self.to_rgb(x) # Ensures output has 3 channels.
return x
@classmethod
def load(cls, path: str):
"""