finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 4 additions and 3 deletions
Showing only changes of commit c88961aee7 - Show all commits

View File

@ -6,23 +6,24 @@ from config import UpsamplerConfig
# Upsampler model that uses the configuration from the base model. # Upsampler model that uses the configuration from the base model.
class Upsampler(AIIA): class Upsampler(AIIA):
def __init__(self, base_model: AIIABase): def __init__(self, base_model: AIIABase):
# Assume that base_model.config is an instance of UpsamplerConfig.
super().__init__(base_model.config) super().__init__(base_model.config)
self.base_model = base_model self.base_model = base_model
self.config = UpsamplerConfig(kwargs=self.base_model.config) self.config = UpsamplerConfig(kwargs=self.base_model.config)
# Create the upsample layer using values from the configuration.
print(self.config.upsample_scale)
self.upsample = nn.Upsample( self.upsample = nn.Upsample(
scale_factor=self.config.upsample_scale, scale_factor=self.config.upsample_scale,
mode=self.config.upsample_mode, mode=self.config.upsample_mode,
align_corners=self.config.upsample_align_corners align_corners=self.config.upsample_align_corners
) )
# Add a final conversion layer to change channels from 512 to 3.
self.to_rgb = nn.Conv2d(in_channels=512, out_channels=3, kernel_size=1)
def forward(self, x): def forward(self, x):
x = self.base_model(x) x = self.base_model(x)
x = self.upsample(x) x = self.upsample(x)
x = self.to_rgb(x) # Convert feature map to RGB image.
return x return x
@classmethod @classmethod
def load(cls, path: str): def load(cls, path: str):
""" """