-
Notifications
You must be signed in to change notification settings - Fork 46
feat: add max-num-batched-tokens configuration and implement request handling constraints (#83) #97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,6 +60,13 @@ const ( | |
| toolChoiceRequired = "required" | ||
| ) | ||
|
|
||
| // runningRequest tracks token usage for a currently running request | ||
| type runningRequest struct { | ||
| promptTokens int | ||
| maxTokens int | ||
| totalTokens int | ||
| } | ||
|
|
||
| // VllmSimulator simulates vLLM server supporting OpenAI API | ||
| type VllmSimulator struct { | ||
| // logger is used for information and errors logging | ||
|
|
@@ -76,6 +83,10 @@ type VllmSimulator struct { | |
| nRunningReqs int64 | ||
| // nWaitingReqs is the number of inference requests that are waiting to be processed | ||
| nWaitingReqs int64 | ||
| // runningRequestsMap tracks token usage for currently running requests | ||
| runningRequestsMap sync.Map | ||
| // processingTokensCount tracks the total number of tokens being processed by running requests | ||
| processingTokensCount int64 | ||
| // loraInfo is prometheus gauge | ||
| loraInfo *prometheus.GaugeVec | ||
| // runningRequests is prometheus gauge | ||
|
|
@@ -86,6 +97,8 @@ type VllmSimulator struct { | |
| kvCacheUsagePercentage *prometheus.GaugeVec | ||
| // channel for requeasts to be passed to workers | ||
| reqChan chan *completionReqCtx | ||
| // channel for processing queue, managed by queue manager | ||
| processingChan chan *completionReqCtx | ||
| // schema validator for tools parameters | ||
| toolsValidator *validator | ||
| } | ||
|
|
@@ -99,6 +112,7 @@ func New(logger logr.Logger) (*VllmSimulator, error) { | |
| return &VllmSimulator{ | ||
| logger: logger, | ||
| reqChan: make(chan *completionReqCtx, 1000), | ||
| processingChan: make(chan *completionReqCtx, 1000), | ||
| toolsValidator: toolsValidtor, | ||
| }, nil | ||
| } | ||
|
|
@@ -117,6 +131,9 @@ func (s *VllmSimulator) Start(ctx context.Context) error { | |
| return err | ||
| } | ||
|
|
||
| // run queue manager that handles request constraints | ||
| go s.queueManager(ctx) | ||
|
|
||
| // run request processing workers | ||
| for i := 1; i <= s.config.MaxNumSeqs; i++ { | ||
| go s.reqProcessingWorker(ctx, i) | ||
|
|
@@ -149,6 +166,7 @@ func (s *VllmSimulator) parseCommandParamsAndLoadConfig() error { | |
| f.IntVar(&config.Port, "port", config.Port, "Port") | ||
| f.StringVar(&config.Model, "model", config.Model, "Currently 'loaded' model") | ||
| f.IntVar(&config.MaxNumSeqs, "max-num-seqs", config.MaxNumSeqs, "Maximum number of inference requests that could be processed at the same time (parameter to simulate requests waiting queue)") | ||
| f.IntVar(&config.MaxNumBatchedTokens, "max-num-batched-tokens", config.MaxNumBatchedTokens, "Maximum number of batched tokens per iteration") | ||
| f.IntVar(&config.MaxLoras, "max-loras", config.MaxLoras, "Maximum number of LoRAs in a single batch") | ||
| f.IntVar(&config.MaxCPULoras, "max-cpu-loras", config.MaxCPULoras, "Maximum number of LoRAs to store in CPU memory") | ||
| f.IntVar(&config.MaxModelLen, "max-model-len", config.MaxModelLen, "Model's context window, maximum number of tokens in a single request including input and output") | ||
|
|
@@ -375,6 +393,72 @@ func (s *VllmSimulator) isLora(model string) bool { | |
| return false | ||
| } | ||
|
|
||
| // calculateProcessingTokens calculates the total number of processing tokens for a request | ||
| // Returns prompt tokens + max output tokens | ||
| func (s *VllmSimulator) calculateProcessingTokens(req completionRequest) int { | ||
| promptTokens := req.getNumberOfPromptTokens() | ||
| maxCompletionTokens := req.getMaxCompletionTokens() | ||
|
|
||
| // If max_tokens is not specified, calculate it as max-model-len - prompt-len | ||
| outputTokens := 0 | ||
| if maxCompletionTokens != nil { | ||
| outputTokens = int(*maxCompletionTokens) | ||
| } else { | ||
| outputTokens = s.config.MaxModelLen - promptTokens | ||
| if outputTokens < 0 { | ||
| outputTokens = 0 | ||
| } | ||
| } | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't it that if maxCompletionTokens is nil, this function should just return s.config.MaxModelLen?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. resolved |
||
|
|
||
| return promptTokens + outputTokens | ||
| } | ||
|
|
||
| // canAcceptRequest checks if a new request can be accepted based on max-num-seqs and max-num-batched-tokens constraints | ||
| func (s *VllmSimulator) canAcceptRequest(req completionRequest) bool { | ||
| currentRunning := atomic.LoadInt64(&s.nRunningReqs) | ||
|
|
||
| // Check max-num-seqs constraint | ||
| if currentRunning >= int64(s.config.MaxNumSeqs) { | ||
| return false | ||
| } | ||
|
|
||
| // If max-num-batched-tokens is not configured (0), only check max-num-seqs | ||
| if s.config.MaxNumBatchedTokens <= 0 { | ||
| return true | ||
| } | ||
|
|
||
| // Calculate tokens needed for this request | ||
| requestTokens := s.calculateProcessingTokens(req) | ||
| currentTokens := atomic.LoadInt64(&s.processingTokensCount) | ||
|
|
||
| // Check max-num-batched-tokens constraint | ||
| return currentTokens+int64(requestTokens) <= int64(s.config.MaxNumBatchedTokens) | ||
| } | ||
|
|
||
| // addRunningRequest adds a request to the running requests tracking | ||
| func (s *VllmSimulator) addRunningRequest(reqID string, req completionRequest) { | ||
| processingTokens := s.calculateProcessingTokens(req) | ||
|
|
||
| runningReq := runningRequest{ | ||
| promptTokens: req.getNumberOfPromptTokens(), | ||
| maxTokens: processingTokens, | ||
| totalTokens: processingTokens, | ||
| } | ||
|
||
|
|
||
| s.runningRequestsMap.Store(reqID, runningReq) | ||
| atomic.AddInt64(&s.processingTokensCount, int64(processingTokens)) | ||
| atomic.AddInt64(&s.nRunningReqs, 1) | ||
| } | ||
|
|
||
| // removeRunningRequest removes a request from the running requests tracking | ||
| func (s *VllmSimulator) removeRunningRequest(reqID string) { | ||
| if value, ok := s.runningRequestsMap.LoadAndDelete(reqID); ok { | ||
| runningReq := value.(runningRequest) | ||
| atomic.AddInt64(&s.processingTokensCount, -int64(runningReq.totalTokens)) | ||
| atomic.AddInt64(&s.nRunningReqs, -1) | ||
| } | ||
| } | ||
|
|
||
| // handleCompletions general completion requests handler, support both text and chat completion APIs | ||
| func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatCompletion bool) { | ||
| vllmReq, err := s.readRequest(ctx, isChatCompletion) | ||
|
|
@@ -400,6 +484,16 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple | |
| return | ||
| } | ||
|
|
||
| // Validate max-num-batched-tokens constraint - reject requests that would never be accepted | ||
| if s.config.MaxNumBatchedTokens > 0 { | ||
| requestTokens := s.calculateProcessingTokens(vllmReq) | ||
| if requestTokens > s.config.MaxNumBatchedTokens { | ||
| s.sendCompletionError(ctx, fmt.Sprintf("Request requires %d tokens, but max-num-batched-tokens is set to %d. This request would never be accepted. Please reduce max_tokens or increase max-num-batched-tokens", | ||
| requestTokens, s.config.MaxNumBatchedTokens), "BadRequestError", fasthttp.StatusBadRequest) | ||
| return | ||
| } | ||
| } | ||
|
|
||
| var wg sync.WaitGroup | ||
| wg.Add(1) | ||
| reqCtx := &completionReqCtx{ | ||
|
|
@@ -414,15 +508,60 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple | |
| wg.Wait() | ||
| } | ||
|
|
||
| func (s *VllmSimulator) queueManager(ctx context.Context) { | ||
| // Use a slice to maintain the queue of waiting requests | ||
| var waitingQueue []*completionReqCtx | ||
| ticker := time.NewTicker(10 * time.Millisecond) // Check every 10ms if we can process waiting requests | ||
| defer ticker.Stop() | ||
|
|
||
| for { | ||
| select { | ||
| case <-ctx.Done(): | ||
| s.logger.Info("queueManager stopped") | ||
| return | ||
| case reqCtx := <-s.reqChan: | ||
| // Add new request to the waiting queue | ||
| waitingQueue = append(waitingQueue, reqCtx) | ||
| case <-ticker.C: | ||
| // Periodically check if we can process waiting requests | ||
| if len(waitingQueue) == 0 { | ||
| continue | ||
| } | ||
|
|
||
| // Try to process requests from the front of the queue | ||
| var newQueue []*completionReqCtx | ||
| for _, reqCtx := range waitingQueue { | ||
| if s.canAcceptRequest(reqCtx.completionReq) { | ||
| // Generate a unique ID for this request | ||
| reqID := uuid.New().String() | ||
|
|
||
| // Add to running requests tracking | ||
| s.addRunningRequest(reqID, reqCtx.completionReq) | ||
|
|
||
| // Add the request ID to the context so workers can use it | ||
| reqCtx.requestID = reqID | ||
|
|
||
| // Send to processing channel | ||
| s.processingChan <- reqCtx | ||
| } else { | ||
| // Can't process yet, keep in queue | ||
| newQueue = append(newQueue, reqCtx) | ||
| } | ||
| } | ||
| waitingQueue = newQueue | ||
| } | ||
| } | ||
| } | ||
|
|
||
| func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { | ||
| for { | ||
| select { | ||
| case <-ctx.Done(): | ||
| s.logger.Info("reqProcessingWorker stopped:", "worker id", id) | ||
| return | ||
| case reqCtx, ok := <-s.reqChan: | ||
| case reqCtx, ok := <-s.processingChan: | ||
| if !ok { | ||
| s.logger.Info("reqProcessingWorker worker exiting: reqChan closed") | ||
| s.logger.Info("reqProcessingWorker worker exiting: processingChan closed") | ||
| return | ||
| } | ||
| atomic.StoreInt64(&(s.nWaitingReqs), int64(len(s.reqChan))) | ||
|
|
@@ -449,7 +588,8 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { | |
| // TODO - check if this request went to the waiting queue - add it to waiting map | ||
| s.reportLoras() | ||
| } | ||
| atomic.AddInt64(&(s.nRunningReqs), 1) | ||
|
|
||
| // Note: we don't increment nRunningReqs here because it's already done in addRunningRequest | ||
| s.reportRunningRequests() | ||
|
|
||
| var responseTokens []string | ||
|
|
@@ -514,15 +654,18 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { | |
| req.doRemotePrefill()) | ||
| } | ||
| } | ||
|
|
||
| // Clean up the running request tracking | ||
| s.removeRunningRequest(reqCtx.requestID) | ||
|
|
||
| reqCtx.wg.Done() | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // decrease model usage reference number | ||
| func (s *VllmSimulator) responseSentCallback(model string) { | ||
|
|
||
| atomic.AddInt64(&(s.nRunningReqs), -1) | ||
| // Note: nRunningReqs is now decremented in removeRunningRequest | ||
| s.reportRunningRequests() | ||
|
|
||
| // Only LoRA models require reference-count handling. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please move this (and all other occurrences) to createDefaultConfig().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done