|
| 1 | +using System.Collections.Generic; |
| 2 | +using System.Text.RegularExpressions; |
| 3 | +using BotSharp.Plugin.MMPEmbedding; |
| 4 | +using Microsoft.Extensions.DependencyInjection; |
| 5 | +using Microsoft.Extensions.Logging; |
| 6 | +using OpenAI.Embeddings; |
| 7 | + |
| 8 | +namespace BotSharp.Plugin.MMPEmbedding.Providers; |
| 9 | + |
| 10 | +/// <summary> |
| 11 | +/// Text embedding provider that uses Mean-Max Pooling strategy |
| 12 | +/// This provider gets embeddings for individual tokens and combines them using mean and max pooling |
| 13 | +/// </summary> |
| 14 | +public class MMPEmbeddingProvider : ITextEmbedding |
| 15 | +{ |
| 16 | + protected readonly IServiceProvider _serviceProvider; |
| 17 | + protected readonly ILogger<MMPEmbeddingProvider> _logger; |
| 18 | + |
| 19 | + private const int DEFAULT_DIMENSION = 1536; |
| 20 | + protected string _model = "text-embedding-3-small"; |
| 21 | + protected int _dimension = DEFAULT_DIMENSION; |
| 22 | + |
| 23 | + // The underlying provider to use (e.g., "openai", "azure-openai", "deepseek-ai") |
| 24 | + protected string _underlyingProvider = "openai"; |
| 25 | + |
| 26 | + public string Provider => "mmp-embedding"; |
| 27 | + public string Model => _model; |
| 28 | + |
| 29 | + private static readonly Regex WordRegex = new(@"\b\w+\b", RegexOptions.Compiled); |
| 30 | + |
| 31 | + public MMPEmbeddingProvider(IServiceProvider serviceProvider, ILogger<MMPEmbeddingProvider> logger) |
| 32 | + { |
| 33 | + _serviceProvider = serviceProvider; |
| 34 | + _logger = logger; |
| 35 | + } |
| 36 | + |
| 37 | + /// <summary> |
| 38 | + /// Gets a single embedding vector using mean-max pooling |
| 39 | + /// </summary> |
| 40 | + public async Task<float[]> GetVectorAsync(string text) |
| 41 | + { |
| 42 | + if (string.IsNullOrWhiteSpace(text)) |
| 43 | + { |
| 44 | + return new float[_dimension]; |
| 45 | + } |
| 46 | + |
| 47 | + var tokens = Tokenize(text).ToList(); |
| 48 | + |
| 49 | + if (tokens.Count == 0) |
| 50 | + { |
| 51 | + return new float[_dimension]; |
| 52 | + } |
| 53 | + |
| 54 | + // Get embeddings for all tokens |
| 55 | + var tokenEmbeddings = await GetTokenEmbeddingsAsync(tokens); |
| 56 | + |
| 57 | + // Apply mean-max pooling |
| 58 | + var pooledEmbedding = MeanMaxPooling(tokenEmbeddings); |
| 59 | + |
| 60 | + return pooledEmbedding; |
| 61 | + } |
| 62 | + |
| 63 | + /// <summary> |
| 64 | + /// Gets multiple embedding vectors using mean-max pooling |
| 65 | + /// </summary> |
| 66 | + public async Task<List<float[]>> GetVectorsAsync(List<string> texts) |
| 67 | + { |
| 68 | + var results = new List<float[]>(); |
| 69 | + |
| 70 | + foreach (var text in texts) |
| 71 | + { |
| 72 | + var embedding = await GetVectorAsync(text); |
| 73 | + results.Add(embedding); |
| 74 | + } |
| 75 | + |
| 76 | + return results; |
| 77 | + } |
| 78 | + |
| 79 | + /// <summary> |
| 80 | + /// Gets embeddings for individual tokens using the underlying provider |
| 81 | + /// </summary> |
| 82 | + private async Task<List<float[]>> GetTokenEmbeddingsAsync(List<string> tokens) |
| 83 | + { |
| 84 | + try |
| 85 | + { |
| 86 | + // Get the appropriate client based on the underlying provider |
| 87 | + var client = ProviderHelper.GetClient(_underlyingProvider, _model, _serviceProvider); |
| 88 | + var embeddingClient = client.GetEmbeddingClient(_model); |
| 89 | + |
| 90 | + // Prepare options |
| 91 | + var options = new EmbeddingGenerationOptions |
| 92 | + { |
| 93 | + Dimensions = _dimension > 0 ? _dimension : null |
| 94 | + }; |
| 95 | + |
| 96 | + // Get embeddings for all tokens in batch |
| 97 | + var response = await embeddingClient.GenerateEmbeddingsAsync(tokens, options); |
| 98 | + var embeddings = response.Value; |
| 99 | + |
| 100 | + return embeddings.Select(e => e.ToFloats().ToArray()).ToList(); |
| 101 | + } |
| 102 | + catch (Exception ex) |
| 103 | + { |
| 104 | + _logger.LogError(ex, "Error getting token embeddings from provider {Provider} with model {Model}", |
| 105 | + _underlyingProvider, _model); |
| 106 | + throw; |
| 107 | + } |
| 108 | + } |
| 109 | + |
| 110 | + /// <summary> |
| 111 | + /// Applies mean-max pooling to combine token embeddings |
| 112 | + /// Mean pooling: average of all token embeddings |
| 113 | + /// Max pooling: element-wise maximum of all token embeddings |
| 114 | + /// Result: concatenation of mean and max pooled vectors |
| 115 | + /// </summary> |
| 116 | + private float[] MeanMaxPooling(IReadOnlyList<float[]> vectors, double meanWeight = 0.5, double maxWeight = 0.5) |
| 117 | + { |
| 118 | + var numTokens = vectors.Count; |
| 119 | + |
| 120 | + if (numTokens == 0) |
| 121 | + return []; |
| 122 | + |
| 123 | + var meanPooled = Enumerable.Range(0, _dimension) |
| 124 | + .Select(i => vectors.Average(v => v[i])) |
| 125 | + .ToArray(); |
| 126 | + var maxPooled = Enumerable.Range(0, _dimension) |
| 127 | + .Select(i => vectors.Max(v => v[i])) |
| 128 | + .ToArray(); |
| 129 | + |
| 130 | + return Enumerable.Range(0, _dimension) |
| 131 | + .Select(i => (float)meanWeight * meanPooled[i] + (float)maxWeight * maxPooled[i]) |
| 132 | + .ToArray(); |
| 133 | + } |
| 134 | + |
| 135 | + public void SetDimension(int dimension) |
| 136 | + { |
| 137 | + _dimension = dimension > 0 ? dimension : DEFAULT_DIMENSION; |
| 138 | + } |
| 139 | + |
| 140 | + public int GetDimension() |
| 141 | + { |
| 142 | + return _dimension; |
| 143 | + } |
| 144 | + |
| 145 | + public void SetModelName(string model) |
| 146 | + { |
| 147 | + _model = model; |
| 148 | + } |
| 149 | + |
| 150 | + /// <summary> |
| 151 | + /// Sets the underlying provider to use for getting token embeddings |
| 152 | + /// </summary> |
| 153 | + /// <param name="provider">Provider name (e.g., "openai", "azure-openai", "deepseek-ai")</param> |
| 154 | + public void SetUnderlyingProvider(string provider) |
| 155 | + { |
| 156 | + _underlyingProvider = provider; |
| 157 | + } |
| 158 | + |
| 159 | + /// <summary> |
| 160 | + /// Tokenizes text into individual words |
| 161 | + /// </summary> |
| 162 | + public static IEnumerable<string> Tokenize(string text, string? pattern = null) |
| 163 | + { |
| 164 | + var patternRegex = string.IsNullOrEmpty(pattern) ? WordRegex : new(pattern, RegexOptions.Compiled); |
| 165 | + return patternRegex.Matches(text).Cast<Match>().Select(m => m.Value); |
| 166 | + } |
| 167 | +} |
0 commit comments