fixed config

This commit is contained in:
Falko Victor Habel 2025-02-24 15:16:05 +01:00
parent 8935fa5e13
commit 33a3626b74
2 changed files with 25 additions and 8 deletions

View File

@ -1,27 +1,40 @@
from aiia import AIIAConfig
import os
import json
class UpsamplerConfig(AIIAConfig):
def __init__(
self,
base_config=None,
upsample_scale: int = 2,
upsample_mode: str = 'bilinear',
upsample_align_corners: bool = False,
layers=None,
**kwargs
):
# Initialize base configuration.
super().__init__(**kwargs)
self.layers = layers if layers is not None else []
# Start with a single configuration dictionary.
config_data = {}
if base_config is not None:
# If base_config is an object with a to_dict method, use it.
if hasattr(base_config, "to_dict"):
config_data.update(base_config.to_dict())
elif isinstance(base_config, dict):
config_data.update(base_config)
# Update with any additional keyword arguments (if needed).
config_data.update(kwargs)
# Initialize base AIIAConfig with a single merged configuration.
super().__init__(**config_data)
# Upsampler-specific parameters.
self.upsample_scale = upsample_scale
self.upsample_mode = upsample_mode
self.upsample_align_corners = upsample_align_corners
# Automatically add the upsample layer details.
# Use layers from the argument or initialize an empty list.
self.layers = layers if layers is not None else []
# Add the upsample layer details only once.
self.add_upsample_layer()
def add_upsample_layer(self):
@ -32,6 +45,6 @@ class UpsamplerConfig(AIIAConfig):
'mode': self.upsample_mode,
'align_corners': self.upsample_align_corners
}
# Add the upsample layer only if not already present.
# Append the layer only if it isnt already present.
if not any(layer.get('name') == 'Upsample' for layer in self.layers):
self.layers.append(upsample_layer)

View File

@ -11,7 +11,10 @@ class Upsampler(AIIA):
def __init__(self, base_model: AIIABase):
super().__init__(base_model.config)
self.base_model = base_model
self.config = UpsamplerConfig(kwargs=self.base_model.config)
# Pass the unified base configuration using the new parameter.
self.config = UpsamplerConfig(base_config=base_model.config)
self.upsample = nn.Upsample(
scale_factor=self.config.upsample_scale,
mode=self.config.upsample_mode,
@ -24,6 +27,7 @@ class Upsampler(AIIA):
kernel_size=1
)
def forward(self, x):
x = self.base_model(x)
x = self.upsample(x)