From e5181c30663dc61e3f568a9c9114f482a8c12d8f Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Tue, 11 Mar 2025 22:13:17 +0100 Subject: [PATCH] corrected modelloading to also accept kwargs when e.g. using base models in Combination with expert models --- pyproject.toml | 2 +- setup.cfg | 2 +- src/aiia/__init__.py | 2 +- src/aiia/model/Model.py | 7 ++----- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 62715b6..00579ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ include = '\.pyi?$' [project] name = "aiia" -version = "0.1.5" +version = "0.1.6" description = "AIIA Deep Learning Model Implementation" readme = "README.md" authors = [ diff --git a/setup.cfg b/setup.cfg index 0bd097e..60dc7ec 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = aiia -version = 0.1.5 +version = 0.1.6 author = Falko Habel author_email = falko.habel@gmx.de description = AIIA deep learning model implementation diff --git a/src/aiia/__init__.py b/src/aiia/__init__.py index 19e0fe8..5015fe7 100644 --- a/src/aiia/__init__.py +++ b/src/aiia/__init__.py @@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader from .pretrain.pretrainer import Pretrainer, ProjectionHead -__version__ = "0.1.5" +__version__ = "0.1.6" diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index 6c497bb..ad6d032 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -23,9 +23,9 @@ class AIIA(nn.Module): self.config.save(path) @classmethod - def load(cls, path, precision: str = None): + def load(cls, path, precision: str = None, **kwargs): config = AIIAConfig.load(path) - model = cls(config) + model = cls(config, **kwargs) # Pass kwargs here! device = 'cuda' if torch.cuda.is_available() else 'cpu' dtype = None @@ -41,10 +41,7 @@ class AIIA(nn.Module): else: raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.") - # Load the state dictionary normally (without dtype argument) model_dict = torch.load(f"{path}/model.pth", map_location=device) - - # If a precision conversion is requested, cast each tensor in the state dict to the target dtype. if dtype is not None: for key, param in model_dict.items(): if torch.is_tensor(param):