Merge pull request 'added cpu support when loading the model' (#6) from cpu_support into develop

Reviewed-on: #6
This commit is contained in:
Falko Victor Habel 2025-02-24 13:13:45 +00:00
commit 027676d95e
1 changed files with 4 additions and 1 deletions

View File

@ -25,7 +25,10 @@ class AIIA(nn.Module):
def load(cls, path):
config = AIIAConfig.load(path)
model = cls(config)
model.load_state_dict(torch.load(f"{path}/model.pth"))
# Check if CUDA is available and set the device accordingly
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the state dictionary with the correct device mapping
model.load_state_dict(torch.load(f"{path}/model.pth", map_location=device))
return model