diff --git a/src/aiia/model/Model.py b/src/aiia/model/Model.py index f0e65ff..c188464 100644 --- a/src/aiia/model/Model.py +++ b/src/aiia/model/Model.py @@ -25,7 +25,10 @@ class AIIA(nn.Module): def load(cls, path): config = AIIAConfig.load(path) model = cls(config) - model.load_state_dict(torch.load(f"{path}/model.pth")) + # Check if CUDA is available and set the device accordingly + device = 'cuda' if torch.cuda.is_available() else 'cpu' + # Load the state dictionary with the correct device mapping + model.load_state_dict(torch.load(f"{path}/model.pth", map_location=device)) return model