Merge pull request 'corrected loading / saving model' (#5) from feat/saving_model into develop

Reviewed-on: #5
This commit is contained in:
Falko Victor Habel 2025-03-26 12:29:42 +00:00
commit a00ffac8c5
1 changed files with 87 additions and 14 deletions

View File

@ -20,10 +20,10 @@ class aiuNN(AIIA):
mode=self.config.upsample_mode,
align_corners=self.config.upsample_align_corners
)
# Conversion layer: change from hidden size channels to 3 channels.
# Conversion layer: change from hidden size channels to number of channels from the config.
self.to_rgb = nn.Conv2d(
in_channels=self.base_model.config.hidden_size,
out_channels=3,
out_channels=self.base_model.config.num_channels,
kernel_size=1
)
@ -35,17 +35,87 @@ class aiuNN(AIIA):
return x
@classmethod
def load(cls, path, precision: str = None):
# Load the configuration from disk.
config = AIIAConfig.load(path)
# Reconstruct the base model from the loaded configuration.
base_model = AIIABase(config)
# Instantiate the Upsampler using the proper base model.
upsampler = cls(base_model)
def load(cls, path, precision: str = None, **kwargs):
"""
Load a aiuNN model from disk with automatic detection of base model type.
# Load state dict and handle precision conversion if needed.
Args:
path (str): Directory containing the stored configuration and model parameters.
precision (str, optional): Desired precision for the model's parameters.
**kwargs: Additional keyword arguments to override configuration parameters.
Returns:
An instance of aiuNN with loaded weights.
"""
# Load the configuration
config = aiuNNConfig.load(path)
# Determine the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
state_dict = torch.load(f"{path}/model.pth", map_location=device)
# Load the state dictionary
state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device)
# Import all model types
from aiia import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive
# Helper function to detect base class type from key patterns
def detect_base_class_type(keys_prefix):
if any(f"{keys_prefix}.shared_layer" in key for key in state_dict.keys()):
return AIIABaseShared
else:
return AIIABase
# Detect base model type
base_model = None
# Check for AIIAmoe with multiple experts
if any("base_model.experts" in key for key in state_dict.keys()):
# Count the number of experts
max_expert_idx = -1
for key in state_dict.keys():
if "base_model.experts." in key:
try:
parts = key.split("base_model.experts.")[1].split(".")
expert_idx = int(parts[0])
max_expert_idx = max(max_expert_idx, expert_idx)
except (ValueError, IndexError):
pass
if max_expert_idx >= 0:
# Determine the type of base_cnn each expert is using
base_class_for_experts = detect_base_class_type("base_model.experts.0.base_cnn")
# Create AIIAmoe with the detected expert count and base class
base_model = AIIAmoe(config, num_experts=max_expert_idx+1, base_class=base_class_for_experts, **kwargs)
# Check for AIIAchunked or AIIArecursive
elif any("base_model.chunked_cnn" in key for key in state_dict.keys()):
if any("recursion_depth" in key for key in state_dict.keys()):
# This is an AIIArecursive model
base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn")
base_model = AIIArecursive(config, base_class=base_class, **kwargs)
else:
# This is an AIIAchunked model
base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn")
base_model = AIIAchunked(config, base_class=base_class, **kwargs)
# Check for AIIAExpert
elif any("base_model.base_cnn" in key for key in state_dict.keys()):
# Determine which base class the expert is using
base_class = detect_base_class_type("base_model.base_cnn")
base_model = AIIAExpert(config, base_class=base_class, **kwargs)
# If none of the above, use AIIABase or AIIABaseShared directly
else:
base_class = detect_base_class_type("base_model")
base_model = base_class(config, **kwargs)
# Create the aiuNN model with the detected base model
model = cls(base_model)
# Handle precision conversion
dtype = None
if precision is not None:
if precision.lower() == 'fp16':
dtype = torch.float16
@ -57,12 +127,15 @@ class aiuNN(AIIA):
dtype = torch.bfloat16
else:
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
if dtype is not None:
for key, param in state_dict.items():
if torch.is_tensor(param):
state_dict[key] = param.to(dtype)
upsampler.load_state_dict(state_dict)
return upsampler
# Load the state dict
model.load_state_dict(state_dict)
return model