diff --git a/torchbenchmark/util/backends/torchdynamo.py b/torchbenchmark/util/backends/torchdynamo.py index 6a7559c9cf..9d0b57d8a1 100644 --- a/torchbenchmark/util/backends/torchdynamo.py +++ b/torchbenchmark/util/backends/torchdynamo.py @@ -86,7 +86,7 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace: ) parser.add_argument( "--quantization", - choices=["int8dynamic", "int8weightonly", "int4weightonly"], + choices=["int8dynamic", "int8weightonly", "int4weightonly", "pt2e"], help="Apply quantization to the model before running it", ) parser.add_argument( @@ -182,26 +182,34 @@ def apply_torchdynamo_args( ) if args.quantization: - import torchao - from torchao.quantization import ( - change_linear_weights_to_int4_woqtensors, - change_linear_weights_to_int8_dqtensors, - change_linear_weights_to_int8_woqtensors, - ) + if model.device == "cpu": + if args.quantization == "pt2e": + enable_inductor_quant(model) + else: + raise ValueError( + "The quantization mode is not enabled on CPU" + ) + else: + import torchao + from torchao.quantization import ( + change_linear_weights_to_int4_woqtensors, + change_linear_weights_to_int8_dqtensors, + change_linear_weights_to_int8_woqtensors, + ) - torch._dynamo.config.automatic_dynamic_shapes = False - torch._dynamo.config.force_parameter_static_shapes = False - torch._dynamo.config.cache_size_limit = 1000 - assert "cuda" in model.device - module, example_inputs = model.get_module() - if args.quantization == "int8dynamic": - torch._inductor.config.force_fuse_int_mm_with_mul = True - change_linear_weights_to_int8_dqtensors(module) - elif args.quantization == "int8weightonly": - torch._inductor.config.use_mixed_mm = True - change_linear_weights_to_int8_woqtensors(module) - elif args.quantization == "int4weightonly": - change_linear_weights_to_int4_woqtensors(module) + torch._dynamo.config.automatic_dynamic_shapes = False + torch._dynamo.config.force_parameter_static_shapes = False + torch._dynamo.config.cache_size_limit = 1000 + assert "cuda" in model.device + module, example_inputs = model.get_module() + if args.quantization == "int8dynamic": + torch._inductor.config.force_fuse_int_mm_with_mul = True + change_linear_weights_to_int8_dqtensors(module) + elif args.quantization == "int8weightonly": + torch._inductor.config.use_mixed_mm = True + change_linear_weights_to_int8_woqtensors(module) + elif args.quantization == "int4weightonly": + change_linear_weights_to_int4_woqtensors(module) if args.freeze_prepack_weights: torch._inductor.config.freezing = True @@ -240,3 +248,54 @@ def apply_torchdynamo_args( model.eval = optimize_ctx(model.eval) torch._dynamo.reset() + +def enable_inductor_quant(model: 'torchbenchmark.util.model.BenchmarkModel'): + from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e + import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq + from torch.export import Dim + module, example_inputs = model.get_module() + + if isinstance(example_inputs, dict): + input_ids = torch.randn(2, 512).to(torch.long) + example_inputs = { + "input_ids": input_ids, + } + input_shapes = {k: list(v.shape) for (k, v) in example_inputs.items()} + dims = set() + for _, v in input_shapes.items(): + dims.update(v) + dims=sorted(dims) + dim_str_map = {x: Dim("dim" + str(list(dims).index(x)), min=1, max=1024 * 1024) for x in dims} + dynamic_shapes = {k: {v.index(dim): dim_str_map[dim] for dim in v} for (k, v) in input_shapes.items()} + del dynamic_shapes["input_ids"][1] + # Create X86InductorQuantizer + quantizer = xiq.X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + # Generate the FX Module + if isinstance(example_inputs, dict): + input_ids = torch.ones(2, 512).to(torch.long) + example_inputs = { + "input_ids": input_ids, + } + exported_model = torch.export.export_for_training( + module, + (), + example_inputs, + dynamic_shapes=dynamic_shapes, + ).module() + else: + exported_model = torch.export.export_for_training( + module, + example_inputs, + ).module() + # PT2E Quantization flow + prepared_model = prepare_pt2e(exported_model, quantizer) + # Calibration + if isinstance(example_inputs, dict): + prepared_model(**example_inputs) + else: + prepared_model(*example_inputs) + with torch.no_grad(): + converted_model = convert_pt2e(prepared_model) + torch.ao.quantization.move_exported_model_to_eval(converted_model) + model.set_module(converted_model)