added cpu support when loading the model

This commit is contained in:
Falko Victor Habel 2025-01-31 09:13:58 +01:00
parent 3c0e9e8ac1
commit f8e59c5896
1 changed files with 4 additions and 1 deletions

View File

@ -25,7 +25,10 @@ class AIIA(nn.Module):
def load(cls, path):
config = AIIAConfig.load(path)
model = cls(config)
model.load_state_dict(torch.load(f"{path}/model.pth"))
# Check if CUDA is available and set the device accordingly
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the state dictionary with the correct device mapping
model.load_state_dict(torch.load(f"{path}/model.pth", map_location=device))
return model