Skip to content

Commit dad659e

Browse files
authored
Refactor timeout and batch inference (#20)
* Refactor timeout implementation * Refactor batch inference * Add unittest action
1 parent 84de66f commit dad659e

23 files changed

Lines changed: 16037 additions & 391 deletions

.github/workflows/unittest.yml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
name: Unittest
2+
on:
3+
push:
4+
branches: [master, main]
5+
permissions:
6+
contents: read
7+
8+
jobs:
9+
unittest:
10+
runs-on: ubuntu-22.04
11+
steps:
12+
- uses: actions/checkout@v4
13+
with:
14+
fetch-depth: 0
15+
- uses: actions/setup-java@v4
16+
with:
17+
distribution: temurin
18+
java-version: 17
19+
cache: sbt
20+
- run: sbt test

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@
99
/src/test/resources/response_0.pb
1010
/.bloop/
1111
/.metals/
12+
/.bsp/
13+
/examples/.ipynb_checkpoints
14+
/examples/__pycache__/

build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name := (sys.props.getOrElse("gpu", "false") match {
33
case _ => "ai-serving"
44
})
55

6-
version := "2.1.0"
6+
version := "2.2.0"
77

88
organization := "com.autodeployai"
99

examples/models/xgb-iris.pmml

Lines changed: 7631 additions & 227 deletions
Large diffs are not rendered by default.

src/main/resources/application.conf

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ service {
4242
}
4343

4444
home = "/opt/ai-serving"
45+
46+
logging {
47+
request-timing-enabled = false
48+
request-timing-level = "DEBUG" // one of "DEBUG", "INFO", "WARNING", "ERROR"
49+
}
4550
}
4651

4752
onnxruntime {
@@ -53,4 +58,3 @@ onnxruntime {
5358
logger-id = "onnxruntime"
5459
logging-level = 3 // 0: VERBOSE, 1: INFO, 2: WARNING, 3: ERROR, 4: FATAL
5560
}
56-

src/main/resources/logback.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
<appender-ref ref="STDOUT"/>
1212
</logger>
1313

14+
<logger name="akka" level="${LOG_LEVEL:-INFO}" additivity="false">
15+
<appender-ref ref="STDOUT"/>
16+
</logger>
17+
1418
<logger name="io.netty" level="${LOG_LEVEL:-INFO}" additivity="false">
1519
<appender-ref ref="STDOUT"/>
1620
</logger>

src/main/scala/com/autodeployai/serving/AIServer.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ object AIServer extends Endpoints with EndpointsV2 {
3939
if (config.hasPath(defaultFixedPoolSizePath)) {
4040
var numCores = config.getInt(defaultFixedPoolSizePath)
4141
if (numCores == -1) {
42+
val onnxBackend = if (config.hasPath("onnxruntime.backend")) config.getString("onnxruntime.backend").toLowerCase else "cpu"
43+
val onnxThreads = if (config.hasPath("onnxruntime.cpu-num-threads")) config.getInt("onnxruntime.cpu-num-threads") else -1
44+
45+
if (onnxBackend == "cpu" && onnxThreads == -1) {
46+
log.warn("Please reserve sufficient CPU capacity for ONNX Runtime to prevent oversubscription when serving ONNX models on CPU.")
47+
}
4248
numCores = Utils.getNumCores
4349
}
4450
config = config.withValue(defaultFixedPoolSizePath, ConfigValueFactory.fromAnyRef(numCores))

src/main/scala/com/autodeployai/serving/deploy/BatchProcessor.scala

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,27 @@ import com.autodeployai.serving.utils.{DataUtils, Utils}
2121
import org.slf4j.{Logger, LoggerFactory}
2222

2323
import java.nio.{ByteBuffer, ByteOrder, DoubleBuffer, FloatBuffer, IntBuffer, LongBuffer, ShortBuffer}
24+
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
2425
import java.util.concurrent.{ConcurrentLinkedQueue, Executors, TimeUnit}
2526
import scala.concurrent.{ExecutionContext, Future, Promise}
2627

2728
trait BatchRequest[Request, Response] {
2829
def request: Request
2930
def promise: Promise[Response]
30-
def options: RunOptions
31+
def options: Option[RunOptions]
3132
def timestamp: Long
3233
}
3334

3435
case class BatchRequestV2(
3536
request: InferenceRequest,
3637
promise: Promise[InferenceResponse],
37-
options: RunOptions,
38+
options: Option[RunOptions],
3839
timestamp: Long = System.currentTimeMillis(),
3940
) extends BatchRequest[InferenceRequest, InferenceResponse]
4041

4142
trait BatchProcessor[Request, Response] extends AutoCloseable {
4243

43-
def predict(request: Request, options: RunOptions): Future[Response]
44+
def predict(request: Request, options: Option[RunOptions]): Future[Response]
4445

4546
def merge(requests: Array[Request]): Request
4647

@@ -53,6 +54,8 @@ class BatchProcessorV2(model: PredictModel,
5354

5455
val log: Logger = LoggerFactory.getLogger(this.getClass)
5556
private val queue = new ConcurrentLinkedQueue[BatchRequestV2]()
57+
private val queueLength = new AtomicInteger(0)
58+
private val processing = new AtomicBoolean(false)
5659
private val enabled: Boolean = maxBatchSize > 1
5760
private val checkInterval: Long = Math.max(maxBatchDelayMs / 2, 1L)
5861

@@ -76,16 +79,15 @@ class BatchProcessorV2(model: PredictModel,
7679

7780
log.info(s"BatchProcessor for model ${model.modelName}:${model.modelVersion} initialized: max-batch-size=$maxBatchSize, max-batch-delay-ms=$maxBatchDelayMs")
7881

79-
override def predict(request: InferenceRequest, options: RunOptions): Future[InferenceResponse] = {
82+
override def predict(request: InferenceRequest, options: Option[RunOptions]): Future[InferenceResponse] = {
8083
if (enabled) {
8184
val promise = Promise[InferenceResponse]()
8285
val batchRequest = BatchRequestV2(request=request, promise=promise, options=options)
8386
queue.offer(batchRequest)
87+
val currentLength = queueLength.incrementAndGet()
8488

85-
if (queue.size() >= maxBatchSize) {
86-
Future{
87-
processBatch()
88-
}
89+
if (currentLength >= maxBatchSize) {
90+
processBatchAsync()
8991
}
9092
promise.future
9193
} else {
@@ -99,18 +101,38 @@ class BatchProcessorV2(model: PredictModel,
99101
if (!queue.isEmpty) {
100102
val oldestRequest = queue.peek()
101103
if (oldestRequest != null && (timestamp - oldestRequest.timestamp) >= maxBatchDelayMs) {
102-
processBatch()
104+
processBatchAsync()
105+
}
106+
}
107+
}
108+
109+
private def processBatchAsync(): Unit = {
110+
if (processing.compareAndSet(false, true)) {
111+
Future {
112+
try {
113+
var keepRunning = true
114+
while (keepRunning) {
115+
val processed = processBatch()
116+
keepRunning = processed && queueLength.get() >= maxBatchSize
117+
}
118+
} finally {
119+
processing.set(false)
120+
if (queueLength.get() >= maxBatchSize) {
121+
processBatchAsync()
122+
}
123+
}
103124
}
104125
}
105126
}
106127

107-
private def processBatch(): Unit = this.synchronized {
128+
private def processBatch(): Boolean = this.synchronized {
108129
val builder = Array.newBuilder[BatchRequestV2]
109130
builder.sizeHint(maxBatchSize)
110131
var exit = false
111132
while (builder.length < maxBatchSize && !exit) {
112133
val item = queue.poll()
113134
if (item != null) {
135+
queueLength.decrementAndGet()
114136
builder += item
115137
} else {
116138
exit = true
@@ -128,7 +150,7 @@ class BatchProcessorV2(model: PredictModel,
128150

129151
val startTime = System.currentTimeMillis()
130152
val batchResponse = model.predict(mergedRequest, options)
131-
log.info(s"Batched ${batch.length} requests elapsed time: ${System.currentTimeMillis() - startTime}")
153+
log.debug(s"Batched ${batch.length} requests elapsed time: ${System.currentTimeMillis() - startTime}")
132154

133155
val recordCounts = requests.map(req => {
134156
if (req.inputs.nonEmpty) {
@@ -147,8 +169,11 @@ class BatchProcessorV2(model: PredictModel,
147169
batch.foreach(_.promise.failure(ex))
148170
} finally {
149171
// Close options of other requests
150-
batch.tail.foreach(x => Utils.safeClose(x.options))
172+
batch.tail.foreach(x => Utils.safeClose(x.options.orNull))
151173
}
174+
true
175+
} else {
176+
false
152177
}
153178
}
154179

@@ -313,7 +338,7 @@ class BatchProcessorV2(model: PredictModel,
313338
scheduler.foreach(s =>{
314339
s.shutdown()
315340
try {
316-
if (s.awaitTermination(5, TimeUnit.SECONDS)) {
341+
if (!s.awaitTermination(5, TimeUnit.SECONDS)) {
317342
s.shutdownNow()
318343
}
319344
} catch {
@@ -323,7 +348,7 @@ class BatchProcessorV2(model: PredictModel,
323348
})
324349

325350
if (!queue.isEmpty) {
326-
processBatch()
351+
while (processBatch()) {}
327352
}
328353
}
329354
}

src/main/scala/com/autodeployai/serving/deploy/InferenceService.scala

Lines changed: 59 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ import com.autodeployai.serving.utils.Utils.toOption
2424
import com.typesafe.config.{Config, ConfigFactory, ConfigRenderOptions}
2525
import org.slf4j.{Logger, LoggerFactory}
2626

27-
import java.util.{Timer, TimerTask}
27+
import java.util.concurrent.{Executors, ScheduledExecutorService, ThreadFactory, TimeUnit}
2828
import scala.collection.concurrent.TrieMap
2929
import scala.collection.mutable.ArrayBuffer
3030
import scala.concurrent.{ExecutionContext, Future, Promise}
31+
import scala.util.Using
3132

3233
/**
3334
* Main entry of models validation, management and deployment.
@@ -49,6 +50,14 @@ object InferenceService extends JsonSupport {
4950

5051
private val repositories: TrieMap[String, ModelRepository] = TrieMap.empty
5152

53+
// Timeout scheduler
54+
private val timeoutScheduler: ScheduledExecutorService = Executors.newSingleThreadScheduledExecutor(
55+
(r: Runnable) => {
56+
val thread = new Thread(r, "ai-serving-timeout-scheduler")
57+
thread.setDaemon(true)
58+
thread
59+
})
60+
5261
// A flag if the inference service is ready to response requests
5362
var isReady = false
5463

@@ -82,16 +91,16 @@ object InferenceService extends JsonSupport {
8291
def loadModels()(implicit ec: ExecutionContext): Unit = {
8392
log.info(s"Loading models under the directory: $modelsPath")
8493

85-
val files = Files.list(modelsPath)
86-
val it = files.iterator()
87-
while (it.hasNext) {
88-
val path = it.next()
89-
if (Files.isDirectory(path)) {
90-
val modelRepository = loadModel(path)
91-
modelRepository.foreach(x => repositories.put(x.modelName, x))
94+
Using(Files.list(modelsPath)) { files =>
95+
val it = files.iterator()
96+
while (it.hasNext) {
97+
val path = it.next()
98+
if (Files.isDirectory(path)) {
99+
val modelRepository = loadModel(path)
100+
modelRepository.foreach(x => repositories.put(x.modelName, x))
101+
}
92102
}
93103
}
94-
95104
isReady = true
96105
}
97106

@@ -204,7 +213,7 @@ object InferenceService extends JsonSupport {
204213
model.predict(request, runOptions)
205214
}
206215

207-
withTimeout(futureResult, modelName, modelVersion, runOptions)
216+
withTimeout(futureResult, model, runOptions)
208217
}
209218

210219
/**
@@ -229,7 +238,7 @@ object InferenceService extends JsonSupport {
229238
}
230239
}
231240

232-
withTimeout(futureResult, modelName, modelVersion, runOptions)
241+
withTimeout(futureResult, model, runOptions)
233242
}
234243

235244
/**
@@ -519,26 +528,27 @@ object InferenceService extends JsonSupport {
519528
val modelRepository = new ModelRepository(modelName, modelConfig)
520529

521530
val versions = ArrayBuffer.empty[String]
522-
val files = Files.list(modelPath)
523-
val it = files.iterator()
524-
while (it.hasNext) {
525-
val path = it.next()
526-
if (Files.isDirectory(path)) {
527-
val modelVersion = path.getFileName.toString
528-
versions += modelVersion
529-
val (modelObjectPath, modelType) = getModelObjectPath(path)
530-
if (modelObjectPath != null) {
531-
try {
532-
log.info(s"Loading model: $modelName with the version $modelVersion")
533-
534-
// Load the version config
535-
val versionConfig = getModelConfig(path)
536-
537-
val model = PredictModel.load(modelObjectPath, modelType, modelName, modelVersion, versionConfig.orElse(modelConfig))
538-
modelRepository.put(modelVersion, model)
539-
} catch {
540-
case e: Exception =>
541-
log.error(s"Failed to load model $modelName with the version $modelVersion caused: $e")
531+
Using(Files.list(modelPath)) { files =>
532+
val it = files.iterator()
533+
while (it.hasNext) {
534+
val path = it.next()
535+
if (Files.isDirectory(path)) {
536+
val modelVersion = path.getFileName.toString
537+
versions += modelVersion
538+
val (modelObjectPath, modelType) = getModelObjectPath(path)
539+
if (modelObjectPath != null) {
540+
try {
541+
log.info(s"Loading model: $modelName with the version $modelVersion")
542+
543+
// Load the version config
544+
val versionConfig = getModelConfig(path)
545+
546+
val model = PredictModel.load(modelObjectPath, modelType, modelName, modelVersion, versionConfig.orElse(modelConfig))
547+
modelRepository.put(modelVersion, model)
548+
} catch {
549+
case e: Exception =>
550+
log.error(s"Failed to load model $modelName with the version $modelVersion caused: $e")
551+
}
542552
}
543553
}
544554
}
@@ -590,29 +600,27 @@ object InferenceService extends JsonSupport {
590600
* @tparam T
591601
* @return
592602
*/
593-
private def withTimeout[T](future: Future[T], modelName: String, modelVersion: Option[String], runOptions: RunOptions)(implicit ec: ExecutionContext): Future[T] = {
594-
val repository = repositories.get(modelName)
595-
repository.flatMap(_.getTimeoutDuration(modelVersion)) match {
596-
case Some(timeout) =>
597-
val promise = Promise[T]()
598-
val task = new TimerTask {
603+
private def withTimeout[T](future: Future[T], model: PredictModel, runOptions: Option[RunOptions])(implicit ec: ExecutionContext): Future[T] = {
604+
val timeout = model.timeout
605+
if (timeout > 0) {
606+
val promise = Promise[T]()
607+
val timeoutFuture = timeoutScheduler.schedule(
608+
new Runnable {
599609
override def run(): Unit = {
600610
if (!promise.isCompleted) {
601-
runOptions.terminate()
602-
promise.tryFailure(InferTimeoutException(modelName, modelVersion, timeout.toMillis))
611+
runOptions.foreach(_.terminate())
612+
promise.tryFailure(InferTimeoutException(model.modelName, Option(model.modelVersion), timeout))
603613
}
604614
}
605-
}
606-
val timer = new Timer(true)
607-
timer.schedule(task, timeout.toMillis)
608-
future.onComplete { result =>
609-
timer.cancel()
610-
promise.tryComplete(result)
611-
}
612-
promise.future
613-
case _ => future
614-
}
615+
},
616+
timeout,
617+
TimeUnit.MILLISECONDS
618+
)
619+
future.onComplete { result =>
620+
timeoutFuture.cancel(false)
621+
promise.tryComplete(result)
622+
}
623+
promise.future
624+
} else future
615625
}
616-
617626
}
618-

0 commit comments

Comments
 (0)