develop #4
|
@ -6,23 +6,24 @@ from config import UpsamplerConfig
|
||||||
# Upsampler model that uses the configuration from the base model.
|
# Upsampler model that uses the configuration from the base model.
|
||||||
class Upsampler(AIIA):
|
class Upsampler(AIIA):
|
||||||
def __init__(self, base_model: AIIABase):
|
def __init__(self, base_model: AIIABase):
|
||||||
# Assume that base_model.config is an instance of UpsamplerConfig.
|
|
||||||
super().__init__(base_model.config)
|
super().__init__(base_model.config)
|
||||||
self.base_model = base_model
|
self.base_model = base_model
|
||||||
self.config = UpsamplerConfig(kwargs=self.base_model.config)
|
self.config = UpsamplerConfig(kwargs=self.base_model.config)
|
||||||
# Create the upsample layer using values from the configuration.
|
|
||||||
print(self.config.upsample_scale)
|
|
||||||
self.upsample = nn.Upsample(
|
self.upsample = nn.Upsample(
|
||||||
scale_factor=self.config.upsample_scale,
|
scale_factor=self.config.upsample_scale,
|
||||||
mode=self.config.upsample_mode,
|
mode=self.config.upsample_mode,
|
||||||
align_corners=self.config.upsample_align_corners
|
align_corners=self.config.upsample_align_corners
|
||||||
)
|
)
|
||||||
|
# Add a final conversion layer to change channels from 512 to 3.
|
||||||
|
self.to_rgb = nn.Conv2d(in_channels=512, out_channels=3, kernel_size=1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.base_model(x)
|
x = self.base_model(x)
|
||||||
x = self.upsample(x)
|
x = self.upsample(x)
|
||||||
|
x = self.to_rgb(x) # Convert feature map to RGB image.
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path: str):
|
def load(cls, path: str):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue