Merge pull request 'corrected modelloading to also accept kwargs when e.g. using base models in Combination with expert models' (#26) from feat/bugfix_loading into develop

Reviewed-on: #26
This commit is contained in:
Falko Victor Habel 2025-03-11 21:26:31 +00:00
commit 7f55ede587
4 changed files with 5 additions and 8 deletions

View File

@ -10,7 +10,7 @@ include = '\.pyi?$'
[project] [project]
name = "aiia" name = "aiia"
version = "0.1.5" version = "0.1.6"
description = "AIIA Deep Learning Model Implementation" description = "AIIA Deep Learning Model Implementation"
readme = "README.md" readme = "README.md"
authors = [ authors = [

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = aiia name = aiia
version = 0.1.5 version = 0.1.6
author = Falko Habel author = Falko Habel
author_email = falko.habel@gmx.de author_email = falko.habel@gmx.de
description = AIIA deep learning model implementation description = AIIA deep learning model implementation

View File

@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader
from .pretrain.pretrainer import Pretrainer, ProjectionHead from .pretrain.pretrainer import Pretrainer, ProjectionHead
__version__ = "0.1.5" __version__ = "0.1.6"

View File

@ -23,9 +23,9 @@ class AIIA(nn.Module):
self.config.save(path) self.config.save(path)
@classmethod @classmethod
def load(cls, path, precision: str = None): def load(cls, path, precision: str = None, **kwargs):
config = AIIAConfig.load(path) config = AIIAConfig.load(path)
model = cls(config) model = cls(config, **kwargs) # Pass kwargs here!
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = None dtype = None
@ -41,10 +41,7 @@ class AIIA(nn.Module):
else: else:
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.") raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
# Load the state dictionary normally (without dtype argument)
model_dict = torch.load(f"{path}/model.pth", map_location=device) model_dict = torch.load(f"{path}/model.pth", map_location=device)
# If a precision conversion is requested, cast each tensor in the state dict to the target dtype.
if dtype is not None: if dtype is not None:
for key, param in model_dict.items(): for key, param in model_dict.items():
if torch.is_tensor(param): if torch.is_tensor(param):