Skip to content

Commit ad97c3a

Browse files
authored
.Net: [MEVD] Support BinaryEmbedding with PostgreSQL (#13322)
The main motivation here was to ensure that our infrastructure and tests properly support binary embeddings (and more in general, non-float/half ones). Closes #13321
1 parent b03c3d0 commit ad97c3a

File tree

6 files changed

+48
-12
lines changed

6 files changed

+48
-12
lines changed

dotnet/src/VectorData/PgVector/PostgresCollection.cs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@ public override async Task UpsertAsync(TRecord record, CancellationToken cancell
199199
generatedEmbeddings[vectorProperty] = [await halfTask.ConfigureAwait(false)];
200200
}
201201
#endif
202+
else if (vectorProperty.TryGenerateEmbedding<TRecord, BinaryEmbedding>(record, cancellationToken, out var binaryTask))
203+
{
204+
generatedEmbeddings ??= new Dictionary<VectorPropertyModel, IReadOnlyList<Embedding>>(vectorPropertyCount);
205+
generatedEmbeddings[vectorProperty] = [await binaryTask.ConfigureAwait(false)];
206+
}
202207
else
203208
{
204209
throw new InvalidOperationException(
@@ -420,10 +425,9 @@ _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embeddin
420425

421426
// Dense Binary
422427
BitArray b => b,
423-
// TODO: Uncomment once we sync to the latest MEAI
424-
// BinaryEmbedding e => e.Vector,
425-
// _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TVector, BinaryEmbedding> generator
426-
// => (await generator.GenerateEmbeddingAsync(value, cancellationToken: cancellationToken).ConfigureAwait(false)).Vector,
428+
BinaryEmbedding e => e.Vector,
429+
_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, BinaryEmbedding> generator
430+
=> await generator.GenerateAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false),
427431

428432
// Sparse
429433
SparseVector sv => sv,

dotnet/src/VectorData/PgVector/PostgresMapper.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ public TRecord MapFromStorageToDataModel(NpgsqlDataReader reader, bool includeVe
7777
}
7878
#endif
7979

80+
case BitArray bitArray when vectorProperty.Type == typeof(BinaryEmbedding):
81+
vectorProperty.SetValueAsObject(record, new BinaryEmbedding(bitArray));
82+
continue;
83+
8084
case BitArray bitArray:
8185
vectorProperty.SetValueAsObject(record, bitArray);
8286
continue;

dotnet/src/VectorData/PgVector/PostgresModelBuilder.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace Microsoft.SemanticKernel.Connectors.PgVector;
1212

1313
internal class PostgresModelBuilder() : CollectionModelBuilder(PostgresModelBuilder.ModelBuildingOptions)
1414
{
15-
internal const string SupportedVectorTypes = "ReadOnlyMemory<float>, Embedding<float>, float[], ReadOnlyMemory<Half>, Embedding<Half>, Half[], BitArray, or SparseVector";
15+
internal const string SupportedVectorTypes = "ReadOnlyMemory<float>, Embedding<float>, float[], ReadOnlyMemory<Half>, Embedding<Half>, Half[], BinaryEmbedding, BitArray, or SparseVector";
1616

1717
public static readonly CollectionModelBuildingOptions ModelBuildingOptions = new()
1818
{
@@ -80,6 +80,7 @@ internal static bool IsVectorPropertyTypeValidCore(Type type, [NotNullWhen(false
8080
type == typeof(Embedding<Half>) ||
8181
type == typeof(Half[]) ||
8282
#endif
83+
type == typeof(BinaryEmbedding) ||
8384
type == typeof(BitArray) ||
8485
type == typeof(SparseVector);
8586
}
@@ -93,5 +94,5 @@ internal static bool IsVectorPropertyTypeValidCore(Type type, [NotNullWhen(false
9394
#if NET8_0_OR_GREATER
9495
?? vectorProperty.ResolveEmbeddingType<Embedding<Half>>(embeddingGenerator, userRequestedEmbeddingType)
9596
#endif
96-
;
97+
?? vectorProperty.ResolveEmbeddingType<BinaryEmbedding>(embeddingGenerator, userRequestedEmbeddingType);
9798
}

dotnet/src/VectorData/PgVector/PostgresPropertyMapping.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ internal static class PostgresPropertyMapping
3030
#endif
3131

3232
BitArray bitArray => bitArray,
33+
BinaryEmbedding binaryEmbedding => binaryEmbedding.Vector,
3334
SparseVector sparseVector => sparseVector,
3435

3536
null => null,
@@ -136,6 +137,7 @@ public static (string PgType, bool IsNullable) GetPgVectorTypeName(VectorPropert
136137

137138
Type t when t == typeof(SparseVector) => "SPARSEVEC",
138139
Type t when t == typeof(BitArray) => "BIT",
140+
Type t when t == typeof(BinaryEmbedding) => "BIT",
139141

140142
_ => throw new NotSupportedException($"Type {vectorProperty.EmbeddingType.Name} is not supported by this store.")
141143
};

dotnet/test/VectorData/PgVector.ConformanceTests/TypeTests/PostgresEmbeddingTypeTests.cs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

33
using System.Collections;
4-
#if NET8_0_OR_GREATER
54
using Microsoft.Extensions.AI;
6-
#endif
75
using Microsoft.Extensions.VectorData;
86
using Pgvector;
97
using PgVector.ConformanceTests.Support;
@@ -41,16 +39,27 @@ public virtual Task Array_of_Half()
4139
new ReadOnlyMemoryEmbeddingGenerator<Half>([(byte)1, (byte)2, (byte)3]));
4240
#endif
4341

44-
// TODO: Figure out the embedding generation story for binaryvec/sparsevec - need an Embedding wrapper
45-
4642
[ConditionalFact]
4743
public virtual Task BitArray()
48-
=> this.Test<BitArray>(new BitArray(new bool[] { true, false, true }), distanceFunction: DistanceFunction.HammingDistance, embeddingGenerator: null);
44+
=> this.Test<BitArray>(
45+
new BitArray([true, false, true]),
46+
new BinaryEmbeddingGenerator(new BitArray([true, false, true])),
47+
distanceFunction: DistanceFunction.HammingDistance);
48+
49+
[ConditionalFact]
50+
public virtual Task BinaryEmbedding()
51+
=> this.Test<BinaryEmbedding>(
52+
new BinaryEmbedding(new([true, false, true])),
53+
new BinaryEmbeddingGenerator(new BitArray([true, false, true])),
54+
distanceFunction: DistanceFunction.HammingDistance,
55+
vectorEqualityAsserter: (e, a) => Assert.Equal(e.Vector, a.Vector));
4956

5057
[ConditionalFact]
5158
public virtual Task SparseVector()
5259
=> this.Test<SparseVector>(new SparseVector(new ReadOnlyMemory<float>([1, 2, 3])), embeddingGenerator: null);
5360

61+
// TODO: Figure out the embedding generation story for sparsevec - need an Embedding wrapper
62+
5463
public new class Fixture : EmbeddingTypeTests<int>.Fixture
5564
{
5665
public override TestStore TestStore => PostgresTestStore.Instance;

dotnet/test/VectorData/VectorData.ConformanceTests/TypeTests/EmbeddingTypeTests.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using System.Collections;
34
using Microsoft.Extensions.AI;
45
using Microsoft.Extensions.VectorData;
56
using VectorData.ConformanceTests.Support;
@@ -178,13 +179,28 @@ protected virtual async Task Test<TVector>(
178179

179180
protected sealed class ReadOnlyMemoryEmbeddingGenerator<T>(T[] data) : IEmbeddingGenerator<string, Embedding<T>>
180181
{
181-
public Task<GeneratedEmbeddings<Embedding<T>>> GenerateAsync(IEnumerable<string> values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default)
182+
public Task<GeneratedEmbeddings<Embedding<T>>> GenerateAsync(
183+
IEnumerable<string> values,
184+
EmbeddingGenerationOptions? options = null,
185+
CancellationToken cancellationToken = default)
182186
=> Task.FromResult(new GeneratedEmbeddings<Embedding<T>>([new(data)]));
183187

184188
public object? GetService(Type serviceType, object? serviceKey = null) => null;
185189
public void Dispose() { }
186190
}
187191

192+
protected sealed class BinaryEmbeddingGenerator(BitArray data) : IEmbeddingGenerator<string, BinaryEmbedding>
193+
{
194+
public Task<GeneratedEmbeddings<BinaryEmbedding>> GenerateAsync(
195+
IEnumerable<string> values,
196+
EmbeddingGenerationOptions? options = null,
197+
CancellationToken cancellationToken = default)
198+
=> Task.FromResult(new GeneratedEmbeddings<BinaryEmbedding>([new(data)]));
199+
200+
public object? GetService(Type serviceType, object? serviceKey = null) => null;
201+
public void Dispose() { }
202+
}
203+
188204
public class RecordWithString
189205
{
190206
public TKey Key { get; set; }

0 commit comments

Comments
 (0)