updated upsampler

This commit is contained in:
Falko Victor Habel 2025-02-22 17:52:37 +01:00
parent 933238a530
commit 2aba93caea
1 changed files with 54 additions and 25 deletions

View File

@ -1,40 +1,69 @@
import torch
import torch.nn as nn
from aiia import AIIA
from aiia import AIIA, AIIAConfig, AIIABase
class Upsampler(AIIA):
def __init__(self, base_model: AIIA):
super().__init__(base_model.config)
def init(self, base_model: AIIA):
# base_model must be a fully instantiated model (with a .config attribute)
super().init(base_model.config)
self.base_model = base_model
# Upsample to double the spatial dimensions using bilinear interpolation
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
# Update the base model's configuration to include the upsample layer details
print(self.base_model.config)
if hasattr(self.base_model, 'config'):
# Check if layers attribute exists, if not create it
if not hasattr(self.base_model.config, 'layers'):
setattr(self.base_model.config, 'layers', [])
# Add the upsample layer configuration
current_layers = getattr(self.base_model.config, 'layers', [])
current_layers.append({
'name': 'Upsample',
'type': 'nn.Upsample',
'scale_factor': 2,
'mode': 'bilinear',
'align_corners': False
})
setattr(self.base_model.config, 'layers', current_layers)
self.config = self.base_model.config
else:
self.config = {}
if not hasattr(self.base_model.config, 'layers'):
self.base_model.config.layers = []
self.base_model.config.layers.append({
'name': 'Upsample',
'type': 'nn.Upsample',
'scale_factor': 2,
'mode': 'bilinear',
'align_corners': False
})
self.config = self.base_model.config
def forward(self, x):
x = self.base_model(x)
x = self.upsample(x)
return x
if __name__ == "__main__":
upsampler = Upsampler.load("test2")
print("Updated configuration:", upsampler.config.__dict__)
@classmethod
def load(cls, path: str):
"""
Override the default load method:
- First, load the base model (which includes its configuration and state_dict)
- Then instantiate the Upsampler with that base model
- Finally, load the Upsampler-specific state dictionary
"""
# Load the full base model from the given path.
# (Assuming AIIABase.load is implemented to load the base model correctly.)
base_model = AIIABase.load(path)
# Create a new instance of Upsampler using the loaded base model.
instance = cls(base_model)
# Choose your device mapping (cuda if available, otherwise cpu)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the saved state dictionary that contains weights for both the base model and upsample layer.
state_dict = torch.load(f"{path}/model.pth", map_location=device)
instance.load_state_dict(state_dict)
return instance
if __name__ == "main":
from aiia import AIIABase, AIIAConfig
# Create a configuration and build a base model.
config = AIIAConfig()
base_model = AIIABase("test2")
# Instantiate Upsampler from the base model (works correctly).
upsampler = Upsampler(base_model)
# Save the model (both configuration and weights).
upsampler.save("test2")
# Now load using the overridden load method; this will load the complete model.
upsampler_loaded = Upsampler.load("test2")
print("Updated configuration:", upsampler_loaded.config.__dict__)