@@ -24,10 +24,11 @@ import com.autodeployai.serving.utils.Utils.toOption
2424import com .typesafe .config .{Config , ConfigFactory , ConfigRenderOptions }
2525import org .slf4j .{Logger , LoggerFactory }
2626
27- import java .util .{ Timer , TimerTask }
27+ import java .util .concurrent .{ Executors , ScheduledExecutorService , ThreadFactory , TimeUnit }
2828import scala .collection .concurrent .TrieMap
2929import scala .collection .mutable .ArrayBuffer
3030import scala .concurrent .{ExecutionContext , Future , Promise }
31+ import scala .util .Using
3132
3233/**
3334 * Main entry of models validation, management and deployment.
@@ -49,6 +50,14 @@ object InferenceService extends JsonSupport {
4950
5051 private val repositories : TrieMap [String , ModelRepository ] = TrieMap .empty
5152
53+ // Timeout scheduler
54+ private val timeoutScheduler : ScheduledExecutorService = Executors .newSingleThreadScheduledExecutor(
55+ (r : Runnable ) => {
56+ val thread = new Thread (r, " ai-serving-timeout-scheduler" )
57+ thread.setDaemon(true )
58+ thread
59+ })
60+
5261 // A flag if the inference service is ready to response requests
5362 var isReady = false
5463
@@ -82,16 +91,16 @@ object InferenceService extends JsonSupport {
8291 def loadModels ()(implicit ec : ExecutionContext ): Unit = {
8392 log.info(s " Loading models under the directory: $modelsPath" )
8493
85- val files = Files .list(modelsPath)
86- val it = files.iterator()
87- while (it.hasNext) {
88- val path = it.next()
89- if (Files .isDirectory(path)) {
90- val modelRepository = loadModel(path)
91- modelRepository.foreach(x => repositories.put(x.modelName, x))
94+ Using (Files .list(modelsPath)) { files =>
95+ val it = files.iterator()
96+ while (it.hasNext) {
97+ val path = it.next()
98+ if (Files .isDirectory(path)) {
99+ val modelRepository = loadModel(path)
100+ modelRepository.foreach(x => repositories.put(x.modelName, x))
101+ }
92102 }
93103 }
94-
95104 isReady = true
96105 }
97106
@@ -204,7 +213,7 @@ object InferenceService extends JsonSupport {
204213 model.predict(request, runOptions)
205214 }
206215
207- withTimeout(futureResult, modelName, modelVersion , runOptions)
216+ withTimeout(futureResult, model , runOptions)
208217 }
209218
210219 /**
@@ -229,7 +238,7 @@ object InferenceService extends JsonSupport {
229238 }
230239 }
231240
232- withTimeout(futureResult, modelName, modelVersion , runOptions)
241+ withTimeout(futureResult, model , runOptions)
233242 }
234243
235244 /**
@@ -519,26 +528,27 @@ object InferenceService extends JsonSupport {
519528 val modelRepository = new ModelRepository (modelName, modelConfig)
520529
521530 val versions = ArrayBuffer .empty[String ]
522- val files = Files .list(modelPath)
523- val it = files.iterator()
524- while (it.hasNext) {
525- val path = it.next()
526- if (Files .isDirectory(path)) {
527- val modelVersion = path.getFileName.toString
528- versions += modelVersion
529- val (modelObjectPath, modelType) = getModelObjectPath(path)
530- if (modelObjectPath != null ) {
531- try {
532- log.info(s " Loading model: $modelName with the version $modelVersion" )
533-
534- // Load the version config
535- val versionConfig = getModelConfig(path)
536-
537- val model = PredictModel .load(modelObjectPath, modelType, modelName, modelVersion, versionConfig.orElse(modelConfig))
538- modelRepository.put(modelVersion, model)
539- } catch {
540- case e : Exception =>
541- log.error(s " Failed to load model $modelName with the version $modelVersion caused: $e" )
531+ Using (Files .list(modelPath)) { files =>
532+ val it = files.iterator()
533+ while (it.hasNext) {
534+ val path = it.next()
535+ if (Files .isDirectory(path)) {
536+ val modelVersion = path.getFileName.toString
537+ versions += modelVersion
538+ val (modelObjectPath, modelType) = getModelObjectPath(path)
539+ if (modelObjectPath != null ) {
540+ try {
541+ log.info(s " Loading model: $modelName with the version $modelVersion" )
542+
543+ // Load the version config
544+ val versionConfig = getModelConfig(path)
545+
546+ val model = PredictModel .load(modelObjectPath, modelType, modelName, modelVersion, versionConfig.orElse(modelConfig))
547+ modelRepository.put(modelVersion, model)
548+ } catch {
549+ case e : Exception =>
550+ log.error(s " Failed to load model $modelName with the version $modelVersion caused: $e" )
551+ }
542552 }
543553 }
544554 }
@@ -590,29 +600,27 @@ object InferenceService extends JsonSupport {
590600 * @tparam T
591601 * @return
592602 */
593- private def withTimeout [T ](future : Future [T ], modelName : String , modelVersion : Option [String ], runOptions : RunOptions )(implicit ec : ExecutionContext ): Future [T ] = {
594- val repository = repositories.get(modelName)
595- repository.flatMap(_.getTimeoutDuration(modelVersion)) match {
596- case Some (timeout) =>
597- val promise = Promise [ T ]()
598- val task = new TimerTask {
603+ private def withTimeout [T ](future : Future [T ], model : PredictModel , runOptions : Option [RunOptions ] )(implicit ec : ExecutionContext ): Future [T ] = {
604+ val timeout = model.timeout
605+ if (timeout > 0 ) {
606+ val promise = Promise [ T ]()
607+ val timeoutFuture = timeoutScheduler.schedule(
608+ new Runnable {
599609 override def run (): Unit = {
600610 if (! promise.isCompleted) {
601- runOptions.terminate()
602- promise.tryFailure(InferTimeoutException (modelName, modelVersion, timeout.toMillis ))
611+ runOptions.foreach(_. terminate() )
612+ promise.tryFailure(InferTimeoutException (model. modelName, Option (model. modelVersion) , timeout))
603613 }
604614 }
605- }
606- val timer = new Timer ( true )
607- timer.schedule(task, timeout.toMillis)
608- future.onComplete { result =>
609- timer.cancel()
610- promise.tryComplete(result )
611- }
612- promise.future
613- case _ => future
614- }
615+ },
616+ timeout,
617+ TimeUnit . MILLISECONDS
618+ )
619+ future.onComplete { result =>
620+ timeoutFuture.cancel( false )
621+ promise.tryComplete(result)
622+ }
623+ promise. future
624+ } else future
615625 }
616-
617626}
618-
0 commit comments