Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crypto-ffi/src/core_crypto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{CoreCryptoResult, Database};
/// CoreCrypto wraps around MLS and Proteus implementations and provides a transactional interface for each.
#[derive(Debug, uniffi::Object)]
pub struct CoreCryptoFfi {
pub(crate) inner: core_crypto::CoreCrypto,
pub(crate) inner: Arc<core_crypto::CoreCrypto>,
}

/// Construct a new `CoreCryptoFfi` instance.
Expand Down
2 changes: 1 addition & 1 deletion crypto/src/ephemeral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ impl CoreCrypto {
///
/// This client exposes the full interface of `CoreCrypto`, but it should only be used to decrypt messages.
/// Other use is a logic error.
pub async fn history_client(history_secret: HistorySecret) -> Result<Self> {
pub async fn history_client(history_secret: HistorySecret) -> Result<Arc<Self>> {
if !history_secret
.client_id
.starts_with(HISTORY_CLIENT_ID_PREFIX.as_bytes())
Expand Down
13 changes: 6 additions & 7 deletions crypto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,29 +122,28 @@ impl MlsTransport for CoreCryptoTransportNotImplementedProvider {
///
/// As [std::ops::Deref] is implemented, this struct is automatically dereferred to [mls::session::Session] apart from
/// `proteus_*` calls
///
/// This is cheap to clone as all internal members have `Arc` wrappers or are `Copy`.
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct CoreCrypto {
database: Database,
pki_environment: Arc<RwLock<Option<PkiEnvironment>>>,
mls: Arc<RwLock<Option<mls::session::Session<Database>>>>,
pki_environment: RwLock<Option<PkiEnvironment>>,
mls: RwLock<Option<mls::session::Session<Database>>>,
#[cfg(feature = "proteus")]
proteus: Arc<Mutex<Option<proteus::ProteusCentral>>>,
proteus: Mutex<Option<proteus::ProteusCentral>>,
#[cfg(not(feature = "proteus"))]
#[allow(dead_code)]
proteus: (),
}

impl CoreCrypto {
/// Create an new CoreCrypto client without any initialized session.
pub fn new(database: Database) -> Self {
pub fn new(database: Database) -> Arc<Self> {
Self {
database,
pki_environment: Default::default(),
mls: Default::default(),
proteus: Default::default(),
}
.into()
}

/// Set the session's PKI Environment
Expand Down
4 changes: 2 additions & 2 deletions crypto/src/proteus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ mod tests {
.await
.unwrap();

let cc: CoreCrypto = CoreCrypto::new(db);
let cc = CoreCrypto::new(db);
let context = cc.new_transaction().await.unwrap();
assert!(context.proteus_init().await.is_ok());
assert!(context.proteus_new_prekey(1).await.is_ok());
Expand All @@ -624,7 +624,7 @@ mod tests {
.await
.unwrap();

let cc: CoreCrypto = CoreCrypto::new(db.clone());
let cc = CoreCrypto::new(db.clone());
let hooks = Arc::new(DummyPkiEnvironmentHooks);
let pki_env = PkiEnvironment::new(hooks, db).await.expect("creating pki environment");
cc.set_pki_environment(Some(pki_env))
Expand Down
4 changes: 2 additions & 2 deletions crypto/src/test_utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ pub struct SessionContext {
mls_transport: Arc<RwLock<Arc<dyn MlsTransportTestExt + 'static>>>,
x509_test_chain: Arc<Option<X509TestChain>>,
history_observer: Arc<RwLock<Option<Arc<TestHistoryObserver>>>>,
core_crypto: CoreCrypto,
core_crypto: Arc<CoreCrypto>,
// We need to store the `TempDir` struct for the duration of the test session,
// because its drop implementation takes care of the directory deletion.
_db: Option<(Database, Arc<tempfile::TempDir>)>,
Expand Down Expand Up @@ -174,7 +174,7 @@ impl SessionContext {

pub(crate) async fn new_from_cc(
context: &TestContext,
core_crypto: CoreCrypto,
core_crypto: Arc<CoreCrypto>,
chain: Option<&X509TestChain>,
) -> Self {
let transport = context.transport.clone();
Expand Down
100 changes: 40 additions & 60 deletions crypto/src/transaction_context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ pub use error::{Error, Result};
use openmls_traits::OpenMlsCryptoProvider as _;
use wire_e2e_identity::pki_env::PkiEnvironment;

#[cfg(feature = "proteus")]
use crate::proteus::ProteusCentral;
use crate::{
ClientId, ConversationId, CoreCrypto, CredentialFindFilters, CredentialRef, KeystoreError, MlsConversation,
MlsError, MlsTransport, RecursiveError, Session,
Expand Down Expand Up @@ -45,13 +43,9 @@ pub struct TransactionContext {
#[derive(Debug, Clone)]
enum TransactionContextInner {
Valid {
pki_environment: Arc<RwLock<Option<PkiEnvironment>>>,
database: Database,
mls_session: Arc<RwLock<Option<Session<Database>>>>,
core_crypto: Arc<CoreCrypto>,
mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
pending_epoch_changes: Arc<Mutex<Vec<(ConversationId, u64)>>>,
#[cfg(feature = "proteus")]
proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
},
Invalid,
}
Expand All @@ -60,15 +54,8 @@ impl CoreCrypto {
/// Creates a new transaction. All operations that persist data will be
/// buffered in memory and when [TransactionContext::finish] is called, the data will be persisted
/// in a single database transaction.
pub async fn new_transaction(&self) -> Result<TransactionContext> {
TransactionContext::new(
self.database.clone(),
self.pki_environment.clone(),
self.mls.clone(),
#[cfg(feature = "proteus")]
self.proteus.clone(),
)
.await
pub async fn new_transaction(self: &Arc<Self>) -> Result<TransactionContext> {
TransactionContext::new(self.clone()).await
}
}

Expand All @@ -91,27 +78,18 @@ impl HasSessionAndCrypto for TransactionContext {
}

impl TransactionContext {
async fn new(
keystore: Database,
pki_environment: Arc<RwLock<Option<PkiEnvironment>>>,
mls_session: Arc<RwLock<Option<Session<Database>>>>,
#[cfg(feature = "proteus")] proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
) -> Result<Self> {
keystore
async fn new(core_crypto: Arc<CoreCrypto>) -> Result<Self> {
core_crypto
.database
.new_transaction()
.await
.map_err(MlsError::wrap("creating new transaction"))?;
let mls_groups = Arc::new(RwLock::new(Default::default()));
Ok(Self {
inner: Arc::new(
TransactionContextInner::Valid {
database: keystore,
pki_environment,
mls_session: mls_session.clone(),
mls_groups,
core_crypto,
mls_groups: Default::default(),
pending_epoch_changes: Default::default(),
#[cfg(feature = "proteus")]
proteus_central,
}
.into(),
),
Expand All @@ -120,7 +98,7 @@ impl TransactionContext {

pub(crate) async fn session(&self) -> Result<Session<Database>> {
match &*self.inner.read().await {
TransactionContextInner::Valid { mls_session, .. } => mls_session.read().await.as_ref().cloned().ok_or(
TransactionContextInner::Valid { core_crypto, .. } => core_crypto.mls.read().await.as_ref().cloned().ok_or(
RecursiveError::mls_client("Getting mls session from transaction context")(
mls::session::Error::MlsNotInitialized,
)
Expand All @@ -133,8 +111,8 @@ impl TransactionContext {
#[cfg(test)]
pub(crate) async fn set_session_if_exists(&self, new_session: Session<Database>) {
match &*self.inner.read().await {
TransactionContextInner::Valid { mls_session, .. } => {
let mut guard = mls_session.write().await;
TransactionContextInner::Valid { core_crypto, .. } => {
let mut guard = core_crypto.mls.write().await;

if guard.as_ref().is_some() {
*guard = Some(new_session)
Expand All @@ -146,14 +124,18 @@ impl TransactionContext {

pub(crate) async fn mls_transport(&self) -> Result<Arc<dyn MlsTransport + 'static>> {
match &*self.inner.read().await {
TransactionContextInner::Valid { mls_session, .. } => {
mls_session.read().await.as_ref().map(|s| s.transport.clone()).ok_or(
TransactionContextInner::Valid { core_crypto, .. } => core_crypto
.mls
.read()
.await
.as_ref()
.map(|s| s.transport.clone())
.ok_or(
RecursiveError::mls_client("Getting mls session from transaction context")(
mls::session::Error::MlsNotInitialized,
)
.into(),
)
}
),

TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
}
Expand All @@ -162,7 +144,8 @@ impl TransactionContext {
/// Clones all references that the [MlsCryptoProvider] comprises.
pub async fn mls_provider(&self) -> Result<MlsCryptoProvider> {
match &*self.inner.read().await {
TransactionContextInner::Valid { mls_session, .. } => mls_session
TransactionContextInner::Valid { core_crypto, .. } => core_crypto
.mls
.read()
.await
.as_ref()
Expand All @@ -179,28 +162,32 @@ impl TransactionContext {

pub(crate) async fn database(&self) -> Result<Database> {
match &*self.inner.read().await {
TransactionContextInner::Valid { database, .. } => Ok(database.clone()),
TransactionContextInner::Valid { core_crypto, .. } => Ok(core_crypto.database.clone()),
TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
}
}

pub(crate) async fn pki_environment(&self) -> Result<PkiEnvironment> {
match &*self.inner.read().await {
TransactionContextInner::Valid { pki_environment, .. } => {
pki_environment.read().await.as_ref().map(Clone::clone).ok_or(
TransactionContextInner::Valid { core_crypto, .. } => core_crypto
.pki_environment
.read()
.await
.as_ref()
.map(Clone::clone)
.ok_or(
RecursiveError::transaction("Getting PKI environment from transaction context")(
e2e_identity::Error::PkiEnvironmentUnset,
)
.into(),
)
}
),
TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
}
}

pub(crate) async fn pki_environment_option(&self) -> Result<Option<PkiEnvironment>> {
match &*self.inner.read().await {
TransactionContextInner::Valid { pki_environment, .. } => Ok(pki_environment.read().await.clone()),
TransactionContextInner::Valid { core_crypto, .. } => Ok(core_crypto.pki_environment.read().await.clone()),

TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
}
Expand All @@ -225,36 +212,28 @@ impl TransactionContext {
}
}

#[cfg(feature = "proteus")]
pub(crate) async fn proteus_central(&self) -> Result<Arc<Mutex<Option<ProteusCentral>>>> {
match &*self.inner.read().await {
TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()),
TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext),
}
}

/// Commits the transaction, meaning it takes all the enqueued operations and persist them into
/// the keystore. After that the internal state is switched to invalid, causing errors if
/// something is called from this object.
pub async fn finish(&self) -> Result<()> {
let mut guard = self.inner.write().await;
let TransactionContextInner::Valid {
database,
core_crypto,
pending_epoch_changes,
mls_session,
..
} = &*guard
else {
return Err(Error::InvalidTransactionContext);
};

let commit_result = database
let commit_result = core_crypto
.database
.commit_transaction()
.await
.map_err(KeystoreError::wrap("commiting transaction"))
.map_err(Into::into);

if let Some(session) = mls_session.read_arc().await.clone()
if let Some(session) = core_crypto.mls.read().await.as_ref()
&& commit_result.is_ok()
{
// We need owned values, so we could just clone the conversation ids, but we don't need the events anymore,
Expand All @@ -276,11 +255,12 @@ impl TransactionContext {
pub async fn abort(&self) -> Result<()> {
let mut guard = self.inner.write().await;

let TransactionContextInner::Valid { database: keystore, .. } = &*guard else {
let TransactionContextInner::Valid { core_crypto, .. } = &*guard else {
return Err(Error::InvalidTransactionContext);
};

let result = keystore
let result = core_crypto
.database
.rollback_transaction()
.await
.map_err(KeystoreError::wrap("rolling back transaction"))
Expand Down Expand Up @@ -310,8 +290,8 @@ impl TransactionContext {
/// Set the `mls_session` Arc (also sets it on the transaction's CoreCrypto instance)
pub(crate) async fn set_mls_session(&self, session: Session<Database>) -> Result<()> {
match &*self.inner.read().await {
TransactionContextInner::Valid { mls_session, .. } => {
let mut guard = mls_session.write().await;
TransactionContextInner::Valid { core_crypto, .. } => {
let mut guard = core_crypto.mls.write().await;
*guard = Some(session);
Ok(())
}
Expand Down
Loading
Loading