Skip to content

Commit acdc197

Browse files
committed
Update
1 parent 45e75e7 commit acdc197

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

quantize.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,14 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
402402
assert inner_k_tiles in [2, 4, 8]
403403

404404
@torch.no_grad()
405-
def create_quantized_state_dict(self):
405+
def create_quantized_state_dict(self, use_cuda = True):
406+
device="cpu"
407+
if use_cuda:
408+
if torch.cuda.is_available():
409+
device="cuda"
410+
else:
411+
print(f"Warning: CUDA not available, running CPU")
412+
406413
cur_state_dict = self.mod.state_dict()
407414
for fqn, mod in self.mod.named_modules():
408415
if isinstance(mod, torch.nn.Linear):
@@ -425,7 +432,7 @@ def create_quantized_state_dict(self):
425432
"and that groupsize and inner_k_tiles*16 evenly divide into it")
426433
continue
427434
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
428-
weight.to(torch.bfloat16), self.groupsize, self.inner_k_tiles
435+
weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles
429436
)
430437
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu')
431438
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu')

0 commit comments

Comments
 (0)