Skip to content

Commit fe1f95b

Browse files
authored
Merge pull request #1223 from evan-cao-wb/master
Add MMP Embedding method
2 parents 1e511b7 + 2a8669e commit fe1f95b

File tree

6 files changed

+294
-0
lines changed

6 files changed

+294
-0
lines changed

BotSharp.sln

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "BotSharp.Plugin.ImageHandle
149149
EndProject
150150
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "BotSharp.Plugin.FuzzySharp", "src\Plugins\BotSharp.Plugin.FuzzySharp\BotSharp.Plugin.FuzzySharp.csproj", "{E7C243B9-E751-B3B4-8F16-95C76CA90D31}"
151151
EndProject
152+
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "BotSharp.Plugin.MMPEmbedding", "src\Plugins\BotSharp.Plugin.MMPEmbedding\BotSharp.Plugin.MMPEmbedding.csproj", "{394B858B-9C26-B977-A2DA-8CC7BE5914CB}"
152153
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "BotSharp.Plugin.Membase", "src\Plugins\BotSharp.Plugin.Membase\BotSharp.Plugin.Membase.csproj", "{13223C71-9EAC-9835-28ED-5A4833E6F915}"
153154
EndProject
154155
Global
@@ -631,6 +632,14 @@ Global
631632
{E7C243B9-E751-B3B4-8F16-95C76CA90D31}.Release|Any CPU.Build.0 = Release|Any CPU
632633
{E7C243B9-E751-B3B4-8F16-95C76CA90D31}.Release|x64.ActiveCfg = Release|Any CPU
633634
{E7C243B9-E751-B3B4-8F16-95C76CA90D31}.Release|x64.Build.0 = Release|Any CPU
635+
{394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
636+
{394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Debug|Any CPU.Build.0 = Debug|Any CPU
637+
{394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Debug|x64.ActiveCfg = Debug|Any CPU
638+
{394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Debug|x64.Build.0 = Debug|Any CPU
639+
{394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Release|Any CPU.ActiveCfg = Release|Any CPU
640+
{394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Release|Any CPU.Build.0 = Release|Any CPU
641+
{394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Release|x64.ActiveCfg = Release|Any CPU
642+
{394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Release|x64.Build.0 = Release|Any CPU
634643
{13223C71-9EAC-9835-28ED-5A4833E6F915}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
635644
{13223C71-9EAC-9835-28ED-5A4833E6F915}.Debug|Any CPU.Build.0 = Debug|Any CPU
636645
{13223C71-9EAC-9835-28ED-5A4833E6F915}.Debug|x64.ActiveCfg = Debug|Any CPU
@@ -711,6 +720,7 @@ Global
711720
{FC63C875-E880-D8BB-B8B5-978AB7B62983} = {51AFE054-AE99-497D-A593-69BAEFB5106F}
712721
{242F2D93-FCCE-4982-8075-F3052ECCA92C} = {51AFE054-AE99-497D-A593-69BAEFB5106F}
713722
{E7C243B9-E751-B3B4-8F16-95C76CA90D31} = {51AFE054-AE99-497D-A593-69BAEFB5106F}
723+
{394B858B-9C26-B977-A2DA-8CC7BE5914CB} = {2635EC9B-2E5F-4313-AC21-0B847F31F36C}
714724
{13223C71-9EAC-9835-28ED-5A4833E6F915} = {53E7CD86-0D19-40D9-A0FA-AB4613837E89}
715725
EndGlobalSection
716726
GlobalSection(ExtensibilityGlobals) = postSolution
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<TargetFramework>$(TargetFramework)</TargetFramework>
5+
<ImplicitUsings>enable</ImplicitUsings>
6+
<Nullable>enable</Nullable>
7+
</PropertyGroup>
8+
9+
<ItemGroup>
10+
<PackageReference Include="Azure.AI.OpenAI" />
11+
<PackageReference Include="OpenAI" />
12+
</ItemGroup>
13+
14+
<ItemGroup>
15+
<ProjectReference Include="..\..\Infrastructure\BotSharp.Core\BotSharp.Core.csproj" />
16+
</ItemGroup>
17+
18+
</Project>
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using BotSharp.Abstraction.Plugins;
2+
using BotSharp.Plugin.MMPEmbedding.Providers;
3+
using Microsoft.Extensions.Configuration;
4+
using Microsoft.Extensions.DependencyInjection;
5+
6+
namespace BotSharp.Plugin.MMPEmbedding
7+
{
8+
public class MMPEmbeddingPlugin : IBotSharpPlugin
9+
{
10+
public string Id => "54d04e10-fc84-493e-a8c9-39da1c83f45a";
11+
public string Name => "MMPEmbedding";
12+
public string Description => "MMP Embedding Service";
13+
14+
public void RegisterDI(IServiceCollection services, IConfiguration config)
15+
{
16+
services.AddScoped<ITextEmbedding, MMPEmbeddingProvider>();
17+
}
18+
}
19+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
using OpenAI;
2+
using Azure.AI.OpenAI;
3+
using System.ClientModel;
4+
using Microsoft.Extensions.DependencyInjection;
5+
6+
namespace BotSharp.Plugin.MMPEmbedding;
7+
8+
/// <summary>
9+
/// Helper class to get the appropriate client based on provider type
10+
/// Supports multiple providers: OpenAI, Azure OpenAI, DeepSeek, etc.
11+
/// </summary>
12+
public static class ProviderHelper
13+
{
14+
/// <summary>
15+
/// Gets an OpenAI-compatible client based on the provider name
16+
/// </summary>
17+
/// <param name="provider">Provider name (e.g., "openai", "azure-openai")</param>
18+
/// <param name="model">Model name</param>
19+
/// <param name="services">Service provider for dependency injection</param>
20+
/// <returns>OpenAIClient instance configured for the specified provider</returns>
21+
public static OpenAIClient GetClient(string provider, string model, IServiceProvider services)
22+
{
23+
var settingsService = services.GetRequiredService<ILlmProviderService>();
24+
var settings = settingsService.GetSetting(provider, model);
25+
26+
if (settings == null)
27+
{
28+
throw new InvalidOperationException($"Cannot find settings for provider '{provider}' and model '{model}'");
29+
}
30+
31+
// Handle Azure OpenAI separately as it uses AzureOpenAIClient
32+
if (provider.Equals("azure-openai", StringComparison.OrdinalIgnoreCase))
33+
{
34+
return GetAzureOpenAIClient(settings);
35+
}
36+
37+
// For OpenAI, DeepSeek, and other OpenAI-compatible providers
38+
return GetOpenAICompatibleClient(settings);
39+
}
40+
41+
/// <summary>
42+
/// Gets an Azure OpenAI client
43+
/// </summary>
44+
private static OpenAIClient GetAzureOpenAIClient(LlmModelSetting settings)
45+
{
46+
if (string.IsNullOrEmpty(settings.Endpoint))
47+
{
48+
throw new InvalidOperationException("Azure OpenAI endpoint is required");
49+
}
50+
51+
var client = new AzureOpenAIClient(
52+
new Uri(settings.Endpoint),
53+
new ApiKeyCredential(settings.ApiKey)
54+
);
55+
56+
return client;
57+
}
58+
59+
/// <summary>
60+
/// Gets an OpenAI-compatible client (OpenAI, DeepSeek, etc.)
61+
/// </summary>
62+
private static OpenAIClient GetOpenAICompatibleClient(LlmModelSetting settings)
63+
{
64+
var options = !string.IsNullOrEmpty(settings.Endpoint)
65+
? new OpenAIClientOptions { Endpoint = new Uri(settings.Endpoint) }
66+
: null;
67+
68+
return new OpenAIClient(new ApiKeyCredential(settings.ApiKey), options);
69+
}
70+
}
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
using System.Collections.Generic;
2+
using System.Text.RegularExpressions;
3+
using BotSharp.Plugin.MMPEmbedding;
4+
using Microsoft.Extensions.DependencyInjection;
5+
using Microsoft.Extensions.Logging;
6+
using OpenAI.Embeddings;
7+
8+
namespace BotSharp.Plugin.MMPEmbedding.Providers;
9+
10+
/// <summary>
11+
/// Text embedding provider that uses Mean-Max Pooling strategy
12+
/// This provider gets embeddings for individual tokens and combines them using mean and max pooling
13+
/// </summary>
14+
public class MMPEmbeddingProvider : ITextEmbedding
15+
{
16+
protected readonly IServiceProvider _serviceProvider;
17+
protected readonly ILogger<MMPEmbeddingProvider> _logger;
18+
19+
private const int DEFAULT_DIMENSION = 1536;
20+
protected string _model = "text-embedding-3-small";
21+
protected int _dimension = DEFAULT_DIMENSION;
22+
23+
// The underlying provider to use (e.g., "openai", "azure-openai", "deepseek-ai")
24+
protected string _underlyingProvider = "openai";
25+
26+
public string Provider => "mmp-embedding";
27+
public string Model => _model;
28+
29+
private static readonly Regex WordRegex = new(@"\b\w+\b", RegexOptions.Compiled);
30+
31+
public MMPEmbeddingProvider(IServiceProvider serviceProvider, ILogger<MMPEmbeddingProvider> logger)
32+
{
33+
_serviceProvider = serviceProvider;
34+
_logger = logger;
35+
}
36+
37+
/// <summary>
38+
/// Gets a single embedding vector using mean-max pooling
39+
/// </summary>
40+
public async Task<float[]> GetVectorAsync(string text)
41+
{
42+
if (string.IsNullOrWhiteSpace(text))
43+
{
44+
return new float[_dimension];
45+
}
46+
47+
var tokens = Tokenize(text).ToList();
48+
49+
if (tokens.Count == 0)
50+
{
51+
return new float[_dimension];
52+
}
53+
54+
// Get embeddings for all tokens
55+
var tokenEmbeddings = await GetTokenEmbeddingsAsync(tokens);
56+
57+
// Apply mean-max pooling
58+
var pooledEmbedding = MeanMaxPooling(tokenEmbeddings);
59+
60+
return pooledEmbedding;
61+
}
62+
63+
/// <summary>
64+
/// Gets multiple embedding vectors using mean-max pooling
65+
/// </summary>
66+
public async Task<List<float[]>> GetVectorsAsync(List<string> texts)
67+
{
68+
var results = new List<float[]>();
69+
70+
foreach (var text in texts)
71+
{
72+
var embedding = await GetVectorAsync(text);
73+
results.Add(embedding);
74+
}
75+
76+
return results;
77+
}
78+
79+
/// <summary>
80+
/// Gets embeddings for individual tokens using the underlying provider
81+
/// </summary>
82+
private async Task<List<float[]>> GetTokenEmbeddingsAsync(List<string> tokens)
83+
{
84+
try
85+
{
86+
// Get the appropriate client based on the underlying provider
87+
var client = ProviderHelper.GetClient(_underlyingProvider, _model, _serviceProvider);
88+
var embeddingClient = client.GetEmbeddingClient(_model);
89+
90+
// Prepare options
91+
var options = new EmbeddingGenerationOptions
92+
{
93+
Dimensions = _dimension > 0 ? _dimension : null
94+
};
95+
96+
// Get embeddings for all tokens in batch
97+
var response = await embeddingClient.GenerateEmbeddingsAsync(tokens, options);
98+
var embeddings = response.Value;
99+
100+
return embeddings.Select(e => e.ToFloats().ToArray()).ToList();
101+
}
102+
catch (Exception ex)
103+
{
104+
_logger.LogError(ex, "Error getting token embeddings from provider {Provider} with model {Model}",
105+
_underlyingProvider, _model);
106+
throw;
107+
}
108+
}
109+
110+
/// <summary>
111+
/// Applies mean-max pooling to combine token embeddings
112+
/// Mean pooling: average of all token embeddings
113+
/// Max pooling: element-wise maximum of all token embeddings
114+
/// Result: concatenation of mean and max pooled vectors
115+
/// </summary>
116+
private float[] MeanMaxPooling(IReadOnlyList<float[]> vectors, double meanWeight = 0.5, double maxWeight = 0.5)
117+
{
118+
var numTokens = vectors.Count;
119+
120+
if (numTokens == 0)
121+
return [];
122+
123+
var meanPooled = Enumerable.Range(0, _dimension)
124+
.Select(i => vectors.Average(v => v[i]))
125+
.ToArray();
126+
var maxPooled = Enumerable.Range(0, _dimension)
127+
.Select(i => vectors.Max(v => v[i]))
128+
.ToArray();
129+
130+
return Enumerable.Range(0, _dimension)
131+
.Select(i => (float)meanWeight * meanPooled[i] + (float)maxWeight * maxPooled[i])
132+
.ToArray();
133+
}
134+
135+
public void SetDimension(int dimension)
136+
{
137+
_dimension = dimension > 0 ? dimension : DEFAULT_DIMENSION;
138+
}
139+
140+
public int GetDimension()
141+
{
142+
return _dimension;
143+
}
144+
145+
public void SetModelName(string model)
146+
{
147+
_model = model;
148+
}
149+
150+
/// <summary>
151+
/// Sets the underlying provider to use for getting token embeddings
152+
/// </summary>
153+
/// <param name="provider">Provider name (e.g., "openai", "azure-openai", "deepseek-ai")</param>
154+
public void SetUnderlyingProvider(string provider)
155+
{
156+
_underlyingProvider = provider;
157+
}
158+
159+
/// <summary>
160+
/// Tokenizes text into individual words
161+
/// </summary>
162+
public static IEnumerable<string> Tokenize(string text, string? pattern = null)
163+
{
164+
var patternRegex = string.IsNullOrEmpty(pattern) ? WordRegex : new(pattern, RegexOptions.Compiled);
165+
return patternRegex.Matches(text).Cast<Match>().Select(m => m.Value);
166+
}
167+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
global using System;
2+
global using System.Collections.Generic;
3+
global using System.Linq;
4+
global using System.Text;
5+
global using System.Threading.Tasks;
6+
7+
global using BotSharp.Abstraction.MLTasks;
8+
global using BotSharp.Abstraction.MLTasks.Settings;
9+
global using Microsoft.Extensions.Logging;
10+

0 commit comments

Comments
 (0)