finetune_class #1
|
@ -1,27 +1,40 @@
|
||||||
from aiia import AIIAConfig
|
from aiia import AIIAConfig
|
||||||
import os
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
class UpsamplerConfig(AIIAConfig):
|
class UpsamplerConfig(AIIAConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
base_config=None,
|
||||||
upsample_scale: int = 2,
|
upsample_scale: int = 2,
|
||||||
upsample_mode: str = 'bilinear',
|
upsample_mode: str = 'bilinear',
|
||||||
upsample_align_corners: bool = False,
|
upsample_align_corners: bool = False,
|
||||||
layers=None,
|
layers=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
# Initialize base configuration.
|
# Start with a single configuration dictionary.
|
||||||
super().__init__(**kwargs)
|
config_data = {}
|
||||||
self.layers = layers if layers is not None else []
|
if base_config is not None:
|
||||||
|
# If base_config is an object with a to_dict method, use it.
|
||||||
|
if hasattr(base_config, "to_dict"):
|
||||||
|
config_data.update(base_config.to_dict())
|
||||||
|
elif isinstance(base_config, dict):
|
||||||
|
config_data.update(base_config)
|
||||||
|
|
||||||
|
# Update with any additional keyword arguments (if needed).
|
||||||
|
config_data.update(kwargs)
|
||||||
|
|
||||||
|
# Initialize base AIIAConfig with a single merged configuration.
|
||||||
|
super().__init__(**config_data)
|
||||||
|
|
||||||
# Upsampler-specific parameters.
|
# Upsampler-specific parameters.
|
||||||
self.upsample_scale = upsample_scale
|
self.upsample_scale = upsample_scale
|
||||||
self.upsample_mode = upsample_mode
|
self.upsample_mode = upsample_mode
|
||||||
self.upsample_align_corners = upsample_align_corners
|
self.upsample_align_corners = upsample_align_corners
|
||||||
|
|
||||||
# Automatically add the upsample layer details.
|
# Use layers from the argument or initialize an empty list.
|
||||||
|
self.layers = layers if layers is not None else []
|
||||||
|
|
||||||
|
# Add the upsample layer details only once.
|
||||||
self.add_upsample_layer()
|
self.add_upsample_layer()
|
||||||
|
|
||||||
def add_upsample_layer(self):
|
def add_upsample_layer(self):
|
||||||
|
@ -32,6 +45,6 @@ class UpsamplerConfig(AIIAConfig):
|
||||||
'mode': self.upsample_mode,
|
'mode': self.upsample_mode,
|
||||||
'align_corners': self.upsample_align_corners
|
'align_corners': self.upsample_align_corners
|
||||||
}
|
}
|
||||||
# Add the upsample layer only if not already present.
|
# Append the layer only if it isn’t already present.
|
||||||
if not any(layer.get('name') == 'Upsample' for layer in self.layers):
|
if not any(layer.get('name') == 'Upsample' for layer in self.layers):
|
||||||
self.layers.append(upsample_layer)
|
self.layers.append(upsample_layer)
|
||||||
|
|
|
@ -11,7 +11,10 @@ class Upsampler(AIIA):
|
||||||
def __init__(self, base_model: AIIABase):
|
def __init__(self, base_model: AIIABase):
|
||||||
super().__init__(base_model.config)
|
super().__init__(base_model.config)
|
||||||
self.base_model = base_model
|
self.base_model = base_model
|
||||||
self.config = UpsamplerConfig(kwargs=self.base_model.config)
|
|
||||||
|
# Pass the unified base configuration using the new parameter.
|
||||||
|
self.config = UpsamplerConfig(base_config=base_model.config)
|
||||||
|
|
||||||
self.upsample = nn.Upsample(
|
self.upsample = nn.Upsample(
|
||||||
scale_factor=self.config.upsample_scale,
|
scale_factor=self.config.upsample_scale,
|
||||||
mode=self.config.upsample_mode,
|
mode=self.config.upsample_mode,
|
||||||
|
@ -24,6 +27,7 @@ class Upsampler(AIIA):
|
||||||
kernel_size=1
|
kernel_size=1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.base_model(x)
|
x = self.base_model(x)
|
||||||
x = self.upsample(x)
|
x = self.upsample(x)
|
||||||
|
|
Loading…
Reference in New Issue