From 040ac478b994262238193ef15d730b33016de73c Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Fri, 11 Apr 2025 22:37:44 +0200 Subject: [PATCH] removed loading and saving functions since tf will take over --- src/aiia/model/Model.py | 54 +---------------------------------------- 1 file changed, 1 insertion(+), 53 deletions(-) diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index abcc34a..118acb0 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -1,5 +1,6 @@ from .config import AIIAConfig from torch import nn +from transformers import PretrainedModel import torch import os import copy @@ -22,59 +23,6 @@ class AIIA(nn.Module): torch.save(self.state_dict(), f"{path}/model.pth") self.config.save(path) - @classmethod - def load(cls, path, precision: str = None, strict: bool = True, **kwargs): - config = AIIAConfig.load(path) - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - # Load the state dict to analyze structure - model_dict = torch.load(f"{path}/model.pth", map_location=device) - - # Special handling for AIIAmoe - detect number of experts from state_dict - if cls.__name__ == "AIIAmoe" and "num_experts" not in kwargs: - # Find maximum expert index - max_expert_idx = -1 - for key in model_dict.keys(): - if key.startswith("experts."): - parts = key.split(".") - if len(parts) > 1: - try: - expert_idx = int(parts[1]) - max_expert_idx = max(max_expert_idx, expert_idx) - except ValueError: - pass - - if max_expert_idx >= 0: - # experts.X keys found, use max_expert_idx + 1 as num_experts - kwargs["num_experts"] = max_expert_idx + 1 - - # Create model with detected structural parameters - model = cls(config, **kwargs) - - # Handle precision conversion - dtype = None - if precision is not None: - if precision.lower() == 'fp16': - dtype = torch.float16 - elif precision.lower() == 'bf16': - if device == 'cuda' and not torch.cuda.is_bf16_supported(): - warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.") - dtype = torch.float16 - else: - dtype = torch.bfloat16 - else: - raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.") - - if dtype is not None: - for key, param in model_dict.items(): - if torch.is_tensor(param): - model_dict[key] = param.to(dtype) - - # Load state dict with strict parameter for flexibility - model.load_state_dict(model_dict, strict=strict) - return model - - class AIIABase(AIIA): def __init__(self, config: AIIAConfig, **kwargs): super().__init__(config=config, **kwargs)