Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ For more details see the <a href="https://docs.vllm.ai/en/stable/getting_started
- `max-cpu-loras`: maximum number of LoRAs to store in CPU memory, optional, must be >= than max-loras, default is max-loras
- `max-model-len`: model's context window, maximum number of tokens in a single request including input and output, optional, default is 1024
- `max-num-seqs`: maximum number of sequences per iteration (maximum number of inference requests that could be processed at the same time), default is 5
- `max-num-batched-tokens`: maximum number of batched tokens per iteration. If set, limits the total number of tokens (prompt + max output tokens) that can be processed simultaneously across all running requests. When not set or set to 0, only `max-num-seqs` constraint is enforced, optional, default is 0 (disabled)
- `mode`: the simulator mode, optional, by default `random`
- `echo`: returns the same text that was sent in the request
- `random`: returns a sentence chosen at random from a set of pre-defined sentences
Expand Down
1 change: 1 addition & 0 deletions manifests/basic-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
port: 8001
model: "Qwen/Qwen2-0.5B"
max-num-seqs: 5
max-num-batched-tokens: 1024
mode: "random"
time-to-first-token: 2000
inter-token-latency: 1000
Expand Down
1 change: 1 addition & 0 deletions manifests/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ served-model-name:
max-loras: 2
max-cpu-loras: 5
max-num-seqs: 5
max-num-batched-tokens: 2048
lora-modules:
- '{"name":"lora1","path":"/path/to/lora1"}'
- '{"name":"lora2","path":"/path/to/lora2"}'
Expand Down
5 changes: 5 additions & 0 deletions pkg/llm-d-inference-sim/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type configuration struct {
// MaxNumSeqs is maximum number of sequences per iteration (the maximum
// number of inference requests that could be processed at the same time)
MaxNumSeqs int `yaml:"max-num-seqs"`
// MaxNumBatchedTokens is maximum number of batched tokens per iteration
MaxNumBatchedTokens int `yaml:"max-num-batched-tokens"`
// MaxModelLen is the model's context window, the maximum number of tokens
// in a single request including input and output. Default value is 1024.
MaxModelLen int `yaml:"max-model-len"`
Expand Down Expand Up @@ -164,6 +166,9 @@ func (c *configuration) validate() error {
if c.MaxModelLen < 1 {
return errors.New("max model len cannot be less than 1")
}
if c.MaxNumBatchedTokens < 0 {
return errors.New("max num batched tokens cannot be negative")
}

for _, lora := range c.LoraModules {
if lora.Name == "" {
Expand Down
36 changes: 36 additions & 0 deletions pkg/llm-d-inference-sim/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ var _ = Describe("Simulator configuration", func() {
c = createDefaultConfig(qwenModelName)
c.Port = 8001
c.ServedModelNames = []string{"model1", "model2"}
c.MaxNumBatchedTokens = 2048
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move this (and all other occurrences) to createDefaultConfig().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

c.LoraModules = []loraModule{{Name: "lora1", Path: "/path/to/lora1"}, {Name: "lora2", Path: "/path/to/lora2"}}
test = testCase{
name: "config file",
Expand All @@ -105,6 +106,7 @@ var _ = Describe("Simulator configuration", func() {
c.Port = 8002
c.ServedModelNames = []string{"alias1", "alias2"}
c.Seed = 100
c.MaxNumBatchedTokens = 2048
c.LoraModules = []loraModule{{Name: "lora3", Path: "/path/to/lora3"}, {Name: "lora4", Path: "/path/to/lora4"}}
c.LoraModulesString = []string{
"{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}",
Expand All @@ -123,6 +125,7 @@ var _ = Describe("Simulator configuration", func() {
// Config from config.yaml file plus command line args with different format
c = createDefaultConfig(model)
c.Port = 8002
c.MaxNumBatchedTokens = 2048
c.LoraModules = []loraModule{{Name: "lora3", Path: "/path/to/lora3"}}
c.LoraModulesString = []string{
"{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}",
Expand All @@ -140,6 +143,7 @@ var _ = Describe("Simulator configuration", func() {
// Config from config.yaml file plus command line args with empty string
c = createDefaultConfig(model)
c.Port = 8002
c.MaxNumBatchedTokens = 2048
c.LoraModules = []loraModule{{Name: "lora3", Path: "/path/to/lora3"}}
c.LoraModulesString = []string{
"{\"name\":\"lora3\",\"path\":\"/path/to/lora3\"}",
Expand All @@ -158,6 +162,7 @@ var _ = Describe("Simulator configuration", func() {
c = createDefaultConfig(qwenModelName)
c.Port = 8001
c.ServedModelNames = []string{"model1", "model2"}
c.MaxNumBatchedTokens = 2048
c.LoraModulesString = []string{}
test = testCase{
name: "config file with command line args with empty string for loras",
Expand All @@ -170,6 +175,7 @@ var _ = Describe("Simulator configuration", func() {
c = createDefaultConfig(qwenModelName)
c.Port = 8001
c.ServedModelNames = []string{"model1", "model2"}
c.MaxNumBatchedTokens = 2048
c.LoraModulesString = []string{}
test = testCase{
name: "config file with command line args with empty parameter for loras",
Expand All @@ -184,6 +190,7 @@ var _ = Describe("Simulator configuration", func() {
// basic config file does not contain properties related to lora
c.MaxLoras = 1
c.MaxCPULoras = 1
c.MaxNumBatchedTokens = 1024
c.KVCacheTransferLatency = 50
test = testCase{
name: "config file with command line args with time to transfer kv-cache",
Expand Down Expand Up @@ -258,4 +265,33 @@ var _ = Describe("Simulator configuration", func() {
Entry(tests[12].name, tests[12].args),
Entry(tests[13].name, tests[13].args),
)

It("should accept max-num-batched-tokens parameter", func() {
config, err := createSimConfig([]string{
"test",
"--model", qwenModelName,
"--max-num-batched-tokens", "1024",
})
Expect(err).NotTo(HaveOccurred())
Expect(config.MaxNumBatchedTokens).Should(Equal(1024))
})

It("should validate max-num-batched-tokens cannot be negative", func() {
config := newConfig()
config.Model = qwenModelName
config.MaxNumBatchedTokens = -1

err := config.validate()
Expect(err).To(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("max num batched tokens cannot be negative"))
})

It("should allow max-num-batched-tokens to be zero (disabled)", func() {
config := newConfig()
config.Model = qwenModelName
config.MaxNumBatchedTokens = 0

err := config.validate()
Expect(err).NotTo(HaveOccurred())
})
})
1 change: 1 addition & 0 deletions pkg/llm-d-inference-sim/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ type completionReqCtx struct {
httpReqCtx *fasthttp.RequestCtx
isChatCompletion bool
wg *sync.WaitGroup
requestID string
}

// chatCompletionRequest defines structure of /chat/completion request
Expand Down
153 changes: 148 additions & 5 deletions pkg/llm-d-inference-sim/simulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ const (
toolChoiceRequired = "required"
)

// runningRequest tracks token usage for a currently running request
type runningRequest struct {
promptTokens int
maxTokens int
totalTokens int
}

// VllmSimulator simulates vLLM server supporting OpenAI API
type VllmSimulator struct {
// logger is used for information and errors logging
Expand All @@ -76,6 +83,10 @@ type VllmSimulator struct {
nRunningReqs int64
// nWaitingReqs is the number of inference requests that are waiting to be processed
nWaitingReqs int64
// runningRequestsMap tracks token usage for currently running requests
runningRequestsMap sync.Map
// processingTokensCount tracks the total number of tokens being processed by running requests
processingTokensCount int64
// loraInfo is prometheus gauge
loraInfo *prometheus.GaugeVec
// runningRequests is prometheus gauge
Expand All @@ -86,6 +97,8 @@ type VllmSimulator struct {
kvCacheUsagePercentage *prometheus.GaugeVec
// channel for requeasts to be passed to workers
reqChan chan *completionReqCtx
// channel for processing queue, managed by queue manager
processingChan chan *completionReqCtx
// schema validator for tools parameters
toolsValidator *validator
}
Expand All @@ -99,6 +112,7 @@ func New(logger logr.Logger) (*VllmSimulator, error) {
return &VllmSimulator{
logger: logger,
reqChan: make(chan *completionReqCtx, 1000),
processingChan: make(chan *completionReqCtx, 1000),
toolsValidator: toolsValidtor,
}, nil
}
Expand All @@ -117,6 +131,9 @@ func (s *VllmSimulator) Start(ctx context.Context) error {
return err
}

// run queue manager that handles request constraints
go s.queueManager(ctx)

// run request processing workers
for i := 1; i <= s.config.MaxNumSeqs; i++ {
go s.reqProcessingWorker(ctx, i)
Expand Down Expand Up @@ -149,6 +166,7 @@ func (s *VllmSimulator) parseCommandParamsAndLoadConfig() error {
f.IntVar(&config.Port, "port", config.Port, "Port")
f.StringVar(&config.Model, "model", config.Model, "Currently 'loaded' model")
f.IntVar(&config.MaxNumSeqs, "max-num-seqs", config.MaxNumSeqs, "Maximum number of inference requests that could be processed at the same time (parameter to simulate requests waiting queue)")
f.IntVar(&config.MaxNumBatchedTokens, "max-num-batched-tokens", config.MaxNumBatchedTokens, "Maximum number of batched tokens per iteration")
f.IntVar(&config.MaxLoras, "max-loras", config.MaxLoras, "Maximum number of LoRAs in a single batch")
f.IntVar(&config.MaxCPULoras, "max-cpu-loras", config.MaxCPULoras, "Maximum number of LoRAs to store in CPU memory")
f.IntVar(&config.MaxModelLen, "max-model-len", config.MaxModelLen, "Model's context window, maximum number of tokens in a single request including input and output")
Expand Down Expand Up @@ -375,6 +393,72 @@ func (s *VllmSimulator) isLora(model string) bool {
return false
}

// calculateProcessingTokens calculates the total number of processing tokens for a request
// Returns prompt tokens + max output tokens
func (s *VllmSimulator) calculateProcessingTokens(req completionRequest) int {
promptTokens := req.getNumberOfPromptTokens()
maxCompletionTokens := req.getMaxCompletionTokens()

// If max_tokens is not specified, calculate it as max-model-len - prompt-len
outputTokens := 0
if maxCompletionTokens != nil {
outputTokens = int(*maxCompletionTokens)
} else {
outputTokens = s.config.MaxModelLen - promptTokens
if outputTokens < 0 {
outputTokens = 0
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it that if maxCompletionTokens is nil, this function should just return s.config.MaxModelLen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved


return promptTokens + outputTokens
}

// canAcceptRequest checks if a new request can be accepted based on max-num-seqs and max-num-batched-tokens constraints
func (s *VllmSimulator) canAcceptRequest(req completionRequest) bool {
currentRunning := atomic.LoadInt64(&s.nRunningReqs)

// Check max-num-seqs constraint
if currentRunning >= int64(s.config.MaxNumSeqs) {
return false
}

// If max-num-batched-tokens is not configured (0), only check max-num-seqs
if s.config.MaxNumBatchedTokens <= 0 {
return true
}

// Calculate tokens needed for this request
requestTokens := s.calculateProcessingTokens(req)
currentTokens := atomic.LoadInt64(&s.processingTokensCount)

// Check max-num-batched-tokens constraint
return currentTokens+int64(requestTokens) <= int64(s.config.MaxNumBatchedTokens)
}

// addRunningRequest adds a request to the running requests tracking
func (s *VllmSimulator) addRunningRequest(reqID string, req completionRequest) {
processingTokens := s.calculateProcessingTokens(req)

runningReq := runningRequest{
promptTokens: req.getNumberOfPromptTokens(),
maxTokens: processingTokens,
totalTokens: processingTokens,
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could only find where 'totalTokens' is used (to update processingTokensCount), why 'promptTokens' and 'totalTokens' are needed? And if they are not used, I guess we don't need runningRequestsMap at all? (And requestID)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, that was unnecessary, I was biased on the general one fit for all use case approach in-case we would like further control over parallel requests in near future. But I guess removing it for now is better and leaner.

Removed unused fields and structures:

  • runningRequest struct - Completely removed since promptTokens and maxTokens were never used
  • runningRequestsMap sync.Map - Removed since we don't need to map request IDs to token counts
  • requestID field - Removed from completionReqCtx since we no longer need unique request tracking

Simplified token tracking:

  • Before: Store a complex runningRequest struct with 3 fields in a map, indexed by requestID
  • After: Store just the processingTokens count directly in the completionReqCtx

Updated method signatures:

  • addRunningRequest() now takes *completionReqCtx instead of (reqID, req)
  • removeRunningRequest() now takes *completionReqCtx instead of reqID
  • Both methods are simpler and more direct


s.runningRequestsMap.Store(reqID, runningReq)
atomic.AddInt64(&s.processingTokensCount, int64(processingTokens))
atomic.AddInt64(&s.nRunningReqs, 1)
}

// removeRunningRequest removes a request from the running requests tracking
func (s *VllmSimulator) removeRunningRequest(reqID string) {
if value, ok := s.runningRequestsMap.LoadAndDelete(reqID); ok {
runningReq := value.(runningRequest)
atomic.AddInt64(&s.processingTokensCount, -int64(runningReq.totalTokens))
atomic.AddInt64(&s.nRunningReqs, -1)
}
}

// handleCompletions general completion requests handler, support both text and chat completion APIs
func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatCompletion bool) {
vllmReq, err := s.readRequest(ctx, isChatCompletion)
Expand All @@ -400,6 +484,16 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
return
}

// Validate max-num-batched-tokens constraint - reject requests that would never be accepted
if s.config.MaxNumBatchedTokens > 0 {
requestTokens := s.calculateProcessingTokens(vllmReq)
if requestTokens > s.config.MaxNumBatchedTokens {
s.sendCompletionError(ctx, fmt.Sprintf("Request requires %d tokens, but max-num-batched-tokens is set to %d. This request would never be accepted. Please reduce max_tokens or increase max-num-batched-tokens",
requestTokens, s.config.MaxNumBatchedTokens), "BadRequestError", fasthttp.StatusBadRequest)
return
}
}

var wg sync.WaitGroup
wg.Add(1)
reqCtx := &completionReqCtx{
Expand All @@ -414,15 +508,60 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
wg.Wait()
}

func (s *VllmSimulator) queueManager(ctx context.Context) {
// Use a slice to maintain the queue of waiting requests
var waitingQueue []*completionReqCtx
ticker := time.NewTicker(10 * time.Millisecond) // Check every 10ms if we can process waiting requests
defer ticker.Stop()

for {
select {
case <-ctx.Done():
s.logger.Info("queueManager stopped")
return
case reqCtx := <-s.reqChan:
// Add new request to the waiting queue
waitingQueue = append(waitingQueue, reqCtx)
case <-ticker.C:
// Periodically check if we can process waiting requests
if len(waitingQueue) == 0 {
continue
}

// Try to process requests from the front of the queue
var newQueue []*completionReqCtx
for _, reqCtx := range waitingQueue {
if s.canAcceptRequest(reqCtx.completionReq) {
// Generate a unique ID for this request
reqID := uuid.New().String()

// Add to running requests tracking
s.addRunningRequest(reqID, reqCtx.completionReq)

// Add the request ID to the context so workers can use it
reqCtx.requestID = reqID

// Send to processing channel
s.processingChan <- reqCtx
} else {
// Can't process yet, keep in queue
newQueue = append(newQueue, reqCtx)
}
}
waitingQueue = newQueue
}
}
}

func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
for {
select {
case <-ctx.Done():
s.logger.Info("reqProcessingWorker stopped:", "worker id", id)
return
case reqCtx, ok := <-s.reqChan:
case reqCtx, ok := <-s.processingChan:
if !ok {
s.logger.Info("reqProcessingWorker worker exiting: reqChan closed")
s.logger.Info("reqProcessingWorker worker exiting: processingChan closed")
return
}
atomic.StoreInt64(&(s.nWaitingReqs), int64(len(s.reqChan)))
Expand All @@ -449,7 +588,8 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
// TODO - check if this request went to the waiting queue - add it to waiting map
s.reportLoras()
}
atomic.AddInt64(&(s.nRunningReqs), 1)

// Note: we don't increment nRunningReqs here because it's already done in addRunningRequest
s.reportRunningRequests()

var responseTokens []string
Expand Down Expand Up @@ -514,15 +654,18 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
req.doRemotePrefill())
}
}

// Clean up the running request tracking
s.removeRunningRequest(reqCtx.requestID)

reqCtx.wg.Done()
}
}
}

// decrease model usage reference number
func (s *VllmSimulator) responseSentCallback(model string) {

atomic.AddInt64(&(s.nRunningReqs), -1)
// Note: nRunningReqs is now decremented in removeRunningRequest
s.reportRunningRequests()

// Only LoRA models require reference-count handling.
Expand Down
Loading