Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Ironwood/configs/host_device/host_device_single_chip.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
benchmarks:
- benchmark_name: host_device
num_runs: 20
benchmark_sweep_params:
# Single Chip (1 Chip, 2 Devices)
- {
num_devices: 2,
data_size_mb_list: [1, 16, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
}
csv_path: "../microbenchmarks/host_device/single_chip"
trace_dir: "../microbenchmarks/host_device/single_chip/trace"
30 changes: 30 additions & 0 deletions Ironwood/guides/host_device/host_device.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Host Device Microbenchmarks on tpu7x-2x2x1

This guide provides instructions for running Host Device (Host-to-Device and Device-to-Host) microbenchmarks on tpu7x-2x2x1 Google Kubernetes Engine (GKE) clusters. It covers creating a node pool, running the benchmarks, and viewing the output.

> [!WARNING]
> This benchmark is currently a Work In Progress (WIP). Expected bandwidth numbers are not yet finalized.

## Create Node Pools

Follow [Setup section](../../Ironwood_Microbenchmarks_readme.md#setup) to create a GKE cluster with one 2x2x1 nodepool.

## Run Host Device Microbenchmarks

To run the microbenchmarks, apply the following Kubernetes configuration:
```bash
kubectl apply -f tpu7x-host-device-benchmark.yaml
```

To extract the log of the microbenchmark, use `kubectl logs`:
```bash
kubectl logs tpu7x-host-device-benchmark
```

Once the benchmark completes, you should see logs reporting bandwidth statistics.

To retrieve the complete results, including the trace and CSV output files, you must keep the pod running after the benchmark completes. To do this, add a `sleep` command to the `tpu7x-host-device-benchmark.yaml` file. You can then use `kubectl cp` to copy the output from the pod.

```bash
kubectl cp tpu7x-host-device-benchmark:/microbenchmarks/host_device host_device
```
34 changes: 34 additions & 0 deletions Ironwood/guides/host_device/tpu7x-host-device-benchmark.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
apiVersion: v1
kind: Pod
metadata:
name: tpu7x-host-device-benchmark
spec:
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu7x
cloud.google.com/gke-tpu-topology: 2x2x1
containers:
- name: tpu-job
image: python:3.12
ports:
- containerPort: 8431
securityContext:
privileged: false
command:
- bash
- -c
- |
set -ex

git clone https://github.com/AI-Hypercomputer/accelerator-microbenchmarks.git
cd accelerator-microbenchmarks
pip install -r requirements.txt

export TPU_VISIBLE_CHIPS=0
bash ./Ironwood/scripts/run_host_device_benchmark.sh --config Ironwood/configs/host_device/host_device_single_chip.yaml

resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
71 changes: 71 additions & 0 deletions Ironwood/scripts/run_host_device_benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/bin/bash

# Default values
CONFIG_DIR="Ironwood/configs/host_device"
SPECIFIC_CONFIG=""
INTERLEAVED=false

# Helper function for usage
usage() {
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --config <path> Path to specific config file (optional)"
echo " --interleaved Run with numactl --interleave=all"
echo " --help Show this help message"
exit 1
}

# Parse arguments
while [[ "$#" -gt 0 ]]; do
case $1 in
--config) SPECIFIC_CONFIG="$2"; shift ;;
--interleaved) INTERLEAVED=true ;;
--help) usage ;;
*) echo "Unknown parameter passed: $1"; usage ;;
esac
shift
done

echo "--- Starting Host-Device Transfer Benchmark (H2D/D2H) ---"
echo "********************************************************"
echo "WARNING: This benchmark is currently a WORK IN PROGRESS"
echo "********************************************************"
echo ""
echo "Configuration:"
echo " Interleaved: $INTERLEAVED"
echo ""

if [ -n "$SPECIFIC_CONFIG" ]; then
CONFIGS=("$SPECIFIC_CONFIG")
else
# Use nullglob to handle case where no files match (though unlikely here)
shopt -s nullglob
CONFIGS=("$CONFIG_DIR"/*.yaml)
shopt -u nullglob
fi

if [ ${#CONFIGS[@]} -eq 0 ]; then
echo "No configuration files found!"
exit 1
fi

for CONFIG_FILE in "${CONFIGS[@]}"; do
echo "--- Running Config: $CONFIG_FILE ---"
CMD="python Ironwood/src/run_benchmark.py --config=${CONFIG_FILE}"

if [ "$INTERLEAVED" = true ]; then
if command -v numactl &> /dev/null; then
echo "Running with numactl --interleave=all"
numactl --interleave=all $CMD
else
echo "Warning: numactl not found. Running without interleaving."
$CMD
fi
else
$CMD
fi
echo "--- Finished Config: $CONFIG_FILE ---"
echo ""
done

echo "--- All Benchmarks Finished ---"
146 changes: 146 additions & 0 deletions Ironwood/src/benchmark_host_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""Benchmarks Host-to-Device and Device-to-Host transfer performance (Simple Baseline)."""

import time
import os
from typing import Any, Dict, Tuple, List

import jax
from jax import sharding
import numpy as np
from benchmark_utils import MetricsStatistics


libtpu_init_args = [
"--xla_tpu_dvfs_p_state=7",
]
os.environ["LIBTPU_INIT_ARGS"] = " ".join(libtpu_init_args)
# 64 GiB
os.environ["TPU_PREMAPPED_BUFFER_SIZE"] = "68719476736"
os.environ["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES"] = "68719476736"

def get_tpu_devices(num_devices: int):
devices = jax.devices()
if len(devices) < num_devices:
raise RuntimeError(f"Require {num_devices} devices, found {len(devices)}")
return devices[:num_devices]

def benchmark_host_device(
num_devices: int,
data_size_mb: int,
num_runs: int = 100,
trace_dir: str = None,
) -> Dict[str, Any]:
"""Benchmarks H2D/D2H transfer using simple device_put/device_get."""
tpu_devices = get_tpu_devices(num_devices)

num_elements = 1024 * 1024 * data_size_mb // np.dtype(np.float32).itemsize

# Allocate Host Source Buffer
host_data = np.random.normal(size=(num_elements,)).astype(np.float32)

print(
f"Benchmarking (Simple) Transfer with Data Size: {data_size_mb} MB on"
f" {num_devices} devices for {num_runs} iterations"
)

# Setup Mesh Sharding (1D)
mesh = sharding.Mesh(
np.array(tpu_devices).reshape((num_devices,)), axis_names=("x",)
)
# Shard the 1D array across "x"
partition_spec = sharding.PartitionSpec("x")

data_sharding = sharding.NamedSharding(mesh, partition_spec)

# Performance Lists
h2d_perf, d2h_perf = [], []

# Profiling Context
import contextlib
if trace_dir:
profiler_context = jax.profiler.trace(trace_dir)
else:
profiler_context = contextlib.nullcontext()

with profiler_context:
# Warmup
for _ in range(2):
device_array = jax.device_put(host_data, data_sharding)
device_array.block_until_ready()
host_out = np.array(device_array)
device_array.delete()
del host_out

for i in range(num_runs):
# Step Context
if trace_dir:
step_context = jax.profiler.StepTraceAnnotation("host_device", step_num=i)
else:
step_context = contextlib.nullcontext()

with step_context:
# H2D
t0 = time.perf_counter()

# Simple device_put
device_array = jax.device_put(host_data, data_sharding)
device_array.block_until_ready()

t1 = time.perf_counter()
h2d_perf.append((t1 - t0) * 1000)

# Verify H2D shape/sharding
assert device_array.shape == host_data.shape
assert device_array.sharding == data_sharding

# D2H
t2 = time.perf_counter()

# Simple device_get
# Note: device_get returns a numpy array (copy)
_ = jax.device_get(device_array)

t3 = time.perf_counter()
d2h_perf.append((t3 - t2) * 1000)

device_array.delete()

return {
"H2D_Bandwidth_ms": h2d_perf,
"D2H_Bandwidth_ms": d2h_perf,
}

def benchmark_host_device_calculate_metrics(
num_devices: int,
data_size_mb: int,
H2D_Bandwidth_ms: List[float],
D2H_Bandwidth_ms: List[float],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Calculates metrics for Host-Device transfer."""
params = locals().items()

data_size_mib = data_size_mb

# Filter out list params from metadata to avoid explosion
metadata_keys = {
"num_devices",
"data_size_mib",
}
metadata = {k: v for k, v in params if k in metadata_keys}

metrics = {}

def add_metric(name, ms_list):
# Report Bandwidth (GiB/s)
# Handle division by zero if ms is 0
bw_list = [
((data_size_mb / 1024) / (ms / 1000)) if ms > 0 else 0.0
for ms in ms_list
]
stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)")
metrics.update(stats_bw.serialize_statistics())

add_metric("H2D", H2D_Bandwidth_ms)
add_metric("D2H", D2H_Bandwidth_ms)

return metadata, metrics
16 changes: 16 additions & 0 deletions Ironwood/src/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,17 @@
"inference_silu_mul": "benchmark_inference_compute.silu_mul",
"inference_sigmoid": "benchmark_inference_compute.sigmoid",
}
HOST_DEVICE_BENCHMARK_MAP = {
"host_device": "benchmark_host_device.benchmark_host_device",
}
BENCHMARK_MAP = {}
BENCHMARK_MAP.update(COLLECTIVE_BENCHMARK_MAP)
BENCHMARK_MAP.update(MATMUL_BENCHMARK_MAP)
BENCHMARK_MAP.update(CONVOLUTION_BENCHMARK_MAP)
BENCHMARK_MAP.update(ATTENTION_BENCHMARK_MAP)
BENCHMARK_MAP.update(HBM_BENCHMARK_MAP)
BENCHMARK_MAP.update(COMPUTE_BENCHMARK_MAP)
BENCHMARK_MAP.update(HOST_DEVICE_BENCHMARK_MAP)


# Mapping from dtype string to actual dtype object
Expand Down Expand Up @@ -326,6 +330,12 @@ def run_single_benchmark(benchmark_config: Dict[str, Any], output_path: str):
# csv_path = os.path.join(output_path, benchmark_name)
trace_dir = os.path.join(output_path, benchmark_name, "trace")
xla_dump_dir = os.path.join(output_path, benchmark_name, "hlo_graphs")
# Inject num_runs from config if not present in params
global_num_runs = benchmark_config.get("num_runs")
if global_num_runs is not None:
for param in benchmark_params:
if "num_runs" not in param:
param["num_runs"] = global_num_runs

if not benchmark_name:
raise ValueError("Each benchmark must have a 'benchmark_name'.")
Expand Down Expand Up @@ -467,6 +477,12 @@ def run_benchmark_multithreaded(benchmark_config, output_path):
if output_path != "":
csv_path = os.path.join(output_path, benchmark_name)
os.makedirs(csv_path, exist_ok=True)
# Inject num_runs from config if not present in params
global_num_runs = benchmark_config.get("num_runs")
if global_num_runs is not None:
for param in benchmark_params:
if "num_runs" not in param:
param["num_runs"] = global_num_runs

# Get the benchmark function
benchmark_func, calculate_metrics_func = get_benchmark_functions(benchmark_name)
Expand Down