added cpu support when loading the model
This commit is contained in:
parent
3c0e9e8ac1
commit
f8e59c5896
|
@ -25,7 +25,10 @@ class AIIA(nn.Module):
|
||||||
def load(cls, path):
|
def load(cls, path):
|
||||||
config = AIIAConfig.load(path)
|
config = AIIAConfig.load(path)
|
||||||
model = cls(config)
|
model = cls(config)
|
||||||
model.load_state_dict(torch.load(f"{path}/model.pth"))
|
# Check if CUDA is available and set the device accordingly
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
# Load the state dictionary with the correct device mapping
|
||||||
|
model.load_state_dict(torch.load(f"{path}/model.pth", map_location=device))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue