improved model performance
This commit is contained in:
parent
a00ffac8c5
commit
25f4225baf
|
@ -15,17 +15,16 @@ class aiuNN(AIIA):
|
||||||
# Pass the unified base configuration using the new parameter.
|
# Pass the unified base configuration using the new parameter.
|
||||||
self.config = aiuNNConfig(base_config=base_model.config)
|
self.config = aiuNNConfig(base_config=base_model.config)
|
||||||
|
|
||||||
self.upsample = nn.Upsample(
|
# Enhanced approach
|
||||||
scale_factor=self.config.upsample_scale,
|
scale_factor = self.config.upsample_scale
|
||||||
mode=self.config.upsample_mode,
|
out_channels = self.base_model.config.num_channels * (scale_factor ** 2)
|
||||||
align_corners=self.config.upsample_align_corners
|
self.pixel_shuffle_conv = nn.Conv2d(
|
||||||
)
|
|
||||||
# Conversion layer: change from hidden size channels to number of channels from the config.
|
|
||||||
self.to_rgb = nn.Conv2d(
|
|
||||||
in_channels=self.base_model.config.hidden_size,
|
in_channels=self.base_model.config.hidden_size,
|
||||||
out_channels=self.base_model.config.num_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=1
|
kernel_size=self.base_model.config.kernel_size,
|
||||||
|
padding=1
|
||||||
)
|
)
|
||||||
|
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
Loading…
Reference in New Issue