first class but load is still missing, not complete
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Failing after 42s
Details
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Failing after 42s
Details
This commit is contained in:
parent
cb7a3da8a4
commit
16f8de2175
|
@ -2,13 +2,14 @@ import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import warnings
|
import warnings
|
||||||
from aiia.model.Model import AIIA, AIIAConfig, AIIABase
|
from aiia.model.Model import AIIAConfig, AIIABase
|
||||||
|
from transformers import PreTrainedModel
|
||||||
from .config import aiuNNConfig
|
from .config import aiuNNConfig
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
class aiuNN(AIIA):
|
class aiuNN(PreTrainedModel):
|
||||||
def __init__(self, base_model: AIIA, config:aiuNNConfig):
|
def __init__(self, base_model: PreTrainedModel, config:aiuNNConfig):
|
||||||
super().__init__(base_model.config)
|
super().__init__(base_model.config)
|
||||||
self.base_model = base_model
|
self.base_model = base_model
|
||||||
|
|
||||||
|
@ -33,112 +34,6 @@ class aiuNN(AIIA):
|
||||||
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
|
x = self.pixel_shuffle(x) # Rearrange channels into spatial dimensions
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, path, precision: str = None, **kwargs):
|
|
||||||
"""
|
|
||||||
Load a aiuNN model from disk with automatic detection of base model type.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str): Directory containing the stored configuration and model parameters.
|
|
||||||
precision (str, optional): Desired precision for the model's parameters.
|
|
||||||
**kwargs: Additional keyword arguments to override configuration parameters.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An instance of aiuNN with loaded weights.
|
|
||||||
"""
|
|
||||||
# Load the configuration
|
|
||||||
config = aiuNNConfig.load(path)
|
|
||||||
|
|
||||||
# Determine the device
|
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
||||||
|
|
||||||
# Load the state dictionary
|
|
||||||
state_dict = torch.load(os.path.join(path, "model.pth"), map_location=device)
|
|
||||||
|
|
||||||
# Import all model types
|
|
||||||
from aiia.model.Model import AIIABase, AIIABaseShared, AIIAExpert, AIIAmoe, AIIAchunked, AIIArecursive
|
|
||||||
|
|
||||||
# Helper function to detect base class type from key patterns
|
|
||||||
def detect_base_class_type(keys_prefix):
|
|
||||||
if any(f"{keys_prefix}.shared_layer" in key for key in state_dict.keys()):
|
|
||||||
return AIIABaseShared
|
|
||||||
else:
|
|
||||||
return AIIABase
|
|
||||||
|
|
||||||
# Detect base model type
|
|
||||||
base_model = None
|
|
||||||
|
|
||||||
# Check for AIIAmoe with multiple experts
|
|
||||||
if any("base_model.experts" in key for key in state_dict.keys()):
|
|
||||||
# Count the number of experts
|
|
||||||
max_expert_idx = -1
|
|
||||||
for key in state_dict.keys():
|
|
||||||
if "base_model.experts." in key:
|
|
||||||
try:
|
|
||||||
parts = key.split("base_model.experts.")[1].split(".")
|
|
||||||
expert_idx = int(parts[0])
|
|
||||||
max_expert_idx = max(max_expert_idx, expert_idx)
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
if max_expert_idx >= 0:
|
|
||||||
# Determine the type of base_cnn each expert is using
|
|
||||||
base_class_for_experts = detect_base_class_type("base_model.experts.0.base_cnn")
|
|
||||||
|
|
||||||
# Create AIIAmoe with the detected expert count and base class
|
|
||||||
base_model = AIIAmoe(config, num_experts=max_expert_idx+1, base_class=base_class_for_experts, **kwargs)
|
|
||||||
|
|
||||||
# Check for AIIAchunked or AIIArecursive
|
|
||||||
elif any("base_model.chunked_cnn" in key for key in state_dict.keys()):
|
|
||||||
if any("recursion_depth" in key for key in state_dict.keys()):
|
|
||||||
# This is an AIIArecursive model
|
|
||||||
base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn")
|
|
||||||
base_model = AIIArecursive(config, base_class=base_class, **kwargs)
|
|
||||||
else:
|
|
||||||
# This is an AIIAchunked model
|
|
||||||
base_class = detect_base_class_type("base_model.chunked_cnn.base_cnn")
|
|
||||||
base_model = AIIAchunked(config, base_class=base_class, **kwargs)
|
|
||||||
|
|
||||||
# Check for AIIAExpert
|
|
||||||
elif any("base_model.base_cnn" in key for key in state_dict.keys()):
|
|
||||||
# Determine which base class the expert is using
|
|
||||||
base_class = detect_base_class_type("base_model.base_cnn")
|
|
||||||
base_model = AIIAExpert(config, base_class=base_class, **kwargs)
|
|
||||||
|
|
||||||
# If none of the above, use AIIABase or AIIABaseShared directly
|
|
||||||
else:
|
|
||||||
base_class = detect_base_class_type("base_model")
|
|
||||||
base_model = base_class(config, **kwargs)
|
|
||||||
|
|
||||||
# Create the aiuNN model with the detected base model
|
|
||||||
model = cls(base_model, config=base_model.config)
|
|
||||||
|
|
||||||
# Handle precision conversion
|
|
||||||
dtype = None
|
|
||||||
if precision is not None:
|
|
||||||
if precision.lower() == 'fp16':
|
|
||||||
dtype = torch.float16
|
|
||||||
elif precision.lower() == 'bf16':
|
|
||||||
if device == 'cuda' and not torch.cuda.is_bf16_supported():
|
|
||||||
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
|
|
||||||
dtype = torch.float16
|
|
||||||
else:
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
|
|
||||||
|
|
||||||
if dtype is not None:
|
|
||||||
for key, param in state_dict.items():
|
|
||||||
if torch.is_tensor(param):
|
|
||||||
state_dict[key] = param.to(dtype)
|
|
||||||
|
|
||||||
# Load the state dict
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from aiia import AIIABase, AIIAConfig
|
from aiia import AIIABase, AIIAConfig
|
||||||
# Create a configuration and build a base model.
|
# Create a configuration and build a base model.
|
||||||
|
@ -149,7 +44,7 @@ if __name__ == "__main__":
|
||||||
upsampler = aiuNN(base_model, config=ai_config)
|
upsampler = aiuNN(base_model, config=ai_config)
|
||||||
|
|
||||||
# Save the model (both configuration and weights).
|
# Save the model (both configuration and weights).
|
||||||
upsampler.save("aiunn")
|
upsampler.save_pretrained("aiunn")
|
||||||
|
|
||||||
# Now load using the overridden load method; this will load the complete model.
|
# Now load using the overridden load method; this will load the complete model.
|
||||||
upsampler_loaded = aiuNN.load("aiunn", precision="bf16")
|
upsampler_loaded = aiuNN.load("aiunn", precision="bf16")
|
||||||
|
|
Loading…
Reference in New Issue