diff --git a/README.md b/README.md index 93388ac7..3e6aa3b0 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,7 @@ For more details see the = int64(s.config.MaxNumSeqs) { + return false + } + + // If max-num-batched-tokens is not configured (0), only check max-num-seqs + if s.config.MaxNumBatchedTokens <= 0 { + return true + } + + // Calculate tokens needed for this request + requestTokens := s.calculateProcessingTokens(req) + currentTokens := atomic.LoadInt64(&s.processingTokensCount) + + // Check max-num-batched-tokens constraint + return currentTokens+int64(requestTokens) <= int64(s.config.MaxNumBatchedTokens) +} + +// addRunningRequest adds a request to the running requests tracking +func (s *VllmSimulator) addRunningRequest(reqCtx *completionReqCtx) { + processingTokens := s.calculateProcessingTokens(reqCtx.completionReq) + reqCtx.processingTokens = processingTokens + + atomic.AddInt64(&s.processingTokensCount, int64(processingTokens)) + atomic.AddInt64(&s.nRunningReqs, 1) +} + +// removeRunningRequest removes a request from the running requests tracking +func (s *VllmSimulator) removeRunningRequest(reqCtx *completionReqCtx) { + atomic.AddInt64(&s.processingTokensCount, -int64(reqCtx.processingTokens)) + atomic.AddInt64(&s.nRunningReqs, -1) +} + // handleCompletions general completion requests handler, support both text and chat completion APIs func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatCompletion bool) { vllmReq, err := s.readRequest(ctx, isChatCompletion) @@ -400,6 +461,16 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple return } + // Validate max-num-batched-tokens constraint - reject requests that would never be accepted + if s.config.MaxNumBatchedTokens > 0 { + requestTokens := s.calculateProcessingTokens(vllmReq) + if requestTokens > s.config.MaxNumBatchedTokens { + s.sendCompletionError(ctx, fmt.Sprintf("Request requires %d tokens, but max-num-batched-tokens is set to %d. This request would never be accepted. Please reduce max_tokens or increase max-num-batched-tokens", + requestTokens, s.config.MaxNumBatchedTokens), "BadRequestError", fasthttp.StatusBadRequest) + return + } + } + var wg sync.WaitGroup wg.Add(1) reqCtx := &completionReqCtx{ @@ -414,15 +485,54 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple wg.Wait() } +func (s *VllmSimulator) queueManager(ctx context.Context) { + // Use a slice to maintain the queue of waiting requests + var waitingQueue []*completionReqCtx + ticker := time.NewTicker(10 * time.Millisecond) // Check every 10ms if we can process waiting requests + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + s.logger.Info("queueManager stopped") + return + case reqCtx := <-s.reqChan: + // Add new request to the waiting queue + waitingQueue = append(waitingQueue, reqCtx) + case <-ticker.C: + // Periodically check if we can process waiting requests + if len(waitingQueue) == 0 { + continue + } + + // Try to process requests from the front of the queue + var newQueue []*completionReqCtx + for _, reqCtx := range waitingQueue { + if s.canAcceptRequest(reqCtx.completionReq) { + // Add to running requests tracking + s.addRunningRequest(reqCtx) + + // Send to processing channel + s.processingChan <- reqCtx + } else { + // Can't process yet, keep in queue + newQueue = append(newQueue, reqCtx) + } + } + waitingQueue = newQueue + } + } +} + func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { for { select { case <-ctx.Done(): s.logger.Info("reqProcessingWorker stopped:", "worker id", id) return - case reqCtx, ok := <-s.reqChan: + case reqCtx, ok := <-s.processingChan: if !ok { - s.logger.Info("reqProcessingWorker worker exiting: reqChan closed") + s.logger.Info("reqProcessingWorker worker exiting: processingChan closed") return } atomic.StoreInt64(&(s.nWaitingReqs), int64(len(s.reqChan))) @@ -449,7 +559,8 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { // TODO - check if this request went to the waiting queue - add it to waiting map s.reportLoras() } - atomic.AddInt64(&(s.nRunningReqs), 1) + + // Note: we don't increment nRunningReqs here because it's already done in addRunningRequest s.reportRunningRequests() var responseTokens []string @@ -514,6 +625,10 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { req.doRemotePrefill()) } } + + // Clean up the running request tracking + s.removeRunningRequest(reqCtx) + reqCtx.wg.Done() } } @@ -521,8 +636,7 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { // decrease model usage reference number func (s *VllmSimulator) responseSentCallback(model string) { - - atomic.AddInt64(&(s.nRunningReqs), -1) + // Note: nRunningReqs is now decremented in removeRunningRequest s.reportRunningRequests() // Only LoRA models require reference-count handling. diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 22a507ae..573688f2 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -65,6 +65,9 @@ func startServerWithArgs(ctx context.Context, mode string, args []string) (*http return nil, err } + // run queue manager that handles request constraints + go s.queueManager(ctx) + // run request processing workers for i := 1; i <= s.config.MaxNumSeqs; i++ { go s.reqProcessingWorker(ctx, i) @@ -489,4 +492,188 @@ var _ = Describe("Simulator", func() { Expect(string(body)).To(ContainSubstring("BadRequestError")) }) }) + + Context("max-num-batched-tokens functionality", func() { + var simulator *VllmSimulator + + BeforeEach(func() { + var err error + simulator, err = New(klog.Background()) + Expect(err).NotTo(HaveOccurred()) + + // Setup basic configuration + simulator.config = newConfig() + simulator.config.Model = "test-model" + simulator.config.MaxModelLen = 1024 + simulator.config.MaxNumSeqs = 5 + simulator.config.MaxNumBatchedTokens = 2048 + }) + + Describe("calculateProcessingTokens", func() { + It("should calculate tokens with explicit max_tokens", func() { + req := &chatCompletionRequest{ + baseCompletionRequest: baseCompletionRequest{ + Model: "test-model", + }, + Messages: []message{ + {Role: "user", Content: content{Raw: "Hello world"}}, + }, + MaxTokens: int64Ptr(100), + } + + // Mock the token counting (in real implementation, this would tokenize the message) + // For test purposes, assume "Hello world" = 2 tokens + tokens := simulator.calculateProcessingTokens(req) + + // Should be prompt tokens (2) + max tokens (100) = 102 + // Note: In real implementation, this depends on the actual tokenization + Expect(tokens).To(BeNumerically(">=", 100)) + }) + + It("should calculate tokens without max_tokens using max-model-len", func() { + req := &chatCompletionRequest{ + baseCompletionRequest: baseCompletionRequest{ + Model: "test-model", + }, + Messages: []message{ + {Role: "user", Content: content{Raw: "Hello world"}}, + }, + } + + tokens := simulator.calculateProcessingTokens(req) + + // Should be prompt tokens + (max-model-len - prompt tokens) + // which equals max-model-len = 1024 + Expect(tokens).To(Equal(1024)) + }) + }) + + Describe("canAcceptRequest", func() { + It("should accept request when within both constraints", func() { + simulator.config.MaxNumSeqs = 2 + simulator.config.MaxNumBatchedTokens = 2048 + + req := &chatCompletionRequest{ + baseCompletionRequest: baseCompletionRequest{ + Model: "test-model", + }, + Messages: []message{ + {Role: "user", Content: content{Raw: "Hello"}}, + }, + MaxTokens: int64Ptr(100), + } + + canAccept := simulator.canAcceptRequest(req) + Expect(canAccept).To(BeTrue()) + }) + + It("should reject request when max-num-seqs is exceeded", func() { + simulator.config.MaxNumSeqs = 1 + simulator.config.MaxNumBatchedTokens = 2048 + + // Simulate one request already running + simulator.nRunningReqs = 1 + + req := &chatCompletionRequest{ + baseCompletionRequest: baseCompletionRequest{ + Model: "test-model", + }, + Messages: []message{ + {Role: "user", Content: content{Raw: "Hello"}}, + }, + MaxTokens: int64Ptr(100), + } + + canAccept := simulator.canAcceptRequest(req) + Expect(canAccept).To(BeFalse()) + }) + + It("should reject request when max-num-batched-tokens would be exceeded", func() { + simulator.config.MaxNumSeqs = 5 + simulator.config.MaxNumBatchedTokens = 500 + + // Simulate tokens already being used + simulator.processingTokensCount = 400 + + req := &chatCompletionRequest{ + baseCompletionRequest: baseCompletionRequest{ + Model: "test-model", + }, + Messages: []message{ + {Role: "user", Content: content{Raw: "Hello"}}, + }, + MaxTokens: int64Ptr(200), // This would exceed the limit (400 + 200+ > 500) + } + + canAccept := simulator.canAcceptRequest(req) + Expect(canAccept).To(BeFalse()) + }) + + It("should ignore batched tokens constraint when MaxNumBatchedTokens is 0", func() { + simulator.config.MaxNumSeqs = 5 + simulator.config.MaxNumBatchedTokens = 0 // Disabled + + // Simulate a lot of tokens being used + simulator.processingTokensCount = 10000 + + req := &chatCompletionRequest{ + baseCompletionRequest: baseCompletionRequest{ + Model: "test-model", + }, + Messages: []message{ + {Role: "user", Content: content{Raw: "Hello"}}, + }, + MaxTokens: int64Ptr(200), + } + + canAccept := simulator.canAcceptRequest(req) + Expect(canAccept).To(BeTrue()) // Should only check max-num-seqs + }) + }) + + It("Should start with max-num-batched-tokens parameter", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", model, "--mode", modeRandom, "--max-num-batched-tokens", "1024"} + client, err := startServerWithArgs(ctx, modeRandom, args) + Expect(err).NotTo(HaveOccurred()) + Expect(client).NotTo(BeNil()) + }) + + It("Should reject requests that exceed max-num-batched-tokens immediately", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", model, "--mode", modeRandom, "--max-num-batched-tokens", "10"} + client, err := startServerWithArgs(ctx, modeRandom, args) + Expect(err).NotTo(HaveOccurred()) + Expect(client).NotTo(BeNil()) + + // Create a request that requires more than 10 tokens (4 prompt + 20 max_tokens = 24 tokens) + reqBody := `{ + "messages": [ + {"role": "user", "content": "Hello world test prompt"} + ], + "model": "my_model", + "max_tokens": 20 + }` + + resp, err := client.Post("http://localhost/v1/chat/completions", "application/json", strings.NewReader(reqBody)) + Expect(err).NotTo(HaveOccurred()) + defer func() { + err := resp.Body.Close() + Expect(err).NotTo(HaveOccurred()) + }() + + body, err := io.ReadAll(resp.Body) + Expect(err).NotTo(HaveOccurred()) + + Expect(resp.StatusCode).To(Equal(400)) + Expect(string(body)).To(ContainSubstring("Request requires")) + Expect(string(body)).To(ContainSubstring("max-num-batched-tokens is set to 10")) + Expect(string(body)).To(ContainSubstring("would never be accepted")) + }) + }) }) + +// Helper function to create int64 pointer +func int64Ptr(i int64) *int64 { + return &i +}