diff --git a/example.py b/example.py index ae3a35a..44e336e 100644 --- a/example.py +++ b/example.py @@ -3,7 +3,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer location = "Godala-moe" -device = "cuda" # cpu when not using gpu +device = "gpu" # or cpu tokenizer = AutoTokenizer.from_pretrained(location) model = AutoModelForCausalLM.from_pretrained(location).to(device)