doubled config

This commit is contained in:
Falko Victor Habel 2025-02-24 14:57:55 +01:00
parent e114023cbc
commit 8935fa5e13
2 changed files with 43 additions and 20 deletions

View File

@ -1,4 +1,6 @@
from aiia import AIIAConfig from aiia import AIIAConfig
import os
import json
class UpsamplerConfig(AIIAConfig): class UpsamplerConfig(AIIAConfig):
@ -33,4 +35,3 @@ class UpsamplerConfig(AIIAConfig):
# Add the upsample layer only if not already present. # Add the upsample layer only if not already present.
if not any(layer.get('name') == 'Upsample' for layer in self.layers): if not any(layer.get('name') == 'Upsample' for layer in self.layers):
self.layers.append(upsample_layer) self.layers.append(upsample_layer)

View File

@ -1,9 +1,12 @@
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import warnings
from aiia import AIIA, AIIAConfig, AIIABase from aiia import AIIA, AIIAConfig, AIIABase
from config import UpsamplerConfig from config import UpsamplerConfig
import warnings
# Upsampler model that uses the configuration from the base model.
class Upsampler(AIIA): class Upsampler(AIIA):
def __init__(self, base_model: AIIABase): def __init__(self, base_model: AIIABase):
super().__init__(base_model.config) super().__init__(base_model.config)
@ -14,8 +17,12 @@ class Upsampler(AIIA):
mode=self.config.upsample_mode, mode=self.config.upsample_mode,
align_corners=self.config.upsample_align_corners align_corners=self.config.upsample_align_corners
) )
# Conversion layer: change from 512 channels to 3 channels. # Conversion layer: change from hidden size channels to 3 channels.
self.to_rgb = nn.Conv2d(in_channels=self.base_model.config.hidden_size, out_channels=3, kernel_size=1) self.to_rgb = nn.Conv2d(
in_channels=self.base_model.config.hidden_size,
out_channels=3,
kernel_size=1
)
def forward(self, x): def forward(self, x):
x = self.base_model(x) x = self.base_model(x)
@ -24,33 +31,48 @@ class Upsampler(AIIA):
return x return x
@classmethod @classmethod
def load(cls, path: str): def load(cls, path, precision: str = None):
""" # Load the configuration from disk.
Load the model: config = AIIAConfig.load(path)
- First, load the base model (including its configuration and state_dict). # Reconstruct the base model from the loaded configuration.
- Then, wrap it with the Upsampler class. base_model = AIIABase(config)
- Finally, load the combined state dictionary. # Instantiate the Upsampler using the proper base model.
""" upsampler = cls(base_model)
base_model = AIIABase.load(path)
instance = cls(base_model)
# Load state dict and handle precision conversion if needed.
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
state_dict = torch.load(f"{path}/model.pth", map_location=device) state_dict = torch.load(f"{path}/model.pth", map_location=device)
instance.load_state_dict(state_dict) if precision is not None:
return instance if precision.lower() == 'fp16':
dtype = torch.float16
elif precision.lower() == 'bf16':
if device == 'cuda' and not torch.cuda.is_bf16_supported():
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
dtype = torch.float16
else:
dtype = torch.bfloat16
else:
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
for key, param in state_dict.items():
if torch.is_tensor(param):
state_dict[key] = param.to(dtype)
upsampler.load_state_dict(state_dict)
return upsampler
if __name__ == "main":
if __name__ == "__main__":
from aiia import AIIABase, AIIAConfig from aiia import AIIABase, AIIAConfig
# Create a configuration and build a base model. # Create a configuration and build a base model.
config = AIIAConfig() config = AIIAConfig()
base_model = AIIABase("test2") base_model = AIIABase(config)
# Instantiate Upsampler from the base model (works correctly). # Instantiate Upsampler from the base model (works correctly).
upsampler = Upsampler(base_model) upsampler = Upsampler(base_model)
# Save the model (both configuration and weights). # Save the model (both configuration and weights).
upsampler.save("test2") upsampler.save("hehe")
# Now load using the overridden load method; this will load the complete model. # Now load using the overridden load method; this will load the complete model.
upsampler_loaded = Upsampler.load("test2") upsampler_loaded = Upsampler.load("hehe", precision="bf16")
print("Updated configuration:", upsampler_loaded.config.__dict__) print("Updated configuration:", upsampler_loaded.config.__dict__)