develop #41

Merged
Fabel merged 27 commits from develop into main 2025-04-17 17:08:57 +00:00
1 changed files with 1 additions and 53 deletions
Showing only changes of commit 040ac478b9 - Show all commits

View File

@ -1,5 +1,6 @@
from .config import AIIAConfig
from torch import nn
from transformers import PretrainedModel
import torch
import os
import copy
@ -22,59 +23,6 @@ class AIIA(nn.Module):
torch.save(self.state_dict(), f"{path}/model.pth")
self.config.save(path)
@classmethod
def load(cls, path, precision: str = None, strict: bool = True, **kwargs):
config = AIIAConfig.load(path)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the state dict to analyze structure
model_dict = torch.load(f"{path}/model.pth", map_location=device)
# Special handling for AIIAmoe - detect number of experts from state_dict
if cls.__name__ == "AIIAmoe" and "num_experts" not in kwargs:
# Find maximum expert index
max_expert_idx = -1
for key in model_dict.keys():
if key.startswith("experts."):
parts = key.split(".")
if len(parts) > 1:
try:
expert_idx = int(parts[1])
max_expert_idx = max(max_expert_idx, expert_idx)
except ValueError:
pass
if max_expert_idx >= 0:
# experts.X keys found, use max_expert_idx + 1 as num_experts
kwargs["num_experts"] = max_expert_idx + 1
# Create model with detected structural parameters
model = cls(config, **kwargs)
# Handle precision conversion
dtype = None
if precision is not None:
if precision.lower() == 'fp16':
dtype = torch.float16
elif precision.lower() == 'bf16':
if device == 'cuda' and not torch.cuda.is_bf16_supported():
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
dtype = torch.float16
else:
dtype = torch.bfloat16
else:
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
if dtype is not None:
for key, param in model_dict.items():
if torch.is_tensor(param):
model_dict[key] = param.to(dtype)
# Load state dict with strict parameter for flexibility
model.load_state_dict(model_dict, strict=strict)
return model
class AIIABase(AIIA):
def __init__(self, config: AIIAConfig, **kwargs):
super().__init__(config=config, **kwargs)