diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index dc1f23b83fc..82eccbec021 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -226,6 +226,7 @@ public enum Builtins { LMDS("lmDS", true), LMPREDICT("lmPredict", true), LMPREDICT_STATS("lmPredictStats", true), + LLMPREDICT("llmPredict", false, true), LOCAL("local", false), LOG("log", false), LOGSUMEXP("logSumExp", true), diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 1b0536416d6..94055d055c5 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -204,6 +204,7 @@ public enum Opcodes { GROUPEDAGG("groupedagg", InstructionType.ParameterizedBuiltin), RMEMPTY("rmempty", InstructionType.ParameterizedBuiltin), REPLACE("replace", InstructionType.ParameterizedBuiltin), + LLMPREDICT("llmpredict", InstructionType.ParameterizedBuiltin), LOWERTRI("lowertri", InstructionType.ParameterizedBuiltin), UPPERTRI("uppertri", InstructionType.ParameterizedBuiltin), REXPAND("rexpand", InstructionType.ParameterizedBuiltin), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index 2e3543882d2..3414614991c 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -805,7 +805,7 @@ public static ReOrgOp valueOfByOpcode(String opcode) { /** Parameterized operations that require named variable arguments */ public enum ParamBuiltinOp { - AUTODIFF, CDF, CONTAINS, INVALID, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, REXPAND, + AUTODIFF, CDF, CONTAINS, INVALID, INVCDF, GROUPEDAGG, LLMPREDICT, RMEMPTY, REPLACE, REXPAND, LOWER_TRI, UPPER_TRI, TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMCOLMAP, TRANSFORMMETA, TOKENIZE, TOSTRING, LIST, PARAMSERV diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java index 61a4b8b8f91..b791478214b 100644 --- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java +++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java @@ -187,6 +187,7 @@ public Lop constructLops() case LOWER_TRI: case UPPER_TRI: case TOKENIZE: + case LLMPREDICT: case TRANSFORMAPPLY: case TRANSFORMDECODE: case TRANSFORMCOLMAP: @@ -758,7 +759,7 @@ && getTargetHop().areDimsBelowThreshold() ) { if (_op == ParamBuiltinOp.TRANSFORMCOLMAP || _op == ParamBuiltinOp.TRANSFORMMETA || _op == ParamBuiltinOp.TOSTRING || _op == ParamBuiltinOp.LIST || _op == ParamBuiltinOp.CDF || _op == ParamBuiltinOp.INVCDF - || _op == ParamBuiltinOp.PARAMSERV) { + || _op == ParamBuiltinOp.PARAMSERV || _op == ParamBuiltinOp.LLMPREDICT) { _etype = ExecType.CP; } @@ -768,7 +769,7 @@ && getTargetHop().areDimsBelowThreshold() ) { switch(_op) { case CONTAINS: if(getTargetHop().optFindExecType() == ExecType.SPARK) - _etype = ExecType.SPARK; + _etype = ExecType.SPARK; break; default: // Do not change execution type. diff --git a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java index 3604121aac8..dcec28f76ca 100644 --- a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java +++ b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java @@ -176,6 +176,7 @@ public String getInstructions(String output) case CONTAINS: case REPLACE: case TOKENIZE: + case LLMPREDICT: case TRANSFORMAPPLY: case TRANSFORMDECODE: case TRANSFORMCOLMAP: diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index c6e7188d7bc..b1536371711 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2007,6 +2007,7 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu case LOWER_TRI: case UPPER_TRI: case TOKENIZE: + case LLMPREDICT: case TRANSFORMAPPLY: case TRANSFORMDECODE: case TRANSFORMCOLMAP: diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java index 314440628e0..cd9699a1082 100644 --- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java @@ -61,6 +61,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier pbHopMap.put(Builtins.GROUPEDAGG, ParamBuiltinOp.GROUPEDAGG); pbHopMap.put(Builtins.RMEMPTY, ParamBuiltinOp.RMEMPTY); pbHopMap.put(Builtins.REPLACE, ParamBuiltinOp.REPLACE); + pbHopMap.put(Builtins.LLMPREDICT, ParamBuiltinOp.LLMPREDICT); pbHopMap.put(Builtins.LOWER_TRI, ParamBuiltinOp.LOWER_TRI); pbHopMap.put(Builtins.UPPER_TRI, ParamBuiltinOp.UPPER_TRI); @@ -211,6 +212,10 @@ public void validateExpression(HashMap ids, HashMap valid = new HashSet<>(Arrays.asList( + "target", "url", "model", "max_tokens", "temperature", "top_p", "concurrency")); + checkInvalidParameters(getOpCode(), getVarParams(), valid); + checkDataType(false, "llmPredict", TF_FN_PARAM_DATA, DataType.FRAME, conditional); + checkStringParam(false, "llmPredict", "url", conditional); + + // validate numeric parameter types at compile time (when literal). + // Note: no range validation -- different LLM servers accept different + // ranges (e.g. vLLM allows temperature=0.0, OpenAI requires >0). + // Runtime errors from the server are more informative than + // compile-time checks locked to one server's rules. + checkNumericScalarParam("llmPredict", "max_tokens", conditional); + checkNumericScalarParam("llmPredict", "temperature", conditional); + checkNumericScalarParam("llmPredict", "top_p", conditional); + checkNumericScalarParam("llmPredict", "concurrency", conditional); + + output.setDataType(DataType.FRAME); + output.setValueType(ValueType.STRING); + output.setDimensions(-1, -1); + } + + private void checkNumericScalarParam(String fname, String pname, boolean conditional) { + Expression expr = getVarParam(pname); + if(expr == null) return; + if(expr instanceof DataIdentifier) { + DataIdentifier di = (DataIdentifier) expr; + if(di.getDataType() != null && !di.getDataType().isScalar()) { + raiseValidateError( + String.format("Function %s: parameter '%s' must be a scalar, got %s.", + fname, pname, di.getDataType()), conditional); + } + } + } + // example: A = transformapply(target=X, meta=M, spec=s) private void validateTransformApply(DataIdentifier output, boolean conditional) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/LlmPredictCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/LlmPredictCPInstruction.java new file mode 100644 index 00000000000..da2c123e89a --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/LlmPredictCPInstruction.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.cp; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.ConnectException; +import java.net.HttpURLConnection; +import java.net.MalformedURLException; +import java.net.SocketTimeoutException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.lineage.LineageItem; +import org.apache.sysds.runtime.lineage.LineageItemUtils; +import org.apache.wink.json4j.JSONObject; + +public class LlmPredictCPInstruction extends ParameterizedBuiltinCPInstruction { + + protected LlmPredictCPInstruction(LinkedHashMap paramsMap, + CPOperand out, String opcode, String istr) { + super(null, paramsMap, out, opcode, istr); + } + + @Override + public void processInstruction(ExecutionContext ec) { + FrameBlock prompts = ec.getFrameInput(params.get("target")); + String url = params.get("url"); + String model = params.containsKey("model") ? + params.get("model") : null; + int maxTokens = params.containsKey("max_tokens") ? + Integer.parseInt(params.get("max_tokens")) : 512; + double temperature = params.containsKey("temperature") ? + Double.parseDouble(params.get("temperature")) : 0.0; + double topP = params.containsKey("top_p") ? + Double.parseDouble(params.get("top_p")) : 0.9; + int concurrency = params.containsKey("concurrency") ? + Integer.parseInt(params.get("concurrency")) : 1; + concurrency = Math.max(1, Math.min(concurrency, 128)); + + int n = prompts.getNumRows(); + String[][] data = new String[n][]; + + List> tasks = new ArrayList<>(n); + for(int i = 0; i < n; i++) { + String prompt = prompts.get(i, 0).toString(); + tasks.add(() -> callLlmEndpoint(prompt, url, model, maxTokens, temperature, topP)); + } + + try { + if(concurrency <= 1) { + for(int i = 0; i < n; i++) + data[i] = tasks.get(i).call(); + } + else { + ExecutorService pool = Executors.newFixedThreadPool( + Math.min(concurrency, n)); + List> futures = pool.invokeAll(tasks); + pool.shutdown(); + for(int i = 0; i < n; i++) + data[i] = futures.get(i).get(); + } + } + catch(DMLRuntimeException e) { + throw e; + } + catch(Exception e) { + throw new DMLRuntimeException("llmPredict failed: " + e.getMessage(), e); + } + + ValueType[] schema = {ValueType.STRING, ValueType.STRING, + ValueType.INT64, ValueType.INT64, ValueType.INT64}; + String[] colNames = {"prompt", "generated_text", "time_ms", "input_tokens", "output_tokens"}; + FrameBlock fbout = new FrameBlock(schema, colNames); + for(String[] row : data) + fbout.appendRow(row); + + ec.setFrameOutput(output.getName(), fbout); + ec.releaseFrameInput(params.get("target")); + } + + // No retry logic by design: as a database built-in, llmPredict should + // fail fast on transient errors and let the caller (DML script) decide + // whether and how to retry. Silent retries with backoff would make + // execution time unpredictable. + private static String[] callLlmEndpoint(String prompt, String url, + String model, int maxTokens, double temperature, double topP) { + long t0 = System.nanoTime(); + + // validate URL and open connection + HttpURLConnection conn; + try { + conn = (HttpURLConnection) new URI(url).toURL().openConnection(); + } + catch(URISyntaxException | MalformedURLException | IllegalArgumentException e) { + throw new DMLRuntimeException( + "llmPredict: invalid URL '" + url + "'. " + + "Expected format: http://host:port/v1/completions", e); + } + catch(IOException e) { + throw new DMLRuntimeException( + "llmPredict: cannot open connection to '" + url + "'.", e); + } + + try { + JSONObject req = new JSONObject(); + if(model != null) + req.put("model", model); + req.put("prompt", prompt); + req.put("max_tokens", maxTokens); + req.put("temperature", temperature); + req.put("top_p", topP); + + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json"); + conn.setConnectTimeout(10_000); + conn.setReadTimeout(300_000); + conn.setDoOutput(true); + + try(OutputStream os = conn.getOutputStream()) { + os.write(req.toString().getBytes(StandardCharsets.UTF_8)); + } + + int httpCode = conn.getResponseCode(); + if(httpCode != 200) { + String errBody = ""; + try(InputStream es = conn.getErrorStream()) { + if(es != null) + errBody = new String(es.readAllBytes(), StandardCharsets.UTF_8); + } + catch(Exception ignored) {} + throw new DMLRuntimeException( + "llmPredict: endpoint returned HTTP " + httpCode + + " for '" + url + "'." + + (errBody.isEmpty() ? "" : " Response: " + errBody)); + } + + String body; + try(InputStream is = conn.getInputStream()) { + body = new String(is.readAllBytes(), StandardCharsets.UTF_8); + } + + JSONObject resp = new JSONObject(body); + if(!resp.has("choices") || resp.getJSONArray("choices").length() == 0) { + String errMsg = resp.has("error") ? resp.optString("error") : body; + throw new DMLRuntimeException( + "llmPredict: server response missing 'choices'. Response: " + errMsg); + } + String text = resp.getJSONArray("choices") + .getJSONObject(0).getString("text"); + long elapsed = (System.nanoTime() - t0) / 1_000_000; + int inTok = 0, outTok = 0; + if(resp.has("usage")) { + JSONObject usage = resp.getJSONObject("usage"); + inTok = usage.has("prompt_tokens") ? usage.getInt("prompt_tokens") : 0; + outTok = usage.has("completion_tokens") ? usage.getInt("completion_tokens") : 0; + } + return new String[]{prompt, text, + String.valueOf(elapsed), String.valueOf(inTok), String.valueOf(outTok)}; + } + catch(ConnectException e) { + throw new DMLRuntimeException( + "llmPredict: connection refused to '" + url + "'. " + + "Ensure the LLM server is running and reachable.", e); + } + catch(SocketTimeoutException e) { + throw new DMLRuntimeException( + "llmPredict: timed out connecting to '" + url + "'. " + + "Ensure the LLM server is running and reachable.", e); + } + catch(IOException e) { + throw new DMLRuntimeException( + "llmPredict: I/O error communicating with '" + url + "'.", e); + } + catch(DMLRuntimeException e) { + throw e; + } + catch(Exception e) { + throw new DMLRuntimeException( + "llmPredict: failed to get response from '" + url + "'.", e); + } + finally { + conn.disconnect(); + } + } + + @Override + public Pair getLineageItem(ExecutionContext ec) { + CPOperand target = new CPOperand(params.get("target"), ValueType.STRING, DataType.FRAME); + CPOperand urlOp = new CPOperand(params.get("url"), ValueType.STRING, DataType.SCALAR, true); + return Pair.of(output.getName(), + new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, target, urlOp))); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java index 119589a3033..ac2f527f06c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java @@ -158,6 +158,9 @@ else if(opcode.equals(Opcodes.TRANSFORMAPPLY.toString()) || opcode.equals(Opcode || opcode.equals(Opcodes.TOSTRING.toString()) || opcode.equals(Opcodes.NVLIST.toString()) || opcode.equals(Opcodes.AUTODIFF.toString())) { return new ParameterizedBuiltinCPInstruction(null, paramsMap, out, opcode, str); } + else if(opcode.equals(Opcodes.LLMPREDICT.toString())) { + return new LlmPredictCPInstruction(paramsMap, out, opcode, str); + } else if(Opcodes.PARAMSERV.toString().equals(opcode)) { return new ParamservBuiltinCPInstruction(null, paramsMap, out, opcode, str); } @@ -324,6 +327,7 @@ else if(opcode.equalsIgnoreCase(Opcodes.TOKENIZE.toString())) { ec.setFrameOutput(output.getName(), fbout); ec.releaseFrameInput(params.get("target")); } + else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMAPPLY.toString())) { // acquire locks FrameBlock data = ec.getFrameInput(params.get("target")); diff --git a/src/main/python/llm_server.py b/src/main/python/llm_server.py new file mode 100644 index 00000000000..b538d871ba8 --- /dev/null +++ b/src/main/python/llm_server.py @@ -0,0 +1,117 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +"""Local inference server for llmPredict. Loads a HuggingFace model +and serves it at http://localhost:PORT/v1/completions. + +Usage: python llm_server.py distilgpt2 --port 8080 +""" + +import argparse +import json +import sys +import time +from http.server import HTTPServer, BaseHTTPRequestHandler + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM + + +class InferenceHandler(BaseHTTPRequestHandler): + + def do_POST(self): + if self.path != "/v1/completions": + self.send_error(404) + return + length = int(self.headers.get("Content-Length", 0)) + body = json.loads(self.rfile.read(length)) + + prompt = body.get("prompt", "") + max_tokens = int(body.get("max_tokens", 512)) + temperature = float(body.get("temperature", 0.0)) + top_p = float(body.get("top_p", 0.9)) + + model = self.server.model + tokenizer = self.server.tokenizer + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + input_len = inputs["input_ids"].shape[1] + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=max_tokens, + temperature=temperature if temperature > 0 else 1.0, + top_p=top_p, + do_sample=temperature > 0, + ) + new_tokens = outputs[0][input_len:] + text = tokenizer.decode(new_tokens, skip_special_tokens=True) + + resp = { + "choices": [{"text": text}], + "usage": { + "prompt_tokens": input_len, + "completion_tokens": len(new_tokens), + }, + } + payload = json.dumps(resp).encode("utf-8") + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(payload))) + self.end_headers() + self.wfile.write(payload) + + def log_message(self, fmt, *args): + sys.stderr.write("[llm_server] %s\n" % (fmt % args)) + + +def main(): + parser = argparse.ArgumentParser(description="OpenAI-compatible LLM server") + parser.add_argument("model", help="HuggingFace model name") + parser.add_argument("--port", type=int, default=8080) + args = parser.parse_args() + + print(f"Loading model: {args.model}", flush=True) + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if torch.cuda.is_available(): + print(f"CUDA available: {torch.cuda.device_count()} GPU(s)", flush=True) + model = AutoModelForCausalLM.from_pretrained( + args.model, device_map="auto", torch_dtype=torch.float16) + else: + model = AutoModelForCausalLM.from_pretrained(args.model) + model.eval() + print(f"Model loaded on {next(model.parameters()).device}", flush=True) + + server = HTTPServer(("0.0.0.0", args.port), InferenceHandler) + server.model = model + server.tokenizer = tokenizer + print(f"Serving on http://0.0.0.0:{args.port}/v1/completions", flush=True) + try: + server.serve_forever() + except KeyboardInterrupt: + print("Shutting down", flush=True) + server.server_close() + + +if __name__ == "__main__": + main() diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java new file mode 100644 index 00000000000..bc7817a7d17 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -0,0 +1,572 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.jmlc; + +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; + +import com.sun.net.httpserver.HttpServer; + +import org.apache.sysds.api.jmlc.Connection; +import org.apache.sysds.api.jmlc.PreparedScript; +import org.apache.sysds.api.jmlc.ResultVariables; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.test.AutomatedTestBase; +import org.junit.Assert; +import org.junit.Test; + +/** + * Tests for llmPredict built-in via JMLC. + * Needs an OpenAI-compatible server on localhost:8080. + */ +public class JMLCLLMInferenceTest extends AutomatedTestBase { + private final static String TEST_NAME = "JMLCLLMInferenceTest"; + private final static String TEST_DIR = "functions/jmlc/"; + private final static String LLM_URL = "http://localhost:8080/v1/completions"; + + private final static String DML_SCRIPT = + "prompts = read(\"prompts\", data_type=\"frame\")\n" + + + "results = llmPredict(target=prompts, url=$url, max_tokens=$mt, temperature=$temp, top_p=$tp)\n" + + "write(results, \"results\")"; + + @Override + public void setUp() { + addTestConfiguration(TEST_DIR, TEST_NAME); + getAndLoadTestConfiguration(TEST_NAME); + } + + @Test + public void testSinglePrompt() { + Connection conn = null; + try { + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", LLM_URL); + args.put("$mt", "20"); + args.put("$temp", "0.7"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + + String[][] promptData = new String[][]{{"The meaning of life is"}}; + ps.setFrame("prompts", promptData); + + ResultVariables rv = ps.executeScript(); + FrameBlock result = rv.getFrameBlock("results"); + + Assert.assertNotNull("Result should not be null", result); + Assert.assertEquals("Should have 1 row", 1, result.getNumRows()); + Assert.assertEquals("Should have 5 columns", 5, result.getNumColumns()); + String generated = result.get(0, 1).toString(); + Assert.assertFalse("Generated text should not be empty", generated.isEmpty()); + + System.out.println("Prompt: " + promptData[0][0]); + System.out.println("Generated: " + generated); + } catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException("LLM server not available", e); + } finally { + if (conn != null) conn.close(); + } + } + + @Test + public void testServerUnreachable() { + // should throw DMLRuntimeException, not hang + Connection conn = null; + try { + conn = new Connection(); + String deadUrl = "http://localhost:19999/v1/completions"; + Map args = new HashMap<>(); + args.put("$url", deadUrl); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + + String[][] promptData = new String[][]{{"Hello"}}; + ps.setFrame("prompts", promptData); + + try { + ps.executeScript(); + Assert.fail("Expected DMLRuntimeException for unreachable server"); + } + catch (DMLRuntimeException e) { + String fullMsg = getExceptionChainMessage(e); + System.out.println("Correctly caught: " + fullMsg); + Assert.assertTrue("Error should mention connection issue", + fullMsg.contains("connection refused") + || fullMsg.contains("Connection refused") + || fullMsg.contains("server is running")); + } + } + catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException( + "Could not set up negative test", e); + } + finally { + if (conn != null) conn.close(); + } + } + + @Test + public void testInvalidUrl() { + Connection conn = null; + try { + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", "not-a-valid-url"); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + + String[][] promptData = new String[][]{{"Hello"}}; + ps.setFrame("prompts", promptData); + + try { + ps.executeScript(); + Assert.fail("Expected DMLRuntimeException for invalid URL"); + } + catch (DMLRuntimeException e) { + String fullMsg = getExceptionChainMessage(e); + System.out.println("Correctly caught: " + fullMsg); + Assert.assertTrue("Error should mention invalid URL", + fullMsg.contains("invalid URL") + || fullMsg.contains("Invalid URL")); + } + } + catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException( + "Could not set up negative test", e); + } + finally { + if (conn != null) conn.close(); + } + } + + private static String getExceptionChainMessage(Throwable t) { + StringBuilder sb = new StringBuilder(); + while(t != null) { + if(sb.length() > 0) sb.append(" | "); + if(t.getMessage() != null) sb.append(t.getMessage()); + t = t.getCause(); + } + return sb.toString(); + } + + @Test + public void testConcurrency() { + Connection conn = null; + try { + conn = new Connection(); + String dmlConc = + "prompts = read(\"prompts\", data_type=\"frame\")\n" + + "results = llmPredict(target=prompts, url=$url, max_tokens=$mt, " + + "temperature=$temp, top_p=$tp, concurrency=$conc)\n" + + "write(results, \"results\")"; + Map args = new HashMap<>(); + args.put("$url", LLM_URL); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + args.put("$conc", "2"); + PreparedScript ps = conn.prepareScript(dmlConc, args, + new String[]{"prompts"}, new String[]{"results"}); + + String[][] promptData = new String[][]{ + {"Hello world"}, {"Test prompt"}, {"Another test"} + }; + ps.setFrame("prompts", promptData); + + ResultVariables rv = ps.executeScript(); + FrameBlock result = rv.getFrameBlock("results"); + + Assert.assertNotNull("Result should not be null", result); + Assert.assertEquals("Should have 3 rows", 3, result.getNumRows()); + Assert.assertEquals("Should have 5 columns", 5, result.getNumColumns()); + } catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException("LLM server not available", e); + } finally { + if (conn != null) conn.close(); + } + } + + @Test + public void testHttpErrorResponse() { + // mock server that returns HTTP 500 + HttpServer server = null; + Connection conn = null; + try { + server = HttpServer.create(new InetSocketAddress(0), 0); + int port = server.getAddress().getPort(); + server.createContext("/v1/completions", exchange -> { + byte[] resp = "{\"error\": \"internal server error\"}".getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(500, resp.length); + try(OutputStream os = exchange.getResponseBody()) { + os.write(resp); + } + }); + server.start(); + + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", "http://localhost:" + port + "/v1/completions"); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + ps.setFrame("prompts", new String[][]{{"Hello"}}); + + try { + ps.executeScript(); + Assert.fail("Expected DMLRuntimeException for HTTP 500"); + } + catch (DMLRuntimeException e) { + String fullMsg = getExceptionChainMessage(e); + System.out.println("Correctly caught HTTP 500: " + fullMsg); + Assert.assertTrue("Error should mention HTTP 500", + fullMsg.contains("HTTP 500")); + } + } + catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException( + "Could not set up mock server", e); + } + finally { + if (server != null) server.stop(0); + if (conn != null) conn.close(); + } + } + + @Test + public void testMalformedJsonResponse() { + // mock server that returns HTTP 200 with invalid JSON + HttpServer server = null; + Connection conn = null; + try { + server = HttpServer.create(new InetSocketAddress(0), 0); + int port = server.getAddress().getPort(); + server.createContext("/v1/completions", exchange -> { + byte[] resp = "this is not json at all".getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, resp.length); + try(OutputStream os = exchange.getResponseBody()) { + os.write(resp); + } + }); + server.start(); + + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", "http://localhost:" + port + "/v1/completions"); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + ps.setFrame("prompts", new String[][]{{"Hello"}}); + + try { + ps.executeScript(); + Assert.fail("Expected DMLRuntimeException for malformed JSON"); + } + catch (DMLRuntimeException e) { + String fullMsg = getExceptionChainMessage(e); + System.out.println("Correctly caught malformed JSON: " + fullMsg); + Assert.assertTrue("Error should mention response issue", + fullMsg.contains("failed") || fullMsg.contains("response")); + } + } + catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException( + "Could not set up mock server", e); + } + finally { + if (server != null) server.stop(0); + if (conn != null) conn.close(); + } + } + + @Test + public void testMissingChoicesInResponse() { + // mock server that returns valid JSON but no "choices" array + HttpServer server = null; + Connection conn = null; + try { + server = HttpServer.create(new InetSocketAddress(0), 0); + int port = server.getAddress().getPort(); + server.createContext("/v1/completions", exchange -> { + byte[] resp = "{\"id\": \"test\", \"object\": \"text_completion\"}" + .getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, resp.length); + try(OutputStream os = exchange.getResponseBody()) { + os.write(resp); + } + }); + server.start(); + + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", "http://localhost:" + port + "/v1/completions"); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + ps.setFrame("prompts", new String[][]{{"Hello"}}); + + try { + ps.executeScript(); + Assert.fail("Expected DMLRuntimeException for missing choices"); + } + catch (DMLRuntimeException e) { + String fullMsg = getExceptionChainMessage(e); + System.out.println("Correctly caught missing choices: " + fullMsg); + Assert.assertTrue("Error should mention missing choices", + fullMsg.contains("choices")); + } + } + catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException( + "Could not set up mock server", e); + } + finally { + if (server != null) server.stop(0); + if (conn != null) conn.close(); + } + } + + @Test + public void testBatchInference() { + Connection conn = null; + try { + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", LLM_URL); + args.put("$mt", "20"); + args.put("$temp", "0.7"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + + String[] prompts = { + "The meaning of life is", + "Machine learning is", + "Apache SystemDS enables" + }; + String[][] promptData = new String[prompts.length][1]; + for (int i = 0; i < prompts.length; i++) + promptData[i][0] = prompts[i]; + ps.setFrame("prompts", promptData); + + ResultVariables rv = ps.executeScript(); + FrameBlock result = rv.getFrameBlock("results"); + + Assert.assertNotNull("Result should not be null", result); + Assert.assertEquals("Should have 3 rows", 3, result.getNumRows()); + Assert.assertEquals("Should have 5 columns", 5, result.getNumColumns()); + + for (int i = 0; i < prompts.length; i++) { + String prompt = result.get(i, 0).toString(); + String generated = result.get(i, 1).toString(); + long timeMs = Long.parseLong(result.get(i, 2).toString()); + Assert.assertEquals("Prompt should match", prompts[i], prompt); + Assert.assertFalse("Generated text should not be empty", generated.isEmpty()); + Assert.assertTrue("Time should be positive", timeMs > 0); + System.out.println("Prompt: " + prompt); + System.out.println("Generated: " + generated + " (" + timeMs + "ms)"); + } + } catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException("LLM server not available", e); + } finally { + if (conn != null) conn.close(); + } + } + + @Test + public void testMockSinglePrompt() { + // mock server that returns a valid OpenAI-compatible response + // runs in CI without a real LLM server + HttpServer server = null; + Connection conn = null; + try { + server = HttpServer.create(new InetSocketAddress(0), 0); + int port = server.getAddress().getPort(); + server.createContext("/v1/completions", exchange -> { + String body = "{\"choices\":[{\"text\":\"42 is the answer\"}]," + + "\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":4}}"; + byte[] resp = body.getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, resp.length); + try(OutputStream os = exchange.getResponseBody()) { + os.write(resp); + } + }); + server.start(); + + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", "http://localhost:" + port + "/v1/completions"); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + ps.setFrame("prompts", new String[][]{{"What is 6 times 7?"}}); + + ResultVariables rv = ps.executeScript(); + FrameBlock result = rv.getFrameBlock("results"); + + Assert.assertNotNull("Result should not be null", result); + Assert.assertEquals("Should have 1 row", 1, result.getNumRows()); + Assert.assertEquals("Should have 5 columns", 5, result.getNumColumns()); + Assert.assertEquals("Generated text should match", "42 is the answer", + result.get(0, 1).toString()); + Assert.assertEquals("Input tokens should be 5", "5", + result.get(0, 3).toString()); + Assert.assertEquals("Output tokens should be 4", "4", + result.get(0, 4).toString()); + } + catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException( + "Could not set up mock server", e); + } + finally { + if (server != null) server.stop(0); + if (conn != null) conn.close(); + } + } + + @Test + public void testMockBatchPrompts() { + // mock server returning different responses per prompt + HttpServer server = null; + Connection conn = null; + try { + server = HttpServer.create(new InetSocketAddress(0), 0); + int port = server.getAddress().getPort(); + server.createContext("/v1/completions", exchange -> { + // read request to get prompt + String reqBody = new String(exchange.getRequestBody().readAllBytes(), + StandardCharsets.UTF_8); + String response; + if (reqBody.contains("first")) + response = "response-1"; + else if (reqBody.contains("second")) + response = "response-2"; + else + response = "response-3"; + String body = "{\"choices\":[{\"text\":\"" + response + "\"}]," + + "\"usage\":{\"prompt_tokens\":3,\"completion_tokens\":1}}"; + byte[] resp = body.getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, resp.length); + try(OutputStream os = exchange.getResponseBody()) { + os.write(resp); + } + }); + server.start(); + + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", "http://localhost:" + port + "/v1/completions"); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + ps.setFrame("prompts", new String[][]{ + {"first prompt"}, {"second prompt"}, {"third prompt"} + }); + + ResultVariables rv = ps.executeScript(); + FrameBlock result = rv.getFrameBlock("results"); + + Assert.assertEquals("Should have 3 rows", 3, result.getNumRows()); + Assert.assertEquals("Row 0 text", "response-1", result.get(0, 1).toString()); + Assert.assertEquals("Row 1 text", "response-2", result.get(1, 1).toString()); + Assert.assertEquals("Row 2 text", "response-3", result.get(2, 1).toString()); + } + catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException( + "Could not set up mock server", e); + } + finally { + if (server != null) server.stop(0); + if (conn != null) conn.close(); + } + } + + @Test + public void testEmptyPromptFrame() { + // empty frame (0 rows) should produce empty result, not crash + HttpServer server = null; + Connection conn = null; + try { + server = HttpServer.create(new InetSocketAddress(0), 0); + int port = server.getAddress().getPort(); + server.createContext("/v1/completions", exchange -> { + // should never be called for 0 prompts + Assert.fail("Server should not be called for empty frame"); + }); + server.start(); + + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", "http://localhost:" + port + "/v1/completions"); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + ps.setFrame("prompts", new String[0][1]); + + ResultVariables rv = ps.executeScript(); + FrameBlock result = rv.getFrameBlock("results"); + + Assert.assertNotNull("Result should not be null", result); + Assert.assertEquals("Should have 0 rows", 0, result.getNumRows()); + } + catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException( + "Could not set up test", e); + } + finally { + if (server != null) server.stop(0); + if (conn != null) conn.close(); + } + } +}