Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions plugins/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,15 +559,15 @@ impl ServiceGenerator<'_> {
)| {
quote! {
#( #attrs )*
async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output;
fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> impl ::core::future::Future<Output = #output> + ::core::marker::Send;
}
},
);

let stub_doc = format!("The stub trait for service [`{service_ident}`].");
quote! {
#( #attrs )*
#vis trait #service_ident: ::core::marker::Sized {
#vis trait #service_ident: ::core::marker::Sized + ::core::marker::Send {
#( #rpc_fns )*

/// Returns a serving function to use with
Expand All @@ -578,11 +578,11 @@ impl ServiceGenerator<'_> {
}

#[doc = #stub_doc]
#vis trait #client_stub_ident: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident> {
#vis trait #client_stub_ident: ::tarpc::client::stub::SendStub<Req = #request_ident, Resp = #response_ident> {
}

impl<S> #client_stub_ident for S
where S: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident>
where S: ::tarpc::client::stub::SendStub<Req = #request_ident, Resp = #response_ident>
{
}
}
Expand Down Expand Up @@ -616,7 +616,7 @@ impl ServiceGenerator<'_> {
} = self;

quote! {
impl<S> ::tarpc::server::Serve for #server_ident<S>
impl<S> ::tarpc::server::SendServe for #server_ident<S>
where S: #service_ident
{
type Req = #request_ident;
Expand Down Expand Up @@ -780,7 +780,7 @@ impl ServiceGenerator<'_> {

quote! {
impl<Stub> #client_ident<Stub>
where Stub: ::tarpc::client::stub::Stub<
where Stub: ::tarpc::client::stub::SendStub<
Req = #request_ident,
Resp = #response_ident>
{
Expand Down
55 changes: 51 additions & 4 deletions tarpc/src/client/stub.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
//! Provides a Stub trait, implemented by types that can call remote services.

use std::future::Future;

use crate::{
client::{Channel, RpcError},
context,
server::Serve,
server::{SendServe, Serve},
RequestName,
};

Expand All @@ -15,7 +17,6 @@ mod mock;

/// A connection to a remote service.
/// Calls the service with requests of type `Req` and receives responses of type `Resp`.
#[allow(async_fn_in_trait)]
pub trait Stub {
/// The service request type.
type Req: RequestName;
Expand All @@ -24,8 +25,28 @@ pub trait Stub {
type Resp;

/// Calls a remote service.
async fn call(&self, ctx: context::Context, request: Self::Req)
-> Result<Self::Resp, RpcError>;
fn call(
&self,
ctx: context::Context,
request: Self::Req,
) -> impl Future<Output = Result<Self::Resp, RpcError>>;
}

/// A connection to a remote service.
/// Calls the service with requests of type `Req` and receives responses of type `Resp`.
pub trait SendStub: Send {
/// The service request type.
type Req: RequestName;

/// The service response type.
type Resp;

/// Calls a remote service.
fn call(
&self,
ctx: context::Context,
request: Self::Req,
) -> impl Future<Output = Result<Self::Resp, RpcError>> + Send;
}

impl<Req, Resp> Stub for Channel<Req, Resp>
Expand All @@ -40,6 +61,19 @@ where
}
}

impl<Req, Resp> SendStub for Channel<Req, Resp>
where
Req: RequestName + Send,
Resp: Send,
{
type Req = Req;
type Resp = Resp;

async fn call(&self, ctx: context::Context, request: Req) -> Result<Self::Resp, RpcError> {
Self::call(self, ctx, request).await
}
}

impl<S> Stub for S
where
S: Serve + Clone,
Expand All @@ -50,3 +84,16 @@ where
self.clone().serve(ctx, req).await.map_err(RpcError::Server)
}
}

impl<S> SendStub for S
where
S: SendServe + Clone + Sync,
S::Req: Send + Sync,
S::Resp: Send,
{
type Req = S::Req;
type Resp = S::Resp;
async fn call(&self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, RpcError> {
self.clone().serve(ctx, req).await.map_err(RpcError::Server)
}
}
18 changes: 18 additions & 0 deletions tarpc/src/client/stub/load_balance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ mod round_robin {
}
}

impl<Stub> stub::SendStub for RoundRobin<Stub>
where
Stub: stub::SendStub + Send + Sync,
Stub::Req: Send,
{
type Req = Stub::Req;
type Resp = Stub::Resp;

async fn call(
&self,
ctx: context::Context,
request: Self::Req,
) -> Result<Stub::Resp, RpcError> {
let next = self.stubs.next();
next.call(ctx, request).await
}
}

/// A Stub that load-balances across backing stubs by round robin.
#[derive(Clone, Debug)]
pub struct RoundRobin<Stub> {
Expand Down
27 changes: 26 additions & 1 deletion tarpc/src/client/stub/mock.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use crate::{
client::{stub::Stub, RpcError},
client::{
stub::{SendStub, Stub},
RpcError,
},
context, RequestName, ServerError,
};
use std::{collections::HashMap, hash::Hash, io};
Expand Down Expand Up @@ -42,3 +45,25 @@ where
})
}
}

impl<Req, Resp> SendStub for Mock<Req, Resp>
where
Req: Eq + Hash + RequestName + Send + Sync,
Resp: Clone + Send + Sync,
{
type Req = Req;
type Resp = Resp;

async fn call(&self, _: context::Context, request: Self::Req) -> Result<Resp, RpcError> {
self.responses
.get(&request)
.cloned()
.map(Ok)
.unwrap_or_else(|| {
Err(RpcError::Server(ServerError {
kind: io::ErrorKind::NotFound,
detail: "mock (request, response) entry not found".into(),
}))
})
}
}
27 changes: 27 additions & 0 deletions tarpc/src/client/stub/retry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,33 @@ where
}
}

impl<Stub, Req, F> stub::SendStub for Retry<F, Stub>
where
Req: RequestName + Send + Sync,
Stub: stub::SendStub<Req = Arc<Req>> + Send + Sync,
F: Fn(&Result<Stub::Resp, RpcError>, u32) -> bool + Send + Sync,
{
type Req = Req;
type Resp = Stub::Resp;

async fn call(
&self,
ctx: context::Context,
request: Self::Req,
) -> Result<Stub::Resp, RpcError> {
let request = Arc::new(request);
for i in 1.. {
let result = self.stub.call(ctx, Arc::clone(&request)).await;
if (self.should_retry)(&result, i) {
tracing::trace!("Retrying on attempt {i}");
continue;
}
return result;
}
unreachable!("Wow, that was a lot of attempts!");
}
}

/// A Stub that retries requests based on response contents.
/// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled.
#[derive(Clone, Debug)]
Expand Down
37 changes: 31 additions & 6 deletions tarpc/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ impl Config {
}

/// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`.
#[allow(async_fn_in_trait)]
pub trait Serve {
/// Type of request.
type Req: RequestName;
Expand All @@ -76,7 +75,33 @@ pub trait Serve {
type Resp;

/// Responds to a single request.
async fn serve(self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError>;
fn serve(
self,
ctx: context::Context,
req: Self::Req,
) -> impl Future<Output = Result<Self::Resp, ServerError>>;
}

/// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`.
pub trait SendServe: Send {
/// Type of request.
type Req: RequestName;
/// Type of response.
type Resp;
/// Responds to a single request.
fn serve(
self,
ctx: context::Context,
req: Self::Req,
) -> impl Future<Output = Result<Self::Resp, ServerError>> + Send;
}

impl<S: SendServe> Serve for S {
type Req = <Self as SendServe>::Req;
type Resp = <Self as SendServe>::Resp;
async fn serve(self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError> {
<Self as SendServe>::serve(self, ctx, req).await
}
}

/// A Serve wrapper around a Fn.
Expand Down Expand Up @@ -113,11 +138,11 @@ where
}
}

impl<Req, Resp, Fut, F> Serve for ServeFn<Req, Resp, F>
impl<Req, Resp, Fut, F> SendServe for ServeFn<Req, Resp, F>
where
Req: RequestName,
F: FnOnce(context::Context, Req) -> Fut,
Fut: Future<Output = Result<Resp, ServerError>>,
Req: RequestName + Send,
F: FnOnce(context::Context, Req) -> Fut + Send,
Fut: Future<Output = Result<Resp, ServerError>> + Send,
{
type Req = Req;
type Resp = Resp;
Expand Down