first class but load is still missing, not complete
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Failing after 42s Details

This commit is contained in:
Falko Victor Habel 2025-04-18 23:45:33 +02:00
parent cb7a3da8a4
commit 16f8de2175
1 changed files with 5 additions and 110 deletions

View File

@ -2,13 +2,14 @@ import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import warnings import warnings
from aiia.model.Model import AIIA, AIIAConfig, AIIABase from aiia.model.Model import AIIAConfig, AIIABase
from transformers import PreTrainedModel
from .config import aiuNNConfig from .config import aiuNNConfig
import warnings import warnings
class aiuNN(AIIA): class aiuNN(PreTrainedModel):
def __init__(self, base_model: AIIA, config:aiuNNConfig): def __init__(self, base_model: PreTrainedModel, config:aiuNNConfig):
super().__init__(base_model.config) super().__init__(base_model.config)
self.base_model = base_model self.base_model = base_model
@ -33,112 +34,6 @@ class aiuNN(AIIA):
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
return x return x
@classmethod
def load(cls, path, precision: str = None, **kwargs):
"""
Load a aiuNN model from disk with automatic detection of base model type.
Args:
path (str): Directory containing the stored configuration and model parameters.
precision (str, optional): Desired precision for the model's parameters.
**kwargs: Additional keyword arguments to override configuration parameters.
Returns:
An instance of aiuNN with loaded weights.
"""
# Load the configuration
config = aiuNNConfig.load(path)
# Determine the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the state dictionary
state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device)
# Import all model types
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive
# Helper function to detect base class type from key patterns
def detect_base_class_type(keys_prefix):
if any(f"{keys_prefix}.shared_layer" in key for key in state_dict.keys()):
return AIIABaseShared
else:
return AIIABase
# Detect base model type
base_model = None
# Check for AIIAmoe with multiple experts
if any("base_model.experts" in key for key in state_dict.keys()):
# Count the number of experts
max_expert_idx = -1
for key in state_dict.keys():
if "base_model.experts." in key:
try:
parts = key.split("base_model.experts.")[1].split(".")
expert_idx = int(parts[0])
max_expert_idx = max(max_expert_idx, expert_idx)
except (ValueError, IndexError):
pass
if max_expert_idx >= 0:
# Determine the type of base_cnn each expert is using
base_class_for_experts = detect_base_class_type("base_model.experts.0.base_cnn")
# Create AIIAmoe with the detected expert count and base class
base_model = AIIAmoe(config, num_experts=max_expert_idx+1, base_class=base_class_for_experts, **kwargs)
# Check for AIIAchunked or AIIArecursive
elif any("base_model.chunked_cnn" in key for key in state_dict.keys()):
if any("recursion_depth" in key for key in state_dict.keys()):
# This is an AIIArecursive model
base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn")
base_model = AIIArecursive(config, base_class=base_class, **kwargs)
else:
# This is an AIIAchunked model
base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn")
base_model = AIIAchunked(config, base_class=base_class, **kwargs)
# Check for AIIAExpert
elif any("base_model.base_cnn" in key for key in state_dict.keys()):
# Determine which base class the expert is using
base_class = detect_base_class_type("base_model.base_cnn")
base_model = AIIAExpert(config, base_class=base_class, **kwargs)
# If none of the above, use AIIABase or AIIABaseShared directly
else:
base_class = detect_base_class_type("base_model")
base_model = base_class(config, **kwargs)
# Create the aiuNN model with the detected base model
model = cls(base_model, config=base_model.config)
# Handle precision conversion
dtype = None
if precision is not None:
if precision.lower() == 'fp16':
dtype = torch.float16
elif precision.lower() == 'bf16':
if device == 'cuda' and not torch.cuda.is_bf16_supported():
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
dtype = torch.float16
else:
dtype = torch.bfloat16
else:
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
if dtype is not None:
for key, param in state_dict.items():
if torch.is_tensor(param):
state_dict[key] = param.to(dtype)
# Load the state dict
model.load_state_dict(state_dict)
return model
if __name__ == "__main__": if __name__ == "__main__":
from aiia import AIIABase, AIIAConfig from aiia import AIIABase, AIIAConfig
# Create a configuration and build a base model. # Create a configuration and build a base model.
@ -149,7 +44,7 @@ if __name__ == "__main__":
upsampler = aiuNN(base_model, config=ai_config) upsampler = aiuNN(base_model, config=ai_config)
# Save the model (both configuration and weights). # Save the model (both configuration and weights).
upsampler.save("aiunn") upsampler.save_pretrained("aiunn")
# Now load using the overridden load method; this will load the complete model. # Now load using the overridden load method; this will load the complete model.
upsampler_loaded = aiuNN.load("aiunn", precision="bf16") upsampler_loaded = aiuNN.load("aiunn", precision="bf16")