diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py index c084d18d3a..3af0ceb57b 100644 --- a/benchmarks/microbenchmarks/benchmark_inference.py +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -15,6 +15,9 @@ import torch +from benchmarks.microbenchmarks.profiler import ( + generate_model_profile, +) from benchmarks.microbenchmarks.utils import ( BenchmarkConfig, BenchmarkResult, @@ -29,70 +32,77 @@ def run(config: BenchmarkConfig) -> BenchmarkResult: """Run inference benchmarks""" - clean_caches() # Clean caches - - # Create output directory if it doesn't exist - Path(config.output_dir).mkdir(parents=True, exist_ok=True) - - base_model, input_data = create_model_and_input( - config.model_type, - config.m, - config.k, - config.n, - high_precision_dtype=config.high_precision_dtype, - device=config.device, - ) - - # Use quantize_ to apply each quantization function to the model - m_copy = deepcopy(base_model).eval().to(config.device) - ao_base_config = string_to_config( - config.quantization, - config.sparsity, - high_precision_dtype=config.high_precision_dtype, - ) - - # Check if sparsity is requested and if the device is CUDA (sparsity operations require CUDA) - is_cuda = config.device == "cuda" and torch.cuda.is_available() - - if config.sparsity is not None and ( - config.quantization is None or "baseline" in config.quantization - ): - if is_cuda: - print(f"Applying {config.sparsity} sparsity to model") - sparsify_(m_copy, ao_base_config) + try: + clean_caches() # Clean caches + + # Create output directory if it doesn't exist + Path(config.output_dir).mkdir(parents=True, exist_ok=True) + + base_model, input_data = create_model_and_input( + config.model_type, + config.m, + config.k, + config.n, + high_precision_dtype=config.high_precision_dtype, + device=config.device, + ) + + # Use quantize_ to apply each quantization function to the model + m_copy = deepcopy(base_model).eval().to(config.device) + ao_base_config = string_to_config( + config.quantization, + config.sparsity, + high_precision_dtype=config.high_precision_dtype, + ) + + # Check if sparsity is requested and if the device is CUDA (sparsity operations require CUDA) + is_cuda = config.device == "cuda" and torch.cuda.is_available() + + if config.sparsity is not None and ( + config.quantization is None or "baseline" in config.quantization + ): + if is_cuda: + print(f"Applying {config.sparsity} sparsity to model") + sparsify_(m_copy, ao_base_config) + else: + print( + f"Warning: Skipping {config.sparsity} sparsity as it requires CUDA, but device is {config.device}" + ) + elif config.sparsity is None and ( + config.quantization is None or "baseline" in config.quantization + ): + pass # No quantization or sparsity specified, do nothing else: - print( - f"Warning: Skipping {config.sparsity} sparsity as it requires CUDA, but device is {config.device}" + print("Quantizing model....") + quantize_(m_copy, ao_base_config) + + if config.use_torch_compile: + print("Compiling model....") + m_copy = torch.compile( + m_copy, mode=config.torch_compile_mode, fullgraph=True ) - elif config.sparsity is None and ( - config.quantization is None or "baseline" in config.quantization - ): - pass # No quantization or sparsity specified, do nothing - else: - print("Quantizing model....") - quantize_(m_copy, ao_base_config) - - if config.use_torch_compile: - print("Compiling model....") - m_copy = torch.compile(m_copy, mode=config.torch_compile_mode, fullgraph=True) - - # Run benchmarks - result = BenchmarkResult(config=config) - - # Benchmark time to run an inference call for quantized model - result.model_inference_time_in_ms = model_inference_time_in_ms( - model=m_copy, input_data=input_data - ) - - # TODO: Benchmark time using profiler - # Profile dtype model evaluation - # prof_dtype = benchmark_model_op_with_profiler_in_microseconds(m_copy, input_data, quantized_dtype) - # prof_dtype.export_chrome_trace(f"{quantization}_model_{input_data[0].size()[0]}.json") # Save profiling details - - # TODO: Benchmark gemm time using cuda graph - # gemm_time = benchmark_torch_function_in_microseconds(gemm_op, *args, **kwargs) - - # TODO: Benchmark op with cuda graph - # time = benchmark_op_with_cuda_graph(op, args) - - return result + + # Run benchmarks + result = BenchmarkResult(config=config) + # Store result in model for memory profiling + m_copy._benchmark_result = result + + # Benchmark time to run an inference call for quantized model + result.model_inference_time_in_ms = model_inference_time_in_ms( + model=m_copy, input_data=input_data + ) + + # Run profiler if enabled + if config.enable_profiler: + print("Running profiler...") + try: + result.profiler_json_path = generate_model_profile( + m_copy, input_data, config.profiler_file_name + ) + except Exception as e: + print(f"Error running profiler for {config.name} with error: {e}") + + return result + except Exception as e: + print(f"Error in benchmark run: {config.name} with error: {e}") + return None diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py index 7152542eec..e38fc93819 100644 --- a/benchmarks/microbenchmarks/benchmark_runner.py +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -164,16 +164,19 @@ def run_inference_benchmarks_from_config(configs: List[BenchmarkConfig]) -> None f"Running: {config.name} for Quantization: {config.quantization} and Sparsity: {config.sparsity}" ) result = run_inference(config) # Pass the config object directly - results.append(result) - except Exception: - print(f"Error running benchmark {config.name}") + if result is not None: # Only add successful results + results.append(result) + except Exception as e: + print(f"Error running benchmark {config.name} with error: {e}") continue - # Add results to csv - generate_results_csv(results, configs[0].output_dir) - - # Print results - print_results(results) + # Add results to csv if there are any + if results: + generate_results_csv(results, configs[0].output_dir) + # Print results + print_results(results) + else: + print("No benchmark results were collected. All benchmarks failed.") # TODO: Process results: Speedups: # 1. For different shapes for same model and quantization diff --git a/benchmarks/microbenchmarks/profiler.py b/benchmarks/microbenchmarks/profiler.py new file mode 100644 index 0000000000..3687116ef1 --- /dev/null +++ b/benchmarks/microbenchmarks/profiler.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import os + +import torch +from torch.profiler import ProfilerActivity + + +def generate_model_profile(model, input_data, profile_file_path): + """Function to benchmark model evaluation with profiling. + + Args: + model: The model to profile + input_data: Input data for the model + profile_file_path: Path to save the profiler output + + Returns: + profile_file_path + """ + # Create parent directory if it doesn't exist + os.makedirs(os.path.dirname(profile_file_path), exist_ok=True) + + # Set up profiler activities based on device + activities = [ProfilerActivity.CPU] + device = next(model.parameters()).device + if device.type == "cuda" and torch.cuda.is_available(): + activities.append(ProfilerActivity.CUDA) + + # Warm up + with torch.no_grad(): + for _ in range(3): + _ = model(input_data) + if device.type == "cuda": + torch.cuda.synchronize() + + # Run profiler with minimal settings to ensure compatibility + with torch.profiler.profile( + activities=activities, + record_shapes=True, + with_stack=True, + profile_memory=True, + with_flops=True, # Experimental; might be unreliable for some layers + ) as prof: + with torch.no_grad(): + for _ in range(3): + _ = model(input_data) + if device.type == "cuda": + torch.cuda.synchronize() + + # Save profiling details + prof.export_chrome_trace(profile_file_path) + print(f"Chrome trace saved at: {profile_file_path}") + print("You can now visualize it using:") + print("1. Chrome Trace Viewer: chrome://tracing") + print("2. Perfetto UI: https://ui.perfetto.dev") + + return profile_file_path diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index 97a38469de..5ea3f5d642 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -2,12 +2,14 @@ benchmark_mode: "inference" quantization_config_recipe_names: # Will run a baseline inference for model by default, without quantization for comparison - - "int4wo-32" - - "marlin" -sparsity_config_recipe_names: + - "int8wo" + - "int8dq" + - "float8dq" + - "float8wo" +# sparsity_config_recipe_names: # Will run a baseline inference for model by default, without sparsity for comparison - - "semi-sparse" - - "block" + # - "semi-sparse" + # - "block" output_dir: "benchmarks/microbenchmarks/results" model_params: - name: "small_bf16_linear" @@ -15,17 +17,6 @@ model_params: - name: "custom" shapes: [ [1024, 1024, 1024], # [m, k, n] - ] - high_precision_dtype: "torch.bfloat16" - use_torch_compile: true - torch_compile_mode: "max-autotune" - device: "cuda" - model_type: "linear" - - - name: "large_bf16_ln_linear" - matrix_shapes: - - name: "custom" - shapes: [ [2048, 4096, 1024], [4096, 4096, 1024] ] @@ -33,15 +24,5 @@ model_params: use_torch_compile: true torch_compile_mode: "max-autotune" device: "cuda" - model_type: "ln_linear_sigmoid" - - - name: "cpu_fp32_linear" - matrix_shapes: - - name: "custom" - shapes: [ - [4096, 4096, 1024] - ] - high_precision_dtype: "torch.float32" - use_torch_compile: false - device: "cpu" model_type: "linear" + enable_profiler: true # Enable profiling for this model diff --git a/benchmarks/microbenchmarks/test/test_benchmark_profiler.py b/benchmarks/microbenchmarks/test/test_benchmark_profiler.py new file mode 100644 index 0000000000..0e398b4899 --- /dev/null +++ b/benchmarks/microbenchmarks/test/test_benchmark_profiler.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import unittest + +import torch + +from benchmarks.microbenchmarks.profiler import ( + generate_model_profile, +) +from benchmarks.microbenchmarks.utils import ( + BenchmarkConfig, + ToyLinearModel, +) + + +class TestBenchmarkProfiler(unittest.TestCase): + def setUp(self): + self.test_dir = os.path.dirname(os.path.abspath(__file__)) + self.results_dir = os.path.join(self.test_dir, "results") + os.makedirs(self.results_dir, exist_ok=True) + + # Set up a simple model and input for testing + self.m, self.k, self.n = 1024, 1024, 1024 + self.dtype = torch.bfloat16 + self.model = ToyLinearModel(k=self.k, n=self.n, dtype=self.dtype) + self.input_data = torch.randn(1, self.k, dtype=self.dtype) + + # Move to appropriate device + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = self.model.to(self.device) + self.input_data = self.input_data.to(self.device) + + def tearDown(self): + # Clean up any generated files + import shutil + + if os.path.exists(self.results_dir): + shutil.rmtree(self.results_dir) + + def test_profiler_enabled(self): + """Test that profiler works when enabled""" + config = BenchmarkConfig( + quantization=None, + sparsity=None, + params={ + "enable_profiler": True, + "device": self.device, + }, + shape_name="test", + shape=[self.m, self.k, self.n], + output_dir=self.results_dir, + benchmark_mode="inference", + ) + + profile_path = os.path.join( + self.results_dir, + "profiler", + f"{config.name}_{self.m}_{self.k}_{self.n}_profile.json", + ) + + # Generate profile + result_path = generate_model_profile(self.model, self.input_data, profile_path) + + # Check that profile file exists and is not empty + self.assertTrue(os.path.exists(result_path)) + self.assertGreater(os.path.getsize(result_path), 0) + + # Verify it's valid JSON + with open(result_path) as f: + profile_data = json.load(f) + self.assertIsInstance(profile_data, dict) + + def test_profiler_basic_output(self): + """Test that profiler output contains expected basic fields""" + config = BenchmarkConfig( + quantization=None, + sparsity=None, + params={ + "enable_profiler": True, + "device": self.device, + }, + shape_name="test", + shape=[self.m, self.k, self.n], + output_dir=self.results_dir, + benchmark_mode="inference", + ) + + profile_path = os.path.join( + self.results_dir, + "profiler", + f"{config.name}_{self.m}_{self.k}_{self.n}_profile.json", + ) + + result_path = generate_model_profile(self.model, self.input_data, profile_path) + + with open(result_path) as f: + data = json.load(f) + + # Check for required Chrome Trace Event format fields + self.assertIn("traceEvents", data) + self.assertTrue(isinstance(data["traceEvents"], list)) + + # Check that we have some events + self.assertGreater(len(data["traceEvents"]), 0) + + # Check event format + event = data["traceEvents"][0] + self.assertIn("name", event) + self.assertIn("ph", event) # Phase + self.assertIn("ts", event) # Timestamp + self.assertIn("pid", event) # Process ID + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_cuda_profiling(self): + """Test CUDA profiling when available""" + config = BenchmarkConfig( + quantization=None, + sparsity=None, + params={ + "enable_profiler": True, + "device": "cuda", + }, + shape_name="test", + shape=[self.m, self.k, self.n], + output_dir=self.results_dir, + benchmark_mode="inference", + ) + + profile_path = os.path.join( + self.results_dir, + "profiler", + f"{config.name}_{self.m}_{self.k}_{self.n}_profile.json", + ) + + result_path = generate_model_profile( + self.model.cuda(), self.input_data.cuda(), profile_path + ) + + with open(result_path) as f: + data = json.load(f) + + # Check for CUDA events + cuda_events = [ + event for event in data["traceEvents"] if "cuda" in event.get("name", "") + ] + self.assertGreater(len(cuda_events), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 754d8a0c92..e46a859586 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -84,6 +84,14 @@ def __init__( "name", f"benchmark_{self.quantization}_{self.model_type}_m{self.m}_k{self.k}_n{self.n}{'_compile' if self.use_torch_compile else ''}", ) + self.enable_profiler = bool(params.get("enable_profiler", False)) + # Create profiler directory path without leading slash + profiler_dir = os.path.join(self.output_dir, "profiler") + os.makedirs(profiler_dir, exist_ok=True) + file_name = f"{self.name}_{self.m}_{self.k}_{self.n}_quant_{self.quantization}_sparsity_{self.sparsity}" + self.profiler_file_name = os.path.join( + profiler_dir, f"{file_name}_profile.json" + ) @staticmethod def _parse_precision(precision_str: str) -> torch.dtype: @@ -105,6 +113,7 @@ def to_dict(self) -> Dict[str, Any]: "device": self.device, "model_type": self.model_type, "output_dir": self.output_dir, + "enable_profiler": self.enable_profiler, } @@ -116,13 +125,16 @@ def __init__( self.config = config self.output_dir = config.output_dir self.model_inference_time_in_ms = 0.0 + self.profiler_json_path: Optional[str] = None def to_dict(self) -> Dict[str, Any]: """Convert result to dictionary for main function""" - return { + result_dict = { **self.config.to_dict(), "model_inference_time_in_ms": self.model_inference_time_in_ms, + "profiler_json_path": self.profiler_json_path, } + return result_dict class ToyLinearModel(torch.nn.Module): @@ -379,6 +391,11 @@ def generate_results_csv( output_dir (str): Directory to save the CSV file. file_name (str, optional): Name of the CSV file. Defaults to "results.csv". """ + # Check if results list is empty + if len(results) == 0: + print("No results to save to CSV.") + return + # Create the output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) file_path = os.path.join(output_dir, file_name) @@ -396,68 +413,39 @@ def generate_results_csv( def print_results(results: List[BenchmarkResult]): - """Print benchmark results in a formatted table. - - Args: - results (List[BenchmarkResult]): List of benchmark results - """ + """Print results in a table format""" if not results: print("No results to display") return - # Extract relevant columns for display - display_columns = [ - "quantization", - "sparsity", - "model_type", - "m", - "k", - "n", - "model_inference_time_in_ms", - "use_torch_compile", - ] - - # Format data for tabulate - headers = { - "quantization": "Quantization", - "sparsity": "Sparsity", - "model_type": "Model Type", - "m": "M", - "k": "K", - "n": "N", - "model_inference_time_in_ms": "Time (μs)", - "use_torch_compile": "Compile Mode", - } - - # Extract and format data table_data = [] for result in results: - result_dict = result.to_dict() - row = [] - for col in display_columns: - value = result_dict.get(col, "N/A") - if value is None: - value = "N/A" - if col == "model_inference_time_in_ms": - value = f"{value:.2f}" if isinstance(value, (int, float)) else value - elif col == "use_torch_compile": - # Show compile mode if compile is True, otherwise show False - value = ( - result_dict.get("torch_compile_mode", "default") - if result_dict.get("use_torch_compile") - else "False" - ) - row.append(value) + if result is None: + continue + + row = [ + result.config.name, + result.config.quantization or "baseline", + result.config.sparsity or "none", + f"{result.config.shape_name} ({result.config.m}, {result.config.k}, {result.config.n})", + f"{result.model_inference_time_in_ms:.2f}", + str(result.config.enable_profiler), + ] + table_data.append(row) - # Print formatted table - print("\nBenchmark Results:") - print( - tabulate( - table_data, - headers=[headers[col] for col in display_columns], - tablefmt="grid", - floatfmt=".2f", - ) - ) - print() + # Define headers + headers = [ + "Name", + "Quantization", + "Sparsity", + "Shape", + "Inference Time (ms)", + "Profiler Enabled", + ] + + if table_data: + print("\nBenchmark Results:") + print(tabulate(table_data, headers=headers, tablefmt="grid")) + else: + print("\nNo valid results to display")