diff --git a/chat.py b/chat.py index 27e4fe0..4f310f6 100644 --- a/chat.py +++ b/chat.py @@ -93,7 +93,8 @@ class DOLPHIN: ckpt = try_rename_lagacy_weights(ckpt) self.model.load_state_dict(ckpt, strict=True) - self.model.to("cuda") + device = "cuda" if torch.cuda.is_available() else "cpu" + self.model.to(device) self.model.eval() transform_args = { "input_size": self.swin_args["img_size"],