From 19bffa99d67b83999ba5547505044da7be9063eb Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sat, 15 Mar 2025 18:53:26 +0100 Subject: [PATCH] fixed model loading to support all models --- src/aiia/model/Model.py | 38 +++++++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index ad6d032..06e2924 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -23,12 +23,36 @@ class AIIA(nn.Module): self.config.save(path) @classmethod - def load(cls, path, precision: str = None, **kwargs): + def load(cls, path, precision: str = None, strict: bool = True, **kwargs): config = AIIAConfig.load(path) - model = cls(config, **kwargs) # Pass kwargs here! device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # Load the state dict to analyze structure + model_dict = torch.load(f"{path}/model.pth", map_location=device) + + # Special handling for AIIAmoe - detect number of experts from state_dict + if cls.__name__ == "AIIAmoe" and "num_experts" not in kwargs: + # Find maximum expert index + max_expert_idx = -1 + for key in model_dict.keys(): + if key.startswith("experts."): + parts = key.split(".") + if len(parts) > 1: + try: + expert_idx = int(parts[1]) + max_expert_idx = max(max_expert_idx, expert_idx) + except ValueError: + pass + + if max_expert_idx >= 0: + # experts.X keys found, use max_expert_idx + 1 as num_experts + kwargs["num_experts"] = max_expert_idx + 1 + + # Create model with detected structural parameters + model = cls(config, **kwargs) + + # Handle precision conversion dtype = None - if precision is not None: if precision.lower() == 'fp16': dtype = torch.float16 @@ -40,14 +64,14 @@ class AIIA(nn.Module): dtype = torch.bfloat16 else: raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.") - - model_dict = torch.load(f"{path}/model.pth", map_location=device) + if dtype is not None: for key, param in model_dict.items(): if torch.is_tensor(param): model_dict[key] = param.to(dtype) - - model.load_state_dict(model_dict) + + # Load state dict with strict parameter for flexibility + model.load_state_dict(model_dict, strict=strict) return model