develop #41

Merged
Fabel merged 27 commits from develop into main 2025-04-17 17:08:57 +00:00
1 changed files with 1 additions and 53 deletions
Showing only changes of commit 040ac478b9 - Show all commits

View File

@ -1,5 +1,6 @@
from .config import AIIAConfig from .config import AIIAConfig
from torch import nn from torch import nn
from transformers import PretrainedModel
import torch import torch
import os import os
import copy import copy
@ -22,59 +23,6 @@ class AIIA(nn.Module):
torch.save(self.state_dict(), f"{path}/model.pth") torch.save(self.state_dict(), f"{path}/model.pth")
self.config.save(path) self.config.save(path)
@classmethod
def load(cls, path, precision: str = None, strict: bool = True, **kwargs):
config = AIIAConfig.load(path)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the state dict to analyze structure
model_dict = torch.load(f"{path}/model.pth", map_location=device)
# Special handling for AIIAmoe - detect number of experts from state_dict
if cls.__name__ == "AIIAmoe" and "num_experts" not in kwargs:
# Find maximum expert index
max_expert_idx = -1
for key in model_dict.keys():
if key.startswith("experts."):
parts = key.split(".")
if len(parts) > 1:
try:
expert_idx = int(parts[1])
max_expert_idx = max(max_expert_idx, expert_idx)
except ValueError:
pass
if max_expert_idx >= 0:
# experts.X keys found, use max_expert_idx + 1 as num_experts
kwargs["num_experts"] = max_expert_idx + 1
# Create model with detected structural parameters
model = cls(config, **kwargs)
# Handle precision conversion
dtype = None
if precision is not None:
if precision.lower() == 'fp16':
dtype = torch.float16
elif precision.lower() == 'bf16':
if device == 'cuda' and not torch.cuda.is_bf16_supported():
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
dtype = torch.float16
else:
dtype = torch.bfloat16
else:
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
if dtype is not None:
for key, param in model_dict.items():
if torch.is_tensor(param):
model_dict[key] = param.to(dtype)
# Load state dict with strict parameter for flexibility
model.load_state_dict(model_dict, strict=strict)
return model
class AIIABase(AIIA): class AIIABase(AIIA):
def __init__(self, config: AIIAConfig, **kwargs): def __init__(self, config: AIIAConfig, **kwargs):
super().__init__(config=config, **kwargs) super().__init__(config=config, **kwargs)