|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import argparse |
| 8 | +import subprocess |
| 9 | + |
| 10 | +import torch |
| 11 | +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig |
| 12 | + |
| 13 | +from torchao.quantization import ( |
| 14 | + Float8DynamicActivationFloat8WeightConfig, |
| 15 | + Float8DynamicActivationInt4WeightConfig, |
| 16 | + Int4WeightOnlyConfig, |
| 17 | + Int8DynamicActivationInt8WeightConfig, |
| 18 | + Int8WeightOnlyConfig, |
| 19 | + PerRow, |
| 20 | +) |
| 21 | + |
| 22 | + |
| 23 | +def string_to_config(s): |
| 24 | + if s is None: |
| 25 | + return None |
| 26 | + elif s == "float8_rowwise": |
| 27 | + return Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) |
| 28 | + elif s == "int4_groupwise_weight_float8_rowwise_activation": |
| 29 | + return Float8DynamicActivationInt4WeightConfig() |
| 30 | + elif s == "int4_groupwise_hqq_weight_only": |
| 31 | + return Int4WeightOnlyConfig( |
| 32 | + group_size=32, |
| 33 | + int4_packing_format="tile_packed_to_4d", |
| 34 | + int4_choose_qparams_algorithm="hqq", |
| 35 | + ) |
| 36 | + elif s == "int8_rowwise_weight_only": |
| 37 | + return Int8WeightOnlyConfig() |
| 38 | + elif s == "int8_rowwise": |
| 39 | + return Int8DynamicActivationInt8WeightConfig() |
| 40 | + else: |
| 41 | + raise AssertionError(f"unsupported {s}") |
| 42 | + |
| 43 | + |
| 44 | +def quantize_model_and_save(model_id, quant_config, output_dir="results"): |
| 45 | + """Quantize the model and save it to the output directory.""" |
| 46 | + print("Quantizing model with config: ", quant_config) |
| 47 | + if quant_config is None: |
| 48 | + quantization_config = None |
| 49 | + else: |
| 50 | + quantization_config = TorchAoConfig(quant_type=quant_config) |
| 51 | + quantized_model = AutoModelForCausalLM.from_pretrained( |
| 52 | + model_id, |
| 53 | + device_map="auto", |
| 54 | + dtype=torch.bfloat16, |
| 55 | + quantization_config=quantization_config, |
| 56 | + ) |
| 57 | + tokenizer = AutoTokenizer.from_pretrained(model_id) |
| 58 | + quantized_model.save_pretrained(output_dir, safe_serialization=False) |
| 59 | + tokenizer.save_pretrained(output_dir, safe_serialization=False) |
| 60 | + return quantized_model, tokenizer |
| 61 | + |
| 62 | + |
| 63 | +def run_lm_eval(model_dir, tasks_list=["hellaswag"], device="cuda:0", batch_size=8): |
| 64 | + """Run the lm_eval command using subprocess.""" |
| 65 | + tasks_str = ",".join(tasks_list) |
| 66 | + command = [ |
| 67 | + "lm_eval", |
| 68 | + "--model", |
| 69 | + "hf", |
| 70 | + "--model_args", |
| 71 | + f"pretrained={model_dir}", |
| 72 | + "--tasks", |
| 73 | + f"{tasks_str}", |
| 74 | + "--device", |
| 75 | + f"{device}", |
| 76 | + "--batch_size", |
| 77 | + f"{batch_size}", |
| 78 | + "--output_path", |
| 79 | + f"{model_dir}/lm_eval_outputs/", |
| 80 | + ] |
| 81 | + subprocess.run(command, check=True) |
| 82 | + |
| 83 | + |
| 84 | +def get_size_of_dir(model_output_dir): |
| 85 | + # get dir size from shell, to skip complexity of dealing with tensor |
| 86 | + # subclasses |
| 87 | + result = subprocess.run( |
| 88 | + ["du", "-sb", model_output_dir], capture_output=True, text=True |
| 89 | + ) |
| 90 | + size = int(result.stdout.split()[0]) |
| 91 | + return size |
| 92 | + |
| 93 | + |
| 94 | +def run( |
| 95 | + model_id: str, |
| 96 | + quant_recipe_name: str | None, |
| 97 | + tasks, |
| 98 | + device, |
| 99 | + batch_size, |
| 100 | + model_output_dir, |
| 101 | +): |
| 102 | + print(f"\nRunning {model_id=} with {quant_recipe_name=}\n") |
| 103 | + model_name = model_id.split("/")[-1] |
| 104 | + model_output_dir = ( |
| 105 | + f"benchmarks/data/quantized_model/{model_name}-{quant_recipe_name}" |
| 106 | + ) |
| 107 | + quant_config = string_to_config(quant_recipe_name) |
| 108 | + quantized_model, tokenizer = quantize_model_and_save( |
| 109 | + model_id, quant_config=quant_config, output_dir=model_output_dir |
| 110 | + ) |
| 111 | + print(quantized_model) |
| 112 | + |
| 113 | + model_size = get_size_of_dir(model_output_dir) / 1e9 |
| 114 | + print(f"checkpoint size: {model_size} GB") |
| 115 | + |
| 116 | + run_lm_eval( |
| 117 | + model_output_dir, tasks_list=tasks, device=device, batch_size=batch_size |
| 118 | + ) |
| 119 | + print("done\n") |
| 120 | + |
| 121 | + |
| 122 | +if __name__ == "__main__": |
| 123 | + try: |
| 124 | + import lm_eval # noqa: F401 |
| 125 | + except: |
| 126 | + print( |
| 127 | + "lm_eval is required to run this script. Please install it using pip install lm-eval." |
| 128 | + ) |
| 129 | + exit(0) |
| 130 | + |
| 131 | + # Set up argument parser |
| 132 | + parser = argparse.ArgumentParser( |
| 133 | + description="Quantize a model and evaluate its throughput." |
| 134 | + ) |
| 135 | + parser.add_argument( |
| 136 | + "--model_id", |
| 137 | + type=str, |
| 138 | + default="meta-llama/Llama-3.1-8B", |
| 139 | + help="The model ID to use.", |
| 140 | + ) |
| 141 | + parser.add_argument( |
| 142 | + "--quant_recipe_name", |
| 143 | + type=str, |
| 144 | + default=None, |
| 145 | + help="The quantization recipe to use.", |
| 146 | + ) |
| 147 | + parser.add_argument( |
| 148 | + "--tasks", |
| 149 | + nargs="+", |
| 150 | + type=str, |
| 151 | + default=["wikitext"], |
| 152 | + help="List of lm-eluther tasks to evaluate usage: --tasks task1 task2", |
| 153 | + ) |
| 154 | + parser.add_argument( |
| 155 | + "--device", type=str, default="cuda:0", help="Device to run the model on." |
| 156 | + ) |
| 157 | + parser.add_argument( |
| 158 | + "--batch_size", type=str, default="auto", help="Batch size for lm_eval." |
| 159 | + ) |
| 160 | + parser.add_argument( |
| 161 | + "--output_dir", |
| 162 | + type=str, |
| 163 | + default="quantized_models", |
| 164 | + help="Output directory for quantized model.", |
| 165 | + ) |
| 166 | + args = parser.parse_args() |
| 167 | + |
| 168 | + # Use parsed arguments |
| 169 | + run( |
| 170 | + model_id=args.model_id, |
| 171 | + quant_recipe_name=args.quant_recipe_name, |
| 172 | + tasks=args.tasks, |
| 173 | + device=args.device, |
| 174 | + batch_size=args.batch_size, |
| 175 | + model_output_dir=args.output_dir, |
| 176 | + ) |
0 commit comments