develop #4
|
@ -0,0 +1,40 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
from aiia import AIIA
|
||||||
|
|
||||||
|
class Upsampler(AIIA):
|
||||||
|
def __init__(self, base_model: AIIA):
|
||||||
|
super().__init__(base_model.config)
|
||||||
|
self.base_model = base_model
|
||||||
|
|
||||||
|
# Upsample to double the spatial dimensions using bilinear interpolation
|
||||||
|
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
|
# Update the base model's configuration to include the upsample layer details
|
||||||
|
print(self.base_model.config)
|
||||||
|
if hasattr(self.base_model, 'config'):
|
||||||
|
# Check if layers attribute exists, if not create it
|
||||||
|
if not hasattr(self.base_model.config, 'layers'):
|
||||||
|
setattr(self.base_model.config, 'layers', [])
|
||||||
|
|
||||||
|
# Add the upsample layer configuration
|
||||||
|
current_layers = getattr(self.base_model.config, 'layers', [])
|
||||||
|
current_layers.append({
|
||||||
|
'name': 'Upsample',
|
||||||
|
'type': 'nn.Upsample',
|
||||||
|
'scale_factor': 2,
|
||||||
|
'mode': 'bilinear',
|
||||||
|
'align_corners': False
|
||||||
|
})
|
||||||
|
setattr(self.base_model.config, 'layers', current_layers)
|
||||||
|
self.config = self.base_model.config
|
||||||
|
else:
|
||||||
|
self.config = {}
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.base_model(x)
|
||||||
|
x = self.upsample(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
upsampler = Upsampler.load("test2")
|
||||||
|
print("Updated configuration:", upsampler.config.__dict__)
|
Loading…
Reference in New Issue