develop #4
|
@ -0,0 +1,36 @@
|
||||||
|
from aiia import AIIAConfig
|
||||||
|
|
||||||
|
|
||||||
|
class UpsamplerConfig(AIIAConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
upsample_scale: int = 2,
|
||||||
|
upsample_mode: str = 'bilinear',
|
||||||
|
upsample_align_corners: bool = False,
|
||||||
|
layers=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# Initialize base configuration.
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.layers = layers if layers is not None else []
|
||||||
|
|
||||||
|
# Upsampler-specific parameters.
|
||||||
|
self.upsample_scale = upsample_scale
|
||||||
|
self.upsample_mode = upsample_mode
|
||||||
|
self.upsample_align_corners = upsample_align_corners
|
||||||
|
|
||||||
|
# Automatically add the upsample layer details.
|
||||||
|
self.add_upsample_layer()
|
||||||
|
|
||||||
|
def add_upsample_layer(self):
|
||||||
|
upsample_layer = {
|
||||||
|
'name': 'Upsample',
|
||||||
|
'type': 'nn.Upsample',
|
||||||
|
'scale_factor': self.upsample_scale,
|
||||||
|
'mode': self.upsample_mode,
|
||||||
|
'align_corners': self.upsample_align_corners
|
||||||
|
}
|
||||||
|
# Add the upsample layer only if not already present.
|
||||||
|
if not any(layer.get('name') == 'Upsample' for layer in self.layers):
|
||||||
|
self.layers.append(upsample_layer)
|
||||||
|
|
|
@ -3,27 +3,19 @@ import torch.nn as nn
|
||||||
from aiia import AIIA, AIIAConfig, AIIABase
|
from aiia import AIIA, AIIAConfig, AIIABase
|
||||||
|
|
||||||
|
|
||||||
|
# Upsampler model that uses the configuration from the base model.
|
||||||
class Upsampler(AIIA):
|
class Upsampler(AIIA):
|
||||||
def init(self, base_model: AIIA):
|
def __init__(self, base_model: AIIABase):
|
||||||
# base_model must be a fully instantiated model (with a .config attribute)
|
# Assume that base_model.config is an instance of UpsamplerConfig.
|
||||||
super().init(base_model.config)
|
super().__init__(base_model.config)
|
||||||
self.base_model = base_model
|
self.base_model = base_model
|
||||||
|
|
||||||
# Upsample to double the spatial dimensions using bilinear interpolation
|
|
||||||
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
|
||||||
|
|
||||||
# Update the base model's configuration to include the upsample layer details
|
# Create the upsample layer using values from the configuration.
|
||||||
if not hasattr(self.base_model.config, 'layers'):
|
self.upsample = nn.Upsample(
|
||||||
self.base_model.config.layers = []
|
scale_factor=self.config.upsample_scale,
|
||||||
|
mode=self.config.upsample_mode,
|
||||||
self.base_model.config.layers.append({
|
align_corners=self.config.upsample_align_corners
|
||||||
'name': 'Upsample',
|
)
|
||||||
'type': 'nn.Upsample',
|
|
||||||
'scale_factor': 2,
|
|
||||||
'mode': 'bilinear',
|
|
||||||
'align_corners': False
|
|
||||||
})
|
|
||||||
self.config = self.base_model.config
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.base_model(x)
|
x = self.base_model(x)
|
||||||
|
@ -33,27 +25,20 @@ class Upsampler(AIIA):
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path: str):
|
def load(cls, path: str):
|
||||||
"""
|
"""
|
||||||
Override the default load method:
|
Load the model:
|
||||||
- First, load the base model (which includes its configuration and state_dict)
|
- First, load the base model (including its configuration and state_dict).
|
||||||
- Then instantiate the Upsampler with that base model
|
- Then, wrap it with the Upsampler class.
|
||||||
- Finally, load the Upsampler-specific state dictionary
|
- Finally, load the combined state dictionary.
|
||||||
"""
|
"""
|
||||||
# Load the full base model from the given path.
|
|
||||||
# (Assuming AIIABase.load is implemented to load the base model correctly.)
|
|
||||||
base_model = AIIABase.load(path)
|
base_model = AIIABase.load(path)
|
||||||
|
|
||||||
# Create a new instance of Upsampler using the loaded base model.
|
|
||||||
instance = cls(base_model)
|
instance = cls(base_model)
|
||||||
|
|
||||||
# Choose your device mapping (cuda if available, otherwise cpu)
|
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
# Load the saved state dictionary that contains weights for both the base model and upsample layer.
|
|
||||||
state_dict = torch.load(f"{path}/model.pth", map_location=device)
|
state_dict = torch.load(f"{path}/model.pth", map_location=device)
|
||||||
instance.load_state_dict(state_dict)
|
instance.load_state_dict(state_dict)
|
||||||
|
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "main":
|
if __name__ == "main":
|
||||||
from aiia import AIIABase, AIIAConfig
|
from aiia import AIIABase, AIIAConfig
|
||||||
# Create a configuration and build a base model.
|
# Create a configuration and build a base model.
|
||||||
|
|
Loading…
Reference in New Issue