11package io .quarkiverse .langchain4j .gemini .common ;
22
3+ import java .util .ArrayList ;
34import java .util .List ;
45
56import dev .langchain4j .data .embedding .Embedding ;
910
1011public abstract class GeminiEmbeddingModel implements EmbeddingModel {
1112
13+ private static final int MAX_NUMBER_OF_SEGMENTS_PER_BATCH = 100 ;
1214 private final String modelId ;
1315 private final Integer dimension ;
1416 private final String taskType ;
@@ -34,15 +36,24 @@ public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
3436 List <EmbedContentRequest > embedContentRequests = textSegments .stream ()
3537 .map (textSegment -> getEmbedContentRequest (modelId , textSegment .text ()))
3638 .toList ();
37-
38- EmbedContentResponses embedContentResponses = batchEmbedContents (
39- new EmbedContentRequests (embedContentRequests ));
40-
41- List <Embedding > embeddings = embedContentResponses .embeddings ()
42- .stream ()
43- .map (embedding -> Embedding .from (embedding .values ()))
44- .toList ();
45- return Response .from (embeddings );
39+ List <Embedding > allEmbeddings = new ArrayList <>();
40+ int numberOfEmbeddings = embedContentRequests .size ();
41+ int numberOfBatches = 1 + numberOfEmbeddings / MAX_NUMBER_OF_SEGMENTS_PER_BATCH ;
42+
43+ for (int i = 0 ; i < numberOfBatches ; i ++) {
44+ int startIndex = MAX_NUMBER_OF_SEGMENTS_PER_BATCH * i ;
45+ int lastIndex = Math .min (startIndex + MAX_NUMBER_OF_SEGMENTS_PER_BATCH , numberOfEmbeddings );
46+
47+ if (startIndex >= numberOfEmbeddings )
48+ break ;
49+
50+ EmbedContentResponses embedContentResponses = batchEmbedContents (
51+ new EmbedContentRequests (embedContentRequests .subList (startIndex , lastIndex )));
52+ embedContentResponses .embeddings ().stream ()
53+ .map (embedding -> Embedding .from (embedding .values ()))
54+ .forEach (allEmbeddings ::add );
55+ }
56+ return Response .from (allEmbeddings );
4657 }
4758
4859 private EmbedContentRequest getEmbedContentRequest (String model , String text ) {
0 commit comments