From f8e59c589655013d5bc63c9da60e0dad58dbb60b Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Fri, 31 Jan 2025 09:13:58 +0100 Subject: [PATCH] added cpu support when loading the model --- src/aiia/model/Model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index f0e65ff..c188464 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -25,7 +25,10 @@ class AIIA(nn.Module): def load(cls, path): config = AIIAConfig.load(path) model = cls(config) - model.load_state_dict(torch.load(f"{path}/model.pth")) + # Check if CUDA is available and set the device accordingly + device = 'cuda' if torch.cuda.is_available() else 'cpu' + # Load the state dictionary with the correct device mapping + model.load_state_dict(torch.load(f"{path}/model.pth", map_location=device)) return model