doubled config

This commit is contained in:
Falko Victor Habel 2025-02-24 14:57:55 +01:00
parent e114023cbc
commit 8935fa5e13
2 changed files with 43 additions and 20 deletions

View File

@ -1,4 +1,6 @@
from aiia import AIIAConfig
import os
import json
class UpsamplerConfig(AIIAConfig):
@ -33,4 +35,3 @@ class UpsamplerConfig(AIIAConfig):
# Add the upsample layer only if not already present.
if not any(layer.get('name') == 'Upsample' for layer in self.layers):
self.layers.append(upsample_layer)

View File

@ -1,9 +1,12 @@
import os
import torch
import torch.nn as nn
import warnings
from aiia import AIIA, AIIAConfig, AIIABase
from config import UpsamplerConfig
import warnings
# Upsampler model that uses the configuration from the base model.
class Upsampler(AIIA):
def __init__(self, base_model: AIIABase):
super().__init__(base_model.config)
@ -14,43 +17,62 @@ class Upsampler(AIIA):
mode=self.config.upsample_mode,
align_corners=self.config.upsample_align_corners
)
# Conversion layer: change from 512 channels to 3 channels.
self.to_rgb = nn.Conv2d(in_channels=self.base_model.config.hidden_size, out_channels=3, kernel_size=1)
# Conversion layer: change from hidden size channels to 3 channels.
self.to_rgb = nn.Conv2d(
in_channels=self.base_model.config.hidden_size,
out_channels=3,
kernel_size=1
)
def forward(self, x):
x = self.base_model(x)
x = self.upsample(x)
x = self.to_rgb(x) # Ensures output has 3 channels.
return x
@classmethod
def load(cls, path: str):
"""
Load the model:
- First, load the base model (including its configuration and state_dict).
- Then, wrap it with the Upsampler class.
- Finally, load the combined state dictionary.
"""
base_model = AIIABase.load(path)
instance = cls(base_model)
def load(cls, path, precision: str = None):
# Load the configuration from disk.
config = AIIAConfig.load(path)
# Reconstruct the base model from the loaded configuration.
base_model = AIIABase(config)
# Instantiate the Upsampler using the proper base model.
upsampler = cls(base_model)
# Load state dict and handle precision conversion if needed.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
state_dict = torch.load(f"{path}/model.pth", map_location=device)
instance.load_state_dict(state_dict)
return instance
if precision is not None:
if precision.lower() == 'fp16':
dtype = torch.float16
elif precision.lower() == 'bf16':
if device == 'cuda' and not torch.cuda.is_bf16_supported():
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
dtype = torch.float16
else:
dtype = torch.bfloat16
else:
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
for key, param in state_dict.items():
if torch.is_tensor(param):
state_dict[key] = param.to(dtype)
upsampler.load_state_dict(state_dict)
return upsampler
if __name__ == "main":
if __name__ == "__main__":
from aiia import AIIABase, AIIAConfig
# Create a configuration and build a base model.
config = AIIAConfig()
base_model = AIIABase("test2")
base_model = AIIABase(config)
# Instantiate Upsampler from the base model (works correctly).
upsampler = Upsampler(base_model)
# Save the model (both configuration and weights).
upsampler.save("test2")
upsampler.save("hehe")
# Now load using the overridden load method; this will load the complete model.
upsampler_loaded = Upsampler.load("test2")
upsampler_loaded = Upsampler.load("hehe", precision="bf16")
print("Updated configuration:", upsampler_loaded.config.__dict__)