From 92ce2b3ea2d16b8ee7f6e6b1504d20397215c858 Mon Sep 17 00:00:00 2001 From: Steve Fan <29133953+stevefan1999-personal@users.noreply.github.com> Date: Tue, 8 Oct 2024 03:50:55 +0800 Subject: [PATCH] let serve and stub traits be send --- plugins/src/lib.rs | 4 +- tarpc/src/client/stub.rs | 18 ++++-- tarpc/src/client/stub/load_balance.rs | 17 +++--- tarpc/src/client/stub/mock.rs | 4 +- tarpc/src/client/stub/retry.rs | 6 +- tarpc/src/lib.rs | 9 +-- tarpc/src/server.rs | 20 +++--- tarpc/src/server/request_hook/after.rs | 11 ++-- tarpc/src/server/request_hook/before.rs | 61 +++++++++++++------ .../server/request_hook/before_and_after.rs | 5 +- 10 files changed, 97 insertions(+), 58 deletions(-) diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index c423644ba..72987e747 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -559,7 +559,7 @@ impl<'a> ServiceGenerator<'a> { )| { quote! { #( #attrs )* - async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output; + fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> impl ::std::future::Future + ::core::marker::Send; } }, ); @@ -567,7 +567,7 @@ impl<'a> ServiceGenerator<'a> { let stub_doc = format!("The stub trait for service [`{service_ident}`]."); quote! { #( #attrs )* - #vis trait #service_ident: ::core::marker::Sized { + #vis trait #service_ident: ::core::marker::Sized + ::core::marker::Send { #( #rpc_fns )* /// Returns a serving function to use with diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index e7c11aa05..6a22650b2 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -1,5 +1,7 @@ //! Provides a Stub trait, implemented by types that can call remote services. +use std::future::Future; + use crate::{ client::{Channel, RpcError}, context, @@ -16,7 +18,7 @@ mod mock; /// A connection to a remote service. /// Calls the service with requests of type `Req` and receives responses of type `Resp`. #[allow(async_fn_in_trait)] -pub trait Stub { +pub trait Stub: Send { /// The service request type. type Req: RequestName; @@ -24,13 +26,17 @@ pub trait Stub { type Resp; /// Calls a remote service. - async fn call(&self, ctx: context::Context, request: Self::Req) - -> Result; + fn call( + &self, + ctx: context::Context, + request: Self::Req, + ) -> impl Future> + Send; } impl Stub for Channel where - Req: RequestName, + Req: RequestName + Send, + Resp: Send, { type Req = Req; type Resp = Resp; @@ -42,7 +48,9 @@ where impl Stub for S where - S: Serve + Clone, + S: Serve + Clone + Send + Sync, + S::Req: Send + Sync, + S::Resp: Send + Sync, { type Req = S::Req; type Resp = S::Resp; diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index e586b7937..9410966a7 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -13,7 +13,8 @@ mod round_robin { impl stub::Stub for RoundRobin where - Stub: stub::Stub, + Stub: stub::Stub + Sync, + Stub::Req: Send, { type Req = Stub::Req; type Resp = Stub::Resp; @@ -110,9 +111,9 @@ mod consistent_hash { impl stub::Stub for ConsistentHash where - Stub: stub::Stub, - Stub::Req: Hash, - S: BuildHasher, + Stub: stub::Stub + Sync, + Stub::Req: Hash + Send, + S: BuildHasher + Send + Sync, { type Req = Stub::Req; type Resp = Stub::Resp; @@ -188,7 +189,7 @@ mod consistent_hash { use std::{ collections::HashMap, hash::{BuildHasher, Hash, Hasher}, - rc::Rc, + sync::Arc, }; #[tokio::test] @@ -230,11 +231,11 @@ mod consistent_hash { } struct FakeHasherBuilder { - recorded_hashes: Rc, u64>>, + recorded_hashes: Arc, u64>>, } struct FakeHasher { - recorded_hashes: Rc, u64>>, + recorded_hashes: Arc, u64>>, output: u64, } @@ -258,7 +259,7 @@ mod consistent_hash { recorded_hashes.insert(recorder.0, fake_hash); } Self { - recorded_hashes: Rc::new(recorded_hashes), + recorded_hashes: Arc::new(recorded_hashes), } } } diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index ae9ae9b26..af143aced 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -23,8 +23,8 @@ where impl Stub for Mock where - Req: Eq + Hash + RequestName, - Resp: Clone, + Req: Eq + Hash + RequestName + Send + Sync, + Resp: Clone + Send + Sync, { type Req = Req; type Resp = Resp; diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs index 89b033bc8..a9d98c1a9 100644 --- a/tarpc/src/client/stub/retry.rs +++ b/tarpc/src/client/stub/retry.rs @@ -8,9 +8,9 @@ use std::sync::Arc; impl stub::Stub for Retry where - Req: RequestName, - Stub: stub::Stub>, - F: Fn(&Result, u32) -> bool, + Req: RequestName + Send + Sync, + Stub: Sync + stub::Stub>, + F: Send + Sync + Fn(&Result, u32) -> bool, { type Req = Req; type Resp = Stub::Resp; diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index f3348d061..02971e6af 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -255,7 +255,6 @@ use std::{any::Any, error::Error, io, sync::Arc, time::Instant}; /// A message from a client to a server. #[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -#[non_exhaustive] pub enum ClientMessage { /// A request initiated by a user. The server responds to a request by invoking a /// service-provided request handler. The handler completes with a [`response`](Response), which @@ -280,7 +279,6 @@ pub enum ClientMessage { /// A request from a client to a server. #[derive(Clone, Copy, Debug)] -#[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. @@ -294,14 +292,14 @@ pub struct Request { /// Implemented by the request types generated by tarpc::service. pub trait RequestName { /// The name of a request. - fn name(&self) -> &'static str; + fn name(&self) -> &str; } impl RequestName for Arc where Req: RequestName, { - fn name(&self) -> &'static str { + fn name(&self) -> &str { self.as_ref().name() } } @@ -310,7 +308,7 @@ impl RequestName for Box where Req: RequestName, { - fn name(&self) -> &'static str { + fn name(&self) -> &str { self.as_ref().name() } } @@ -360,7 +358,6 @@ impl RequestName for u64 { /// A response from a server to a client. #[derive(Clone, Debug, PartialEq, Eq, Hash)] -#[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Response { /// The ID of the request being responded to. diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index d79d45c2c..d135fd427 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -68,7 +68,7 @@ impl Config { /// Equivalent to a `FnOnce(Req) -> impl Future`. #[allow(async_fn_in_trait)] -pub trait Serve { +pub trait Serve: Send { /// Type of request. type Req: RequestName; @@ -76,7 +76,11 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve(self, ctx: context::Context, req: Self::Req) -> Result; + fn serve( + self, + ctx: context::Context, + req: Self::Req, + ) -> impl Future> + Send; } /// A Serve wrapper around a Fn. @@ -115,9 +119,9 @@ where impl Serve for ServeFn where - Req: RequestName, - F: FnOnce(context::Context, Req) -> Fut, - Fut: Future>, + Req: RequestName + Send, + F: FnOnce(context::Context, Req) -> Fut + Send, + Fut: Future> + Send, { type Req = Req; type Resp = Resp; @@ -1046,7 +1050,7 @@ mod tests { #[tokio::test] async fn serve_before_mutates_context() -> anyhow::Result<()> { struct SetDeadline(Instant); - impl BeforeRequest for SetDeadline { + impl BeforeRequest for SetDeadline { async fn before( &mut self, ctx: &mut context::Context, @@ -1085,7 +1089,7 @@ mod tests { } } } - impl BeforeRequest for PrintLatency { + impl BeforeRequest for PrintLatency { async fn before( &mut self, _: &mut context::Context, @@ -1095,7 +1099,7 @@ mod tests { Ok(()) } } - impl AfterRequest for PrintLatency { + impl AfterRequest for PrintLatency { async fn after(&mut self, _: &mut context::Context, _: &mut Result) { tracing::info!("Elapsed: {:?}", self.start.elapsed()); } diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index 59afb473e..cf35dc7a8 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -11,17 +11,18 @@ use futures::prelude::*; /// A hook that runs after request execution. #[allow(async_fn_in_trait)] -pub trait AfterRequest { +pub trait AfterRequest: Send { /// The function that is called after request execution. /// /// The hook can modify the request context and the response. - async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result); + fn after(&mut self, ctx: &mut context::Context, resp: &mut Result) -> impl Future + Send; } impl AfterRequest for F where - F: FnMut(&mut context::Context, &mut Result) -> Fut, - Fut: Future, + F: Send + FnMut(&mut context::Context, &mut Result) -> Fut, + Fut: Send + Future, + Resp: Send, { async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result) { self(ctx, resp).await @@ -53,6 +54,8 @@ impl Serve for ServeThenHook where Serv: Serve, Hook: AfterRequest, + Serv::Req: Send, + Serv::Resp: Send, { type Req = Serv::Req; type Resp = Serv::Resp; diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index a221219ee..ffa9c9d57 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -11,7 +11,7 @@ use futures::prelude::*; /// A hook that runs before request execution. #[allow(async_fn_in_trait)] -pub trait BeforeRequest { +pub trait BeforeRequest: Send { /// The function that is called before request execution. /// /// If this function returns an error, the request will not be executed and the error will be @@ -19,11 +19,14 @@ pub trait BeforeRequest { /// /// This function can also modify the request context. This could be used, for example, to /// enforce a maximum deadline on all requests. - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError>; + fn before(&mut self, ctx: &mut context::Context, req: &Req) -> impl Future> + Send; } /// A list of hooks that run in order before request execution. -pub trait BeforeRequestList: BeforeRequest { +pub trait BeforeRequestList: BeforeRequest +where + Req: Sync, +{ /// The hook returned by `BeforeRequestList::then`. type Then: BeforeRequest where @@ -34,8 +37,8 @@ pub trait BeforeRequestList: BeforeRequest { /// Same as `then`, but helps the compiler with type inference when Next is a closure. fn then_fn< - Next: FnMut(&mut context::Context, &Req) -> Fut, - Fut: Future>, + Next: Send + FnMut(&mut context::Context, &Req) -> Fut, + Fut: Send + Future>, >( self, next: Next, @@ -56,8 +59,9 @@ pub trait BeforeRequestList: BeforeRequest { impl BeforeRequest for F where - F: FnMut(&mut context::Context, &Req) -> Fut, - Fut: Future>, + F: Send + FnMut(&mut context::Context, &Req) -> Fut, + Fut: Send + Future>, + Req: Sync, { async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> { self(ctx, req).await @@ -81,6 +85,7 @@ impl Serve for HookThenServe where Serv: Serve, Hook: BeforeRequest, + Serv::Req: Send, { type Req = Serv::Req; type Resp = Serv::Resp; @@ -139,6 +144,8 @@ pub struct BeforeRequestNil; impl, Rest: BeforeRequest> BeforeRequest for BeforeRequestCons +where + Req: Sync, { async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> { let BeforeRequestCons(first, rest) = self; @@ -148,7 +155,10 @@ impl, Rest: BeforeRequest> BeforeRequest BeforeRequest for BeforeRequestNil { +impl BeforeRequest for BeforeRequestNil +where + Req: Sync, +{ async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> { Ok(()) } @@ -156,8 +166,13 @@ impl BeforeRequest for BeforeRequestNil { impl, Rest: BeforeRequestList> BeforeRequestList for BeforeRequestCons +where + Req: Send + Sync, { - type Then = BeforeRequestCons> where Next: BeforeRequest; + type Then + = BeforeRequestCons> + where + Next: BeforeRequest; fn then>(self, next: Next) -> Self::Then { let BeforeRequestCons(first, rest) = self; @@ -171,8 +186,14 @@ impl, Rest: BeforeRequestList> BeforeRequest } } -impl BeforeRequestList for BeforeRequestNil { - type Then = BeforeRequestCons where Next: BeforeRequest; +impl BeforeRequestList for BeforeRequestNil +where + Req: Send + Sync, +{ + type Then + = BeforeRequestCons + where + Next: BeforeRequest; fn then>(self, next: Next) -> Self::Then { BeforeRequestCons(next, BeforeRequestNil) @@ -189,22 +210,26 @@ impl BeforeRequestList for BeforeRequestNil { fn before_request_list() { use crate::server::serve; use futures::executor::block_on; - use std::cell::Cell; + use std::sync::Mutex; - let i = Cell::new(0); + let i = Mutex::new(0); let serve = before() .then_fn(|_, _| async { - assert!(i.get() == 0); - i.set(1); + let mut i = i.lock().unwrap(); + assert!(*i == 0); + *i = 1; Ok(()) }) .then_fn(|_, _| async { - assert!(i.get() == 1); - i.set(2); + let mut i = i.lock().unwrap(); + assert!(*i == 1); + *i = 2; Ok(()) }) .serving(serve(|_ctx, i| async move { Ok(i + 1) })); let response = serve.clone().serve(context::current(), 1); assert!(block_on(response).is_ok()); - assert!(i.get() == 2); + + let i = i.lock().unwrap(); + assert!(*i == 2); } diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index 8556ac016..25e775994 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -39,9 +39,10 @@ impl Clone for HookThenServeThenHook Serve for HookThenServeThenHook where - Req: RequestName, + Req: RequestName + Send, + Resp: Send, Serv: Serve, - Hook: BeforeRequest + AfterRequest, + Hook: BeforeRequest + AfterRequest + Send, { type Req = Req; type Resp = Resp;