Skip to content

Commit ab2b900

Browse files
authored
Merge pull request #1937 from orionpapadakis/fix/gpu-resources-utilization
Fix GPU resources utilization
2 parents 11b8d5e + 704c4e5 commit ab2b900

File tree

4 files changed

+231
-106
lines changed

4 files changed

+231
-106
lines changed

docs/modules/ROOT/pages/gpullama3-chat-model.adoc

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ The above steps:
2929
3030
- Set the `TORNADOVM_SDK` environment variable to the TornadoVM SDK path.
3131
- Create a `tornado-argfile` under `~/TornadoVM` containing the JVM arguments required to enable TornadoVM.
32-
- The `tornado-argfile` is automatically used in Quarkus *dev mode*.
33-
- For *production mode*, you must manually pass the argfile to the JVM (see step 3).
32+
- ⚠️ The `tornado-argfile` should be used for *building* and *running* the Quarkus application (see section Building & Running the Quarkus Application).
3433
3534
== Using GPULlama3.java
3635
@@ -130,6 +129,66 @@ quarkus.langchain4j.gpu-llama3.chat-model.max-tokens=1024
130129
131130
Model files are automatically downloaded from https://huggingface.co/beehive-lab[Beehive Lab HuggingFace] if not available locally.
132131
132+
== Building & Running the Quarkus Application
133+
134+
=== Dev Mode
135+
136+
To run your Quarkus application in **dev mode** with TornadoVM:
137+
138+
1. Ensure your `pom.xml` contains the `quarkus-langchain4j-gpu-llama3` dependency (shown earlier).
139+
140+
2. Add the TornadoVM argfile as a Maven property:
141+
142+
[source,xml]
143+
----
144+
<properties>
145+
<tornado.argfile>/path/to/tornado-argfile</tornado.argfile>
146+
</properties>
147+
----
148+
149+
3. Pass the argfile to the JVM in the plugin configuration for dev mode:
150+
151+
[source,xml]
152+
----
153+
<plugin>
154+
<groupId>io.quarkus</groupId>
155+
<artifactId>quarkus-maven-plugin</artifactId>
156+
<configuration>
157+
<jvmArgs>@${tornado.argfile}</jvmArgs>
158+
</configuration>
159+
</plugin>
160+
----
161+
162+
4. Launch dev mode explicitly:
163+
164+
[source,shell]
165+
----
166+
mvn quarkus:dev
167+
----
168+
169+
---
170+
171+
=== Production Mode
172+
173+
To build and run your application in **production mode**:
174+
175+
1. Build the Quarkus application:
176+
177+
[source,shell]
178+
----
179+
mvn clean package
180+
----
181+
182+
2. Run the generated jar with the TornadoVM argfile:
183+
184+
[source,shell]
185+
----
186+
java @/path/to/tornado-argfile -jar target/quarkus-app/quarkus-run.jar
187+
----
188+
189+
⚠ **Important:** Ensure `TORNADOVM_SDK` and the `tornado-argfile` path are correctly set.
190+
191+
133192
== Supported Models and Quantizations
134193
135194
The following models have been tested with GPULlama3.java and can be found in link:++https://huggingface.co/beehive-lab/collections[Beehive Lab's HuggingFace Collections].

model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/GPULlama3ChatModel.java

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import java.nio.file.Path;
88
import java.util.Optional;
99

10+
import org.jboss.logging.Logger;
11+
1012
import dev.langchain4j.data.message.AiMessage;
1113
import dev.langchain4j.internal.ChatRequestValidationUtils;
1214
import dev.langchain4j.model.chat.ChatModel;
@@ -16,19 +18,83 @@
1618

1719
public class GPULlama3ChatModel extends GPULlama3BaseModel implements ChatModel {
1820

19-
// @formatter:off
21+
private static final Logger LOG = Logger.getLogger(GPULlama3ChatModel.class);
22+
23+
private final Builder builderConfig;
24+
private volatile boolean initialized = false;
25+
26+
/**
27+
* Default constructor.
28+
*
29+
* @param builder
30+
*/
2031
private GPULlama3ChatModel(Builder builder) {
32+
this(builder, false);
33+
}
34+
35+
/**
36+
* Constructor with lazy initialization.
37+
*
38+
* @param builder the builder used to configure the model.
39+
* @param lazy if true, the model is not initialized until the first call to doChat.
40+
*/
41+
private GPULlama3ChatModel(Builder builder, boolean lazy) {
42+
if (lazy) {
43+
// lazy initialization
44+
this.builderConfig = builder;
45+
} else {
46+
this.builderConfig = null;
47+
// original immediate initialization
48+
doInitialization(builder);
49+
}
50+
}
51+
52+
/**
53+
* The factory method for creating a lazy initialized model.
54+
*
55+
* @param builder the builder used to configure the model.
56+
* @return the model.
57+
*/
58+
public static GPULlama3ChatModel createLazy(Builder builder) {
59+
return new GPULlama3ChatModel(builder, true);
60+
}
61+
62+
/**
63+
* Ensure that the model is initialized.
64+
*/
65+
private void ensureInitialized() {
66+
if (!initialized && builderConfig != null) {
67+
if (!initialized) {
68+
doInitialization(builderConfig);
69+
initialized = true;
70+
}
71+
}
72+
}
73+
74+
// @formatter:off
75+
/**
76+
* Performs the actual initialization.
77+
*/
78+
private void doInitialization(Builder builder) {
2179
GPULlama3ModelRegistry gpuLlama3ModelRegistry = GPULlama3ModelRegistry.getOrCreate(builder.modelCachePath);
2280
try {
2381
Path modelPath = gpuLlama3ModelRegistry.downloadModel(builder.modelName, builder.quantization,
2482
Optional.empty(), Optional.empty());
25-
init(
26-
modelPath,
27-
getOrDefault(builder.temperature, 0.1),
28-
getOrDefault(builder.topP, 1.0),
29-
getOrDefault(builder.seed, 12345),
30-
getOrDefault(builder.maxTokens, 512),
31-
getOrDefault(builder.onGPU, Boolean.TRUE));
83+
Double temp = getOrDefault(builder.temperature, 0.1);
84+
Double topP = getOrDefault(builder.topP, 1.0);
85+
Integer seed = getOrDefault(builder.seed, 12345);
86+
Integer maxTokens = getOrDefault(builder.maxTokens, 512);
87+
Boolean onGPU = getOrDefault(builder.onGPU, Boolean.TRUE);
88+
89+
LOG.info("GPULlama3ChatModel Instantiation {modelPath=" + modelPath +
90+
", temperature=" + temp +
91+
", topP=" + topP +
92+
", seed=" + seed +
93+
", maxTokens=" + maxTokens +
94+
", onGPU=" + onGPU + "}...");
95+
96+
init(modelPath, temp, topP, seed, maxTokens, onGPU);
97+
LOG.info("GPULlama3ChatModel Instantiation Complete!");
3298
} catch (IOException e) {
3399
throw new UncheckedIOException(e);
34100
} catch (InterruptedException e) {
@@ -43,6 +109,8 @@ public static Builder builder() {
43109

44110
@Override
45111
public ChatResponse doChat(ChatRequest chatRequest) {
112+
ensureInitialized(); // If in lazy path, init model
113+
46114
ChatRequestValidationUtils.validateMessages(chatRequest.messages());
47115
ChatRequestParameters parameters = chatRequest.parameters();
48116
ChatRequestValidationUtils.validateParameters(parameters);

model-providers/gpu-llama3/runtime/src/main/java/io/quarkiverse/langchain4j/gpullama3/GPULlama3StreamingChatModel.java

Lines changed: 69 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import java.io.UncheckedIOException;
88
import java.nio.file.Path;
99
import java.util.Optional;
10-
import java.util.concurrent.CompletableFuture;
11-
import java.util.concurrent.atomic.AtomicBoolean;
1210

1311
import org.jboss.logging.Logger;
1412

@@ -50,20 +48,68 @@ public class GPULlama3StreamingChatModel extends GPULlama3BaseModel implements S
5048

5149
private static final Logger LOG = Logger.getLogger(GPULlama3StreamingChatModel.class);
5250

53-
// Fields to track initialization state
54-
private final CompletableFuture<Void> initializationFuture = new CompletableFuture<>();
55-
private final AtomicBoolean initialized = new AtomicBoolean(false);
51+
private final Builder builderConfig;
52+
private volatile boolean initialized = false;
53+
54+
private GPULlama3StreamingChatModel(Builder builder, boolean lazy) {
55+
if (lazy) {
56+
this.builderConfig = builder;
57+
// Don't initialize yet!
58+
} else {
59+
this.builderConfig = null;
60+
// Original background initialization
61+
runOutEventLoop(() -> {
62+
LOG.debug("Starting GPULlama3 StreamingChatModel initialization on worker thread");
63+
doInitialization(builder);
64+
initialized = true;
65+
});
66+
}
67+
}
5668

5769
private GPULlama3StreamingChatModel(Builder builder) {
58-
// Schedule the initialization to happen on a background thread
59-
runOutEventLoop(() -> {
60-
LOG.debug("Starting GPULlama3 model initialization on worker thread");
61-
coreInit(builder);
62-
});
70+
this(builder, false); // Default to original background initialization
6371
}
6472

65-
public static Builder builder() {
66-
return new Builder();
73+
// Add factory method for lazy initialization
74+
public static GPULlama3StreamingChatModel createLazy(Builder builder) {
75+
return new GPULlama3StreamingChatModel(builder, true);
76+
}
77+
78+
private void ensureInitialized() {
79+
if (!initialized && builderConfig != null) {
80+
if (!initialized) {
81+
LOG.debug("Lazy initialization of GPULlama3StreamingChatModel");
82+
doInitialization(builderConfig);
83+
initialized = true;
84+
}
85+
}
86+
}
87+
88+
private void doInitialization(Builder builder) {
89+
GPULlama3ModelRegistry gpuLlama3ModelRegistry = GPULlama3ModelRegistry.getOrCreate(builder.modelCachePath);
90+
try {
91+
Path modelPath = gpuLlama3ModelRegistry.downloadModel(builder.modelName, builder.quantization,
92+
Optional.empty(), Optional.empty());
93+
Double temp = getOrDefault(builder.temperature, 0.1);
94+
Double topP = getOrDefault(builder.topP, 1.0);
95+
Integer seed = getOrDefault(builder.seed, 12345);
96+
Integer maxTokens = getOrDefault(builder.maxTokens, 512);
97+
Boolean onGPU = getOrDefault(builder.onGPU, Boolean.TRUE);
98+
99+
LOG.info("GPULlama3StreamingChatModel Instantiation {modelPath=" + modelPath +
100+
", temperature=" + temp +
101+
", topP=" + topP +
102+
", seed=" + seed +
103+
", maxTokens=" + maxTokens +
104+
", onGPU=" + onGPU + "}...");
105+
106+
init(modelPath, temp, topP, seed, maxTokens, onGPU);
107+
LOG.info("GPULlama3StreamingChatModel Instantiation Complete!");
108+
} catch (IOException e) {
109+
throw new UncheckedIOException(e);
110+
} catch (InterruptedException e) {
111+
throw new RuntimeException(e);
112+
}
67113
}
68114

69115
@Override
@@ -75,59 +121,18 @@ public void doChat(ChatRequest chatRequest, StreamingChatResponseHandler handler
75121
ChatRequestValidationUtils.validate(parameters.responseFormat());
76122

77123
// Run the GPU operations on a worker thread using runOutEventLoop
78-
runOutEventLoop(new Runnable() {
79-
@Override
80-
public void run() {
81-
// Wait for initialization to complete if it hasn't yet
82-
if (!initialized.get()) {
83-
LOG.debug("Waiting for model initialization to complete");
84-
try {
85-
initializationFuture.get();
86-
} catch (Exception e) {
87-
LOG.error("Failed to initialize model", e);
88-
handler.onError(e);
89-
return;
90-
}
91-
}
124+
runOutEventLoop(() -> {
125+
try {
126+
ensureInitialized(); // Build happens HERE on first call!
92127
LOG.debug("Executing GPU Llama inference on worker thread");
93128
coreDoChat(chatRequest, handler);
94-
LOG.debug("GPULlama3 model initialization completed");
129+
} catch (Exception e) {
130+
LOG.error("Failed during lazy initialization or inference", e);
131+
handler.onError(e);
95132
}
96133
});
97134
}
98135

99-
/**
100-
* The actual initialization logic.
101-
* It is called by a worker thread in a non-blocking manner.
102-
*/
103-
private void coreInit(Builder builder) {
104-
GPULlama3ModelRegistry gpuLlama3ModelRegistry = GPULlama3ModelRegistry.getOrCreate(builder.modelCachePath);
105-
try {
106-
Path modelPath = gpuLlama3ModelRegistry.downloadModel(builder.modelName, builder.quantization,
107-
Optional.empty(), Optional.empty());
108-
init(
109-
modelPath,
110-
getOrDefault(builder.temperature, 0.1),
111-
getOrDefault(builder.topP, 1.0),
112-
getOrDefault(builder.seed, 12345),
113-
getOrDefault(builder.maxTokens, 512),
114-
getOrDefault(builder.onGPU, Boolean.TRUE));
115-
116-
// Mark initialization as complete
117-
initialized.set(true);
118-
initializationFuture.complete(null);
119-
} catch (IOException e) {
120-
initializationFuture.completeExceptionally(new UncheckedIOException(e));
121-
throw new UncheckedIOException(e);
122-
} catch (InterruptedException e) {
123-
initializationFuture.completeExceptionally(e);
124-
throw new RuntimeException(e);
125-
} catch (Exception e) {
126-
initializationFuture.completeExceptionally(e);
127-
throw e;
128-
}
129-
}
130-
131136
/**
132137
* The actual doChat logic.
133138
* It is called by a worker thread in a non-blocking manner.
@@ -152,11 +157,15 @@ private void coreDoChat(ChatRequest chatRequest, StreamingChatResponseHandler ha
152157

153158
handler.onCompleteResponse(chatResponse);
154159
} catch (Exception e) {
155-
LOG.error("Error in GPULlama3 asyncDoChat", e);
160+
LOG.error("Error in GPULlama3 coreDoChat", e);
156161
handler.onError(e);
157162
}
158163
}
159164

165+
public static Builder builder() {
166+
return new Builder();
167+
}
168+
160169
public static class Builder {
161170

162171
private Optional<Path> modelCachePath;

0 commit comments

Comments
 (0)