Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions scripts/playground/bench_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os
import time
from types import SimpleNamespace
from typing import List

import numpy as np
import requests
Expand Down Expand Up @@ -54,17 +55,18 @@ def encode(self, text: str, add_special_tokens: bool = False):
return []


def send_one_batch(base_url, num_prompts, batch_size, tokenizer, is_multimodal):
def send_one_batch(base_url, num_prompts, batch_size, processor, is_multimodal):
# format: (prompt, input_len, output len). We set input_len as a dummy value 0.
if is_multimodal:
backend = "sglang-oai-chat"
api_url = f"{base_url}/v1/chat/completions"
input_requests = sample_mmmu_requests(
num_prompts,
tokenizer,
processor,
backend=backend,
fixed_output_len=512,
)
tokenizer = processor.tokenizer
else:
padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[
:num_prompts
Expand All @@ -74,6 +76,7 @@ def send_one_batch(base_url, num_prompts, batch_size, tokenizer, is_multimodal):
]
backend = "sglang"
api_url = f"{base_url}/generate"
tokenizer = processor

# We need to set some dummy values in order to call `benchmark` below.
args = SimpleNamespace(
Expand Down Expand Up @@ -108,6 +111,8 @@ def send_one_batch(base_url, num_prompts, batch_size, tokenizer, is_multimodal):
lora_names=None,
extra_request_body={},
profile=None,
lora_request_distribution=None,
lora_zipf_alpha=None,
)
)

Expand Down Expand Up @@ -225,22 +230,29 @@ def main(args, server_args):
},
)

tokenizer = AutoTokenizer.from_pretrained(
args.model_path, trust_remote_code=server_args.trust_remote_code
)
if args.is_multimodal:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(
args.model_path, trust_remote_code=server_args.trust_remote_code
)
else:
processor = AutoTokenizer.from_pretrained(
args.model_path, trust_remote_code=server_args.trust_remote_code
)

try:
# Warmup
send_one_batch(
base_url, batch_size, batch_size, tokenizer, args.is_multimodal
base_url, batch_size, batch_size, processor, args.is_multimodal
)

# Benchmark
acc_length, step_time, speed, completion_tokens = send_one_batch(
base_url,
max(args.num_prompts, batch_size),
batch_size,
tokenizer,
processor,
args.is_multimodal,
)
finally:
Expand Down