added cpu support when loading the model #6
|
@ -25,7 +25,10 @@ class AIIA(nn.Module):
|
|||
def load(cls, path):
|
||||
config = AIIAConfig.load(path)
|
||||
model = cls(config)
|
||||
model.load_state_dict(torch.load(f"{path}/model.pth"))
|
||||
# Check if CUDA is available and set the device accordingly
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
# Load the state dictionary with the correct device mapping
|
||||
model.load_state_dict(torch.load(f"{path}/model.pth", map_location=device))
|
||||
return model
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue