Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions dotnet/src/Agents/Orchestration/AgentActor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,23 +92,31 @@ protected ValueTask<ChatMessageContent> InvokeAsync(ChatMessageContent input, Ca
/// <returns>A task that returns the response <see cref="ChatMessageContent"/>.</returns>
protected async ValueTask<ChatMessageContent> InvokeAsync(IList<ChatMessageContent> input, CancellationToken cancellationToken)
{
this.Context.Cancellation.ThrowIfCancellationRequested();
try
{
this.Context.Cancellation.ThrowIfCancellationRequested();

this._lastResponse = null;
this._lastResponse = null;

AgentInvokeOptions options = this.GetInvokeOptions(HandleMessageAsync);
if (this.Context.StreamingResponseCallback == null)
{
// No need to utilize streaming if no callback is provided
await this.InvokeAsync(input, options, cancellationToken).ConfigureAwait(false);
AgentInvokeOptions options = this.GetInvokeOptions(HandleMessageAsync);
if (this.Context.StreamingResponseCallback == null)
{
// No need to utilize streaming if no callback is provided
await this.InvokeAsync(input, options, cancellationToken).ConfigureAwait(false);
}
else
{
await this.InvokeStreamingAsync(input, options, cancellationToken).ConfigureAwait(false);
}

return this._lastResponse ?? new ChatMessageContent(AuthorRole.Assistant, string.Empty);
}
else
catch (Exception exception)
{
await this.InvokeStreamingAsync(input, options, cancellationToken).ConfigureAwait(false);
this.Context.FailureCallback.Invoke(exception);
throw;
}

return this._lastResponse ?? new ChatMessageContent(AuthorRole.Assistant, string.Empty);

async Task HandleMessageAsync(ChatMessageContent message)
{
this._lastResponse = message; // Keep track of most recent response for both invocation modes
Expand Down
5 changes: 3 additions & 2 deletions dotnet/src/Agents/Orchestration/AgentOrchestration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,19 @@ public async ValueTask<OrchestrationResult<TOutput>> InvokeAsync(

CancellationTokenSource orchestrationCancelSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);

TaskCompletionSource<TOutput> completion = new();

OrchestrationContext context =
new(this.OrchestrationLabel,
topic,
this.ResponseCallback,
this.StreamingResponseCallback,
exception => completion.SetException(exception),
this.LoggerFactory,
cancellationToken);

ILogger logger = this.LoggerFactory.CreateLogger(this.GetType());

TaskCompletionSource<TOutput> completion = new();

AgentType orchestrationType = await this.RegisterAsync(runtime, context, completion, handoff: null).ConfigureAwait(false);

cancellationToken.ThrowIfCancellationRequested();
Expand Down
8 changes: 8 additions & 0 deletions dotnet/src/Agents/Orchestration/OrchestrationContext.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Threading;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.Agents.Runtime;
Expand All @@ -16,11 +17,13 @@ internal OrchestrationContext(
TopicId topic,
OrchestrationResponseCallback? responseCallback,
OrchestrationStreamingCallback? streamingCallback,
Action<Exception> failureCallback,
ILoggerFactory loggerFactory,
CancellationToken cancellation)
{
this.Orchestration = orchestration;
this.Topic = topic;
this.FailureCallback = failureCallback;
this.ResponseCallback = responseCallback;
this.StreamingResponseCallback = streamingCallback;
this.LoggerFactory = loggerFactory;
Expand Down Expand Up @@ -59,4 +62,9 @@ internal OrchestrationContext(
/// Optional callback that is invoked for every agent response.
/// </summary>
public OrchestrationStreamingCallback? StreamingResponseCallback { get; }

/// <summary>
/// Gets the callback that is invoked when an operation fails due to an exception.
/// </summary>
public Action<Exception> FailureCallback { get; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ public class OrchestrationResultTests
public void Constructor_InitializesPropertiesCorrectly()
{
// Arrange
OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, NullLoggerFactory.Instance, CancellationToken.None);
Exception? captureException = null;
OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, exception => captureException = exception, NullLoggerFactory.Instance, CancellationToken.None);
TaskCompletionSource<string> tcs = new();

// Act
using CancellationTokenSource cancelSource = new();
using OrchestrationResult<string> result = new(context, tcs, cancelSource, NullLogger.Instance);

// Assert
Assert.Null(captureException);
Assert.Equal("TestOrchestration", result.Orchestration);
Assert.Equal(new TopicId("testTopic"), result.Topic);
}
Expand All @@ -32,7 +34,8 @@ public void Constructor_InitializesPropertiesCorrectly()
public async Task GetValueAsync_ReturnsCompletedValue_WhenTaskIsCompletedAsync()
{
// Arrange
OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, NullLoggerFactory.Instance, CancellationToken.None);
Exception? captureException = null;
OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, exception => captureException = exception, NullLoggerFactory.Instance, CancellationToken.None);
TaskCompletionSource<string> tcs = new();
using CancellationTokenSource cancelSource = new();
using OrchestrationResult<string> result = new(context, tcs, cancelSource, NullLogger.Instance);
Expand All @@ -43,14 +46,16 @@ public async Task GetValueAsync_ReturnsCompletedValue_WhenTaskIsCompletedAsync()
string actualValue = await result.GetValueAsync();

// Assert
Assert.Null(captureException);
Assert.Equal(expectedValue, actualValue);
}

[Fact]
public async Task GetValueAsync_WithTimeout_ReturnsCompletedValue_WhenTaskCompletesWithinTimeoutAsync()
{
// Arrange
OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, NullLoggerFactory.Instance, CancellationToken.None);
Exception? captureException = null;
OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, exception => captureException = exception, NullLoggerFactory.Instance, CancellationToken.None);
TaskCompletionSource<string> tcs = new();
using CancellationTokenSource cancelSource = new();
using OrchestrationResult<string> result = new(context, tcs, cancelSource, NullLogger.Instance);
Expand All @@ -62,28 +67,32 @@ public async Task GetValueAsync_WithTimeout_ReturnsCompletedValue_WhenTaskComple
string actualValue = await result.GetValueAsync(timeout);

// Assert
Assert.Null(captureException);
Assert.Equal(expectedValue, actualValue);
}

[Fact]
public async Task GetValueAsync_WithTimeout_ThrowsTimeoutException_WhenTaskDoesNotCompleteWithinTimeoutAsync()
{
// Arrange
OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, NullLoggerFactory.Instance, CancellationToken.None);
Exception? captureException = null;
OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, exception => captureException = exception, NullLoggerFactory.Instance, CancellationToken.None);
TaskCompletionSource<string> tcs = new();
using CancellationTokenSource cancelSource = new();
using OrchestrationResult<string> result = new(context, tcs, cancelSource, NullLogger.Instance);
TimeSpan timeout = TimeSpan.FromMilliseconds(50);

// Act & Assert
TimeoutException exception = await Assert.ThrowsAsync<TimeoutException>(() => result.GetValueAsync(timeout).AsTask());
Assert.Null(captureException);
}

[Fact]
public async Task GetValueAsync_ReturnsCompletedValue_WhenCompletionIsDelayedAsync()
{
// Arrange
OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, NullLoggerFactory.Instance, CancellationToken.None);
Exception? captureException = null;
OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, exception => captureException = exception, NullLoggerFactory.Instance, CancellationToken.None);
TaskCompletionSource<int> tcs = new();
using CancellationTokenSource cancelSource = new();
using OrchestrationResult<int> result = new(context, tcs, cancelSource, NullLogger.Instance);
Expand All @@ -100,6 +109,7 @@ public async Task GetValueAsync_ReturnsCompletedValue_WhenCompletionIsDelayedAsy
int actualValue = await result.GetValueAsync();

// Assert
Assert.Null(captureException);
Assert.Equal(expectedValue, actualValue);
}
}
2 changes: 1 addition & 1 deletion dotnet/src/VectorData/SqliteVec/SqliteMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public TRecord MapFromStorageToDataModel(DbDataReader reader, bool includeVector

var floats = new float[length / 4];
var bytes = MemoryMarshal.Cast<float, byte>(floats);
stream.ReadExactly(bytes);
stream.ReadExactly([.. bytes]);
#else
var floats = MemoryMarshal.Cast<byte, float>((byte[])reader[ordinal]).ToArray();
#endif
Expand Down
Loading