75 lines
2.8 KiB
Python
75 lines
2.8 KiB
Python
import torch.nn as nn
|
|
from aiia.model.Model import AIIAConfig, AIIABase
|
|
from transformers import PreTrainedModel
|
|
from .config import aiuNNConfig
|
|
|
|
|
|
class aiuNN(PreTrainedModel):
|
|
config_class = aiuNNConfig
|
|
|
|
def __init__(self, config: aiuNNConfig, base_model: PreTrainedModel = None):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
# Copy base layers into aiuNN for self-containment and portability
|
|
if base_model is not None:
|
|
if hasattr(base_model, 'cnn'):
|
|
self.base_layers = nn.Sequential(*[layer for layer in base_model.cnn])
|
|
elif hasattr(base_model, 'shared_layer') and hasattr(base_model, 'unique_layers'):
|
|
layers = [base_model.shared_layer, base_model.activation, base_model.max_pool]
|
|
for ul in base_model.unique_layers:
|
|
layers.extend([ul, base_model.activation, base_model.max_pool])
|
|
self.base_layers = nn.Sequential(*layers)
|
|
else:
|
|
self.base_layers = self._build_base_layers_from_config(config)
|
|
else:
|
|
self.base_layers = self._build_base_layers_from_config(config)
|
|
|
|
# Bilinear upsampling head
|
|
self.upsample = nn.Upsample(
|
|
scale_factor=self.config.upsample_scale,
|
|
mode='bilinear',
|
|
align_corners=False
|
|
)
|
|
self.final_conv = nn.Conv2d(
|
|
in_channels=self.config.hidden_size,
|
|
out_channels=self.config.num_channels,
|
|
kernel_size=3,
|
|
padding=1
|
|
)
|
|
|
|
def _build_base_layers_from_config(self, config):
|
|
layers = []
|
|
in_channels = config.num_channels
|
|
for _ in range(config.num_hidden_layers):
|
|
layers.extend([
|
|
nn.Conv2d(in_channels, config.hidden_size,
|
|
kernel_size=config.kernel_size, padding=1),
|
|
getattr(nn, config.activation_function)(),
|
|
nn.MaxPool2d(kernel_size=1, stride=1)
|
|
])
|
|
in_channels = config.hidden_size
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
x = self.base_layers(x)
|
|
x = self.upsample(x)
|
|
x = self.final_conv(x)
|
|
return x
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from aiia import AIIABase, AIIAConfig
|
|
# Create a configuration and build a base model.
|
|
config = AIIAConfig()
|
|
ai_config = aiuNNConfig()
|
|
base_model = AIIABase(config)
|
|
# Instantiate Upsampler from the base model (works correctly).
|
|
upsampler = aiuNN(config=ai_config, base_model=base_model)
|
|
# Save the model (both configuration and weights).
|
|
upsampler.save_pretrained("aiunn")
|
|
|
|
# Now load using the overridden load method; this will load the complete model.
|
|
upsampler_loaded = aiuNN.from_pretrained("aiunn")
|
|
print("Updated configuration:", upsampler_loaded.config.__dict__)
|