bilinear upsampling followed by a convolution instead

This commit is contained in:
Falko Victor Habel 2025-07-03 09:53:10 +02:00
parent ff45f54860
commit 6d1fc4c88d
1 changed files with 25 additions and 22 deletions

View File

@ -1,39 +1,44 @@
import os
import torch
import torch.nn as nn import torch.nn as nn
import warnings
from aiia.model.Model import AIIAConfig, AIIABase from aiia.model.Model import AIIAConfig, AIIABase
from transformers import PreTrainedModel from transformers import PreTrainedModel
from .config import aiuNNConfig from .config import aiuNNConfig
import warnings
class aiuNN(PreTrainedModel): class aiuNN(PreTrainedModel):
config_class = aiuNNConfig config_class = aiuNNConfig
def __init__(self, config: aiuNNConfig, base_model: PreTrainedModel = None): def __init__(self, config: aiuNNConfig, base_model: PreTrainedModel = None):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
# Copy base layers into aiuNN for self-containment and portability
if base_model is not None: if base_model is not None:
# Only copy submodules if base_model is provided if hasattr(base_model, 'cnn'):
self.base_layers = nn.Sequential(*[layer for layer in base_model.cnn]) self.base_layers = nn.Sequential(*[layer for layer in base_model.cnn])
elif hasattr(base_model, 'shared_layer') and hasattr(base_model, 'unique_layers'):
layers = [base_model.shared_layer, base_model.activation, base_model.max_pool]
for ul in base_model.unique_layers:
layers.extend([ul, base_model.activation, base_model.max_pool])
self.base_layers = nn.Sequential(*layers)
else:
self.base_layers = self._build_base_layers_from_config(config)
else: else:
# At inference, modules will be loaded from state_dict
self.base_layers = self._build_base_layers_from_config(config) self.base_layers = self._build_base_layers_from_config(config)
scale_factor = self.config.upsample_scale # Bilinear upsampling head
out_channels = self.config.num_channels * (scale_factor ** 2) self.upsample = nn.Upsample(
self.pixel_shuffle_conv = nn.Conv2d( scale_factor=self.config.upsample_scale,
mode='bilinear',
align_corners=False
)
self.final_conv = nn.Conv2d(
in_channels=self.config.hidden_size, in_channels=self.config.hidden_size,
out_channels=out_channels, out_channels=self.config.num_channels,
kernel_size=self.config.kernel_size, kernel_size=3,
padding=1 padding=1
) )
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
def _build_base_layers_from_config(self, config): def _build_base_layers_from_config(self, config):
"""
Reconstruct the base layers (e.g., CNN) using only the config.
This must match exactly how your base model builds its layers!
"""
layers = [] layers = []
in_channels = config.num_channels in_channels = config.num_channels
for _ in range(config.num_hidden_layers): for _ in range(config.num_hidden_layers):
@ -47,14 +52,12 @@ class aiuNN(PreTrainedModel):
return nn.Sequential(*layers) return nn.Sequential(*layers)
def forward(self, x): def forward(self, x):
if self.base_layers is not None: x = self.base_layers(x)
x = self.base_layers(x) x = self.upsample(x)
x = self.pixel_shuffle_conv(x) x = self.final_conv(x)
x = self.pixel_shuffle(x)
return x return x
if __name__ == "__main__": if __name__ == "__main__":
from aiia import AIIABase, AIIAConfig from aiia import AIIABase, AIIAConfig
# Create a configuration and build a base model. # Create a configuration and build a base model.