Merge pull request 'corrected modelloading to also accept kwargs when e.g. using base models in Combination with expert models' (#26) from feat/bugfix_loading into develop
Reviewed-on: #26
This commit is contained in:
commit
7f55ede587
|
@ -10,7 +10,7 @@ include = '\.pyi?$'
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "aiia"
|
name = "aiia"
|
||||||
version = "0.1.5"
|
version = "0.1.6"
|
||||||
description = "AIIA Deep Learning Model Implementation"
|
description = "AIIA Deep Learning Model Implementation"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[metadata]
|
[metadata]
|
||||||
name = aiia
|
name = aiia
|
||||||
version = 0.1.5
|
version = 0.1.6
|
||||||
author = Falko Habel
|
author = Falko Habel
|
||||||
author_email = falko.habel@gmx.de
|
author_email = falko.habel@gmx.de
|
||||||
description = AIIA deep learning model implementation
|
description = AIIA deep learning model implementation
|
||||||
|
|
|
@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader
|
||||||
from .pretrain.pretrainer import Pretrainer, ProjectionHead
|
from .pretrain.pretrainer import Pretrainer, ProjectionHead
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.1.5"
|
__version__ = "0.1.6"
|
||||||
|
|
|
@ -23,9 +23,9 @@ class AIIA(nn.Module):
|
||||||
self.config.save(path)
|
self.config.save(path)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path, precision: str = None):
|
def load(cls, path, precision: str = None, **kwargs):
|
||||||
config = AIIAConfig.load(path)
|
config = AIIAConfig.load(path)
|
||||||
model = cls(config)
|
model = cls(config, **kwargs) # Pass kwargs here!
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
dtype = None
|
dtype = None
|
||||||
|
|
||||||
|
@ -41,10 +41,7 @@ class AIIA(nn.Module):
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
|
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
|
||||||
|
|
||||||
# Load the state dictionary normally (without dtype argument)
|
|
||||||
model_dict = torch.load(f"{path}/model.pth", map_location=device)
|
model_dict = torch.load(f"{path}/model.pth", map_location=device)
|
||||||
|
|
||||||
# If a precision conversion is requested, cast each tensor in the state dict to the target dtype.
|
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
for key, param in model_dict.items():
|
for key, param in model_dict.items():
|
||||||
if torch.is_tensor(param):
|
if torch.is_tensor(param):
|
||||||
|
|
Loading…
Reference in New Issue