fixed model loading to support all models

This commit is contained in:
Falko Victor Habel 2025-03-15 18:53:26 +01:00
parent 5f16b6ca6c
commit 19bffa99d6
1 changed files with 31 additions and 7 deletions

View File

@ -23,12 +23,36 @@ class AIIA(nn.Module):
self.config.save(path) self.config.save(path)
@classmethod @classmethod
def load(cls, path, precision: str = None, **kwargs): def load(cls, path, precision: str = None, strict: bool = True, **kwargs):
config = AIIAConfig.load(path) config = AIIAConfig.load(path)
model = cls(config, **kwargs) # Pass kwargs here!
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the state dict to analyze structure
model_dict = torch.load(f"{path}/model.pth", map_location=device)
# Special handling for AIIAmoe - detect number of experts from state_dict
if cls.__name__ == "AIIAmoe" and "num_experts" not in kwargs:
# Find maximum expert index
max_expert_idx = -1
for key in model_dict.keys():
if key.startswith("experts."):
parts = key.split(".")
if len(parts) > 1:
try:
expert_idx = int(parts[1])
max_expert_idx = max(max_expert_idx, expert_idx)
except ValueError:
pass
if max_expert_idx >= 0:
# experts.X keys found, use max_expert_idx + 1 as num_experts
kwargs["num_experts"] = max_expert_idx + 1
# Create model with detected structural parameters
model = cls(config, **kwargs)
# Handle precision conversion
dtype = None dtype = None
if precision is not None: if precision is not None:
if precision.lower() == 'fp16': if precision.lower() == 'fp16':
dtype = torch.float16 dtype = torch.float16
@ -40,14 +64,14 @@ class AIIA(nn.Module):
dtype = torch.bfloat16 dtype = torch.bfloat16
else: else:
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.") raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
model_dict = torch.load(f"{path}/model.pth", map_location=device)
if dtype is not None: if dtype is not None:
for key, param in model_dict.items(): for key, param in model_dict.items():
if torch.is_tensor(param): if torch.is_tensor(param):
model_dict[key] = param.to(dtype) model_dict[key] = param.to(dtype)
model.load_state_dict(model_dict) # Load state dict with strict parameter for flexibility
model.load_state_dict(model_dict, strict=strict)
return model return model