Skip to content

Commit 63d644a

Browse files
committed
Automatically sort segments for embedding with gemini into batches of 100
1 parent 55a3f3a commit 63d644a

File tree

1 file changed

+20
-9
lines changed
  • model-providers/google/gemini/gemini-common/runtime/src/main/java/io/quarkiverse/langchain4j/gemini/common

1 file changed

+20
-9
lines changed

model-providers/google/gemini/gemini-common/runtime/src/main/java/io/quarkiverse/langchain4j/gemini/common/GeminiEmbeddingModel.java

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.quarkiverse.langchain4j.gemini.common;
22

3+
import java.util.ArrayList;
34
import java.util.List;
45

56
import dev.langchain4j.data.embedding.Embedding;
@@ -9,6 +10,7 @@
910

1011
public abstract class GeminiEmbeddingModel implements EmbeddingModel {
1112

13+
private static final int MAX_NUMBER_OF_SEGMENTS_PER_BATCH = 100;
1214
private final String modelId;
1315
private final Integer dimension;
1416
private final String taskType;
@@ -34,15 +36,24 @@ public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
3436
List<EmbedContentRequest> embedContentRequests = textSegments.stream()
3537
.map(textSegment -> getEmbedContentRequest(modelId, textSegment.text()))
3638
.toList();
37-
38-
EmbedContentResponses embedContentResponses = batchEmbedContents(
39-
new EmbedContentRequests(embedContentRequests));
40-
41-
List<Embedding> embeddings = embedContentResponses.embeddings()
42-
.stream()
43-
.map(embedding -> Embedding.from(embedding.values()))
44-
.toList();
45-
return Response.from(embeddings);
39+
List<Embedding> allEmbeddings = new ArrayList<>();
40+
int numberOfEmbeddings = embedContentRequests.size();
41+
int numberOfBatches = 1 + numberOfEmbeddings / MAX_NUMBER_OF_SEGMENTS_PER_BATCH;
42+
43+
for (int i = 0; i < numberOfBatches; i++) {
44+
int startIndex = MAX_NUMBER_OF_SEGMENTS_PER_BATCH * i;
45+
int lastIndex = Math.min(startIndex + MAX_NUMBER_OF_SEGMENTS_PER_BATCH, numberOfEmbeddings);
46+
47+
if (startIndex >= numberOfEmbeddings)
48+
break;
49+
50+
EmbedContentResponses embedContentResponses = batchEmbedContents(
51+
new EmbedContentRequests(embedContentRequests.subList(startIndex, lastIndex)));
52+
embedContentResponses.embeddings().stream()
53+
.map(embedding -> Embedding.from(embedding.values()))
54+
.forEach(allEmbeddings::add);
55+
}
56+
return Response.from(allEmbeddings);
4657
}
4758

4859
private EmbedContentRequest getEmbedContentRequest(String model, String text) {

0 commit comments

Comments
 (0)