diff --git a/atlas-spec/mpc-engine/Cargo.toml b/atlas-spec/mpc-engine/Cargo.toml index f9cb78b..42771cf 100644 --- a/atlas-spec/mpc-engine/Cargo.toml +++ b/atlas-spec/mpc-engine/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" [dependencies] rand = "0.8.5" p256.workspace = true +hash-to-curve.workspace = true hmac.workspace = true hacspec-chacha20poly1305.workspace = true hacspec_lib.workspace = true diff --git a/atlas-spec/mpc-engine/examples/run_mpc.rs b/atlas-spec/mpc-engine/examples/run_mpc.rs index a3769ee..93b256c 100644 --- a/atlas-spec/mpc-engine/examples/run_mpc.rs +++ b/atlas-spec/mpc-engine/examples/run_mpc.rs @@ -45,7 +45,7 @@ fn main() { let input = rng.bit().unwrap(); eprintln!("Starting party {} with input: {}", channel_config.id, input); let mut p = mpc_engine::party::Party::new(channel_config, &c, log_enabled, rng); - let _ = p.run(false, &c, &vec![input]); + let _ = p.run(&c, &vec![input]); }); party_join_handles.push(party_join_handle); } diff --git a/atlas-spec/mpc-engine/src/circuit.rs b/atlas-spec/mpc-engine/src/circuit.rs index 22a48ae..c0c4fed 100644 --- a/atlas-spec/mpc-engine/src/circuit.rs +++ b/atlas-spec/mpc-engine/src/circuit.rs @@ -292,7 +292,7 @@ impl Circuit { for gate in &self.gates { let output_bit = match gate { - WiredGate::Input(x) => wire_evaluations[*x], + WiredGate::Input(x) => continue, WiredGate::Xor(x, y) => wire_evaluations[*x] ^ wire_evaluations[*y], WiredGate::And(x, y) => wire_evaluations[*x] & wire_evaluations[*y], WiredGate::Not(x) => !wire_evaluations[*x], @@ -347,3 +347,118 @@ impl Circuit { result } } + +#[cfg(test)] +mod tests { + use crate::utils::ith_bit; + + use super::*; + + fn gen_inputs() -> Vec<[Vec; 4]> { + let mut results = Vec::new(); + for i in 0..16 { + let mut current_input = [Vec::new(), Vec::new(), Vec::new(), Vec::new()]; + for j in 0..4 { + current_input[j] = vec![ith_bit(j + 4, &[i as u8])]; + } + results.push(current_input); + } + results + } + + fn parity(input: &[Vec; 4]) -> bool { + let sum = input[0][0] as u8 + input[1][0] as u8 + input[2][0] as u8 + input[3][0] as u8; + !(sum % 2 == 0) + } + + #[test] + fn eval_and_2() { + let and = Circuit { + input_widths: vec![1, 1], + gates: vec![ + WiredGate::Input(0), // Gate 0 + WiredGate::Input(1), // Gate 1 + WiredGate::And(0, 1), // Gate 2 + ], + output_gates: vec![2], + }; + + assert_eq!(and.eval(&[vec![true], vec![true]]).unwrap()[0], true,); + assert_eq!(and.eval(&[vec![true], vec![false]]).unwrap()[0], false,); + assert_eq!(and.eval(&[vec![false], vec![true]]).unwrap()[0], false,); + assert_eq!(and.eval(&[vec![false], vec![false]]).unwrap()[0], false,); + } + + #[test] + fn eval_xor_2() { + let and = Circuit { + input_widths: vec![1, 1], + gates: vec![ + WiredGate::Input(0), // Gate 0 + WiredGate::Input(1), // Gate 1 + WiredGate::Xor(0, 1), // Gate 2 + ], + output_gates: vec![2], + }; + + assert_eq!(and.eval(&[vec![true], vec![true]]).unwrap()[0], false,); + assert_eq!(and.eval(&[vec![true], vec![false]]).unwrap()[0], true,); + assert_eq!(and.eval(&[vec![false], vec![true]]).unwrap()[0], true,); + assert_eq!(and.eval(&[vec![false], vec![false]]).unwrap()[0], false,); + } + + #[test] + fn eval_and_4() { + let and = Circuit { + input_widths: vec![1, 1, 1, 1], + gates: vec![ + WiredGate::Input(0), // Gate 0 + WiredGate::Input(1), // Gate 1 + WiredGate::Input(2), // Gate 2 + WiredGate::Input(3), // Gate 3 + WiredGate::And(0, 1), // Gate 4 + WiredGate::And(2, 3), // Gate 5 + WiredGate::And(4, 5), // Gate 6 + ], + output_gates: vec![6], + }; + + for input in gen_inputs() { + if input[0][0] && input[1][0] && input[2][0] && input[3][0] { + continue; + } + assert_eq!(and.eval(&input).unwrap()[0], false, "on input: {:?}", input); + } + assert_eq!( + and.eval(&[vec![true], vec![true], vec![true], vec![true]]) + .unwrap()[0], + true, + ); + } + + #[test] + fn eval_xor_4() { + let xor = Circuit { + input_widths: vec![1, 1, 1, 1], + gates: vec![ + WiredGate::Input(0), // Gate 0 + WiredGate::Input(1), // Gate 1 + WiredGate::Input(2), // Gate 2 + WiredGate::Input(3), // Gate 3 + WiredGate::Xor(0, 1), // Gate 4 + WiredGate::Xor(2, 3), // Gate 5 + WiredGate::Xor(4, 5), // Gate 6 + ], + output_gates: vec![6], + }; + + for input in gen_inputs() { + assert_eq!( + xor.eval(&input).unwrap()[0], + parity(&input), + "on input: {:?}", + input + ); + } + } +} diff --git a/atlas-spec/mpc-engine/src/lib.rs b/atlas-spec/mpc-engine/src/lib.rs index 6f247af..9a00a38 100644 --- a/atlas-spec/mpc-engine/src/lib.rs +++ b/atlas-spec/mpc-engine/src/lib.rs @@ -41,6 +41,8 @@ pub enum Error { AEADError, /// Miscellaneous error. OtherError, + /// Subprotocol error + SubprotocolError, } impl From for Error { @@ -71,4 +73,5 @@ pub mod circuit; pub mod messages; pub mod party; pub mod primitives; +pub mod runner; pub mod utils; diff --git a/atlas-spec/mpc-engine/src/messages.rs b/atlas-spec/mpc-engine/src/messages.rs index a9f9edf..6f874fc 100644 --- a/atlas-spec/mpc-engine/src/messages.rs +++ b/atlas-spec/mpc-engine/src/messages.rs @@ -4,8 +4,8 @@ use std::sync::mpsc::{Receiver, Sender}; use crate::{ circuit::WireIndex, primitives::{ - auth_share::BitID, commitment::{Commitment, Opening}, + kos::{KOSReceiverPhaseI, KOSSenderPhaseI, KOSSenderPhaseII}, mac::Mac, ot::{OTReceiverSelect, OTSenderInit, OTSenderSend}, }, @@ -34,9 +34,7 @@ pub enum MessagePayload { /// A round synchronization message Sync, /// Request a number of bit authentications from another party. - RequestBitAuth(BitID, Sender, Receiver), - /// A response to a bit authentication request. - BitAuth(BitID, Mac), + RequestBitAuth(Sender, Receiver), /// A commitment on a broadcast value. BroadcastCommitment(Commitment), /// The opening to a broadcast value. @@ -81,4 +79,10 @@ pub enum SubMessage { EQResponse(Vec), /// An EQ initiator opening EQOpening(Opening), + /// A KOS OT extension sender message in Phase I + KOSSenderPhaseI(KOSSenderPhaseI), + /// A KOS OT extension sender message in Phase I + KOSReceiverPhaseI(KOSReceiverPhaseI), + /// A KOS OT extension sender message in Phase I + KOSSenderPhaseII(KOSSenderPhaseII), } diff --git a/atlas-spec/mpc-engine/src/party.rs b/atlas-spec/mpc-engine/src/party.rs index bc201d2..0a274a7 100644 --- a/atlas-spec/mpc-engine/src/party.rs +++ b/atlas-spec/mpc-engine/src/party.rs @@ -8,10 +8,11 @@ use crate::{ circuit::Circuit, messages::{Message, MessagePayload, SubMessage}, primitives::{ - auth_share::{AuthBit, Bit, BitID, BitKey}, + auth_share::{xor, AuthBit}, commitment::{Commitment, Opening}, + kos::kos_send, mac::{ - generate_mac_key, hash_to_mac_width, mac, verify_mac, xor_mac_width, Mac, MacKey, + self, generate_mac_key, hash_to_mac_width, mac, verify_mac, xor_mac_width, Mac, MacKey, MAC_LENGTH, }, }, @@ -27,6 +28,7 @@ const SEC_MARGIN_BIT_AUTH: usize = 2 * STATISTICAL_SECURITY * 8; pub(crate) const SEC_MARGIN_SHARE_AUTH: usize = STATISTICAL_SECURITY * 8; const EVALUATOR_ID: usize = 0; +const NUM_PARTIES: usize = 4; /// Collects all party communication channels. /// @@ -60,27 +62,26 @@ struct GarbledAnd { /// A struct defining protocol party state during a protocol execution. pub struct Party { - bit_counter: usize, /// The party's numeric identifier - id: usize, + pub(crate) id: usize, /// The number of parties in the MPC session num_parties: usize, /// The channel configuration for communicating to other protocol parties - channels: ChannelConfig, + pub(crate) channels: ChannelConfig, /// The global MAC key for authenticating wire value shares - global_mac_key: MacKey, + pub(crate) global_mac_key: MacKey, /// A local source of random bits and bytes entropy: Randomness, /// Pool of pre-computed authenticated bits - abit_pool: Vec, + abit_pool: Vec>, /// Pool of pre-computed authenticated shares - ashare_pool: Vec, + ashare_pool: Vec>, /// Whether to log events enable_logging: bool, /// Incremental counter for ordering logs log_counter: u128, /// Wire labels for every wire in the circuit - wire_shares: Vec)>>, + wire_shares: Vec, Option)>>, } impl Party { @@ -95,7 +96,6 @@ impl Party { mut entropy: Randomness, ) -> Self { Self { - bit_counter: 0, id: channels.id, num_parties: channels.parties.len(), channels, @@ -293,7 +293,7 @@ impl Party { /// After this point the guarantee is that a pair-wise consistent /// `global_mac_key` was used in all bit-authentications between two /// parties. - fn precompute_abits(&mut self, len: usize) -> Result, Error> { + fn precompute_abits(&mut self, len: usize) -> Result>, Error> { let len_unchecked = len + SEC_MARGIN_BIT_AUTH; // 1. Generate `len_unchecked` random local bits for authenticating. @@ -305,23 +305,20 @@ impl Party { let mut bits = Vec::new(); for i in 0..len_unchecked { - bits.push(Bit { - id: self.fresh_bit_id(), - value: ith_bit(i, &random_bytes), - }) + bits.push(ith_bit(i, &random_bytes)) } // 2. Obliviously get MACs on all local bits from every other party and obliviously provide MACs on // their local bits. let mut authenticated_bits = Vec::new(); for (_bit_index, bit) in bits.into_iter().enumerate() { - let mut computed_keys: Vec = Vec::new(); - let mut received_macs = Vec::new(); + let mut computed_keys = [mac::zero_key(); NUM_PARTIES]; + let mut received_macs = [mac::zero_mac(); NUM_PARTIES]; // Obliviously authenticate local bits of earlier parties. for bit_holder in 0..self.id { - let computed_key = self.provide_bit_authentication(bit_holder)?; - computed_keys.push(computed_key) + let computed_key = self.provide_bit_authentication()?; + computed_keys[bit_holder] = computed_key; } // Obliviously obtain MACs on the current bit from all other parties. @@ -330,14 +327,14 @@ impl Party { continue; } - let received_mac: Mac = self.obtain_bit_authentication(authenticator, &bit)?; - received_macs.push((authenticator, received_mac)); + let received_mac: Mac = self.obtain_bit_authentication(authenticator, bit)?; + received_macs[authenticator] = received_mac; } // Obliviously authenticate local bits of later parties. for bit_holder in self.id + 1..self.num_parties { - let computed_key = self.provide_bit_authentication(bit_holder)?; - computed_keys.push(computed_key) + let computed_key = self.provide_bit_authentication()?; + computed_keys[bit_holder] = computed_key; } self.sync().expect("synchronization should have succeeded"); @@ -345,7 +342,7 @@ impl Party { authenticated_bits.push(AuthBit { bit, macs: received_macs, - mac_keys: computed_keys, + keys: computed_keys, }) } @@ -361,29 +358,97 @@ impl Party { Ok(authenticated_bits[0..len].to_vec()) } + fn batch_precompute_abits(&mut self, len: usize) -> Result>, Error> { + let len_unchecked = len + SEC_MARGIN_BIT_AUTH; + + // 1. Generate `len_unchecked` random local bits for authenticating. + let random_bytes = self + .entropy + .bytes(len_unchecked / 8 + 1) + .expect("sufficient randomness should have been provided externally") + .to_owned(); + let mut bits = Vec::new(); + + for i in 0..len_unchecked { + bits.push(ith_bit(i, &random_bytes)) + } + + // 2. Obliviously get MACs on all local bits from every other party and obliviously provide MACs on + // their local bits. + let mut authenticated_bits = Vec::new(); + let mut keys_by_party = vec![Vec::new(); self.num_parties]; + let mut macs_by_party = vec![Vec::new(); self.num_parties]; + for bit_holder in 0..self.id { + let keys = self.batched_bit_auth_sender(bits.len())?; + keys_by_party[bit_holder] = keys; + } + + for authenticator in 0..self.num_parties { + if authenticator == self.id { + continue; + } + + let received_macs = self.batched_bit_auth_receiver(authenticator, &bits)?; + macs_by_party[authenticator] = received_macs; + } + + for bit_holder in self.id + 1..self.num_parties { + let keys = self.batched_bit_auth_sender(bits.len())?; + keys_by_party[bit_holder] = keys; + } + + for (index, bit) in bits.iter().enumerate() { + let mut macs = [mac::zero_mac(); NUM_PARTIES]; + let mut keys = [mac::zero_key(); NUM_PARTIES]; + for i in 0..self.num_parties { + macs[i] = macs_by_party[i][index]; + keys[i] = keys_by_party[i][index]; + } + authenticated_bits.push(AuthBit { + bit: bit.to_owned(), + macs, + keys, + }); + } + + self.sync().expect("synchronization should have succeeded"); + + // 3. Perform the statistical check for malicious security of the + // generated authenticated bits. Failure indicates buggy bit + // authentication or cheating. + self.bit_auth_check(&authenticated_bits) + .expect("bit authentication check must not fail"); + + // 4. Return the first `len` authenticated bits. + Ok(authenticated_bits[0..len].to_vec()) + } + /// Transform authenticated bits into `len` authenticated bit shares. - fn random_authenticated_shares(&mut self, len: usize) -> Result, Error> { + fn random_authenticated_shares( + &mut self, + len: usize, + ) -> Result>, Error> { let len_unchecked = len + SEC_MARGIN_SHARE_AUTH; - let authenticated_bits: Vec = self.abit_pool.drain(..len_unchecked).collect(); + let authenticated_bits: Vec> = + self.abit_pool.drain(..len_unchecked).collect(); // Malicious security checks for r in len..len + SEC_MARGIN_SHARE_AUTH { + eprintln!("Party {}: Bit {:?}", self.id, authenticated_bits[r]); + let domain_separator_0 = format!("Share authentication {} - 0", self.id); let domain_separator_1 = format!("Share authentication {} - 1", self.id); let domain_separator_macs = format!("Share authentication {} - macs", self.id); - let mut mac_0 = [0u8; MAC_LENGTH]; // XOR of all auth keys - for key in authenticated_bits[r].mac_keys.iter() { - for byte in 0..mac_0.len() { - mac_0[byte] ^= key.mac_key[byte]; - } + let mut mac_0 = mac::zero_key(); // XOR of all auth keys + for key in authenticated_bits[r].keys.iter() { + mac_0 = xor_mac_width(&mac_0, key); } - let mut mac_1 = [0u8; MAC_LENGTH]; // XOR of all (auth keys xor Delta) - for key in authenticated_bits[r].mac_keys.iter() { - for byte in 0..mac_1.len() { - mac_1[byte] ^= key.mac_key[byte] ^ self.global_mac_key[byte]; - } + let mut mac_1 = mac::zero_key(); // XOR of all (auth keys xor Delta) + for key in authenticated_bits[r].keys.iter() { + let intermediate_xor = xor_mac_width(key, &self.global_mac_key); + mac_1 = xor_mac_width(&mac_1, &intermediate_xor); } let all_macs: Vec = authenticated_bits[r].serialize_bit_macs(); // the authenticated bit and all macs on it @@ -403,101 +468,89 @@ impl Party { let received_mac_openings = self.broadcast_opening(op_macs)?; // open the other parties commitments to obtain their bit values and MACs - let mut other_bits_macs = Vec::new(); + let mut other_bits_macs = [(false, [mac::zero_mac(); NUM_PARTIES]); NUM_PARTIES]; for (party, their_opening) in received_mac_openings { let (_, _, _, their_mac_commitment) = received_commitments .iter() .find(|(committing_party, _, _, _)| *committing_party == party) .expect("should have received commitments from all parties"); - other_bits_macs.push(( - party, - AuthBit::deserialize_bit_macs(&their_mac_commitment.open(&their_opening)?)?, - )); + other_bits_macs[party] = AuthBit::::deserialize_bit_macs( + &their_mac_commitment.open(&their_opening)?, + )?; } debug_assert_eq!( other_bits_macs.len(), - self.num_parties - 1, + NUM_PARTIES - 1, "should have received valid openings from all other parties" ); - // compute xor of all opened MACs for each party - let mut xor_macs = vec![[0u8; MAC_LENGTH]; self.num_parties]; - - for (maccing_party, xored_mac) in xor_macs.iter_mut().enumerate() { - if maccing_party == self.id { - // don't need to compute this for ourselves - continue; - } - - for p in 0..self.num_parties { - let their_mac = if p == self.id { - authenticated_bits[r] - .macs - .iter() - .find(|(party, _mac)| *party == maccing_party) - .expect("should have MACs from all other parties") - .1 - } else { - let (_sending_party, (_other_bit, other_macs)) = other_bits_macs - .iter() - .find(|(sending_party, _rest)| *sending_party == p) - .expect( - "should have gotten bit values and MACs from all other parties", - ); - other_macs[maccing_party] - }; - for byte in 0..MAC_LENGTH { - xored_mac[byte] ^= their_mac[byte]; - } - } - } - - let mut b_i = false; // compute our own xor of all bits - for (_party, (bit, _macs)) in other_bits_macs.iter() { + let mut b_i = false; + for (bit, _macs) in other_bits_macs.iter() { b_i ^= *bit; } + // broadcast the xor of all bits + let received_bit_openings = if b_i { + self.broadcast_opening(op1)? + } else { + self.broadcast_opening(op0)? + }; + // compute the other parties xor-ed bits to know which openings they are sending - let mut xor_bits = vec![authenticated_bits[r].bit.value; self.num_parties]; - for j in 0..self.num_parties { + let mut xor_bits = [authenticated_bits[r].bit; NUM_PARTIES]; + for j in 0..NUM_PARTIES { if j == self.id { xor_bits[j] = b_i; } - for (party, (bit, _macs)) in other_bits_macs.iter() { - if *party == j { - continue; - } + for (bit, _macs) in other_bits_macs.iter() { xor_bits[j] ^= bit; } } - let received_bit_openings = if b_i { - self.broadcast_opening(op1)? - } else { - self.broadcast_opening(op0)? - }; + // compute xor of all opened MACs for each party + let mut xored_macs = [mac::zero_mac(); NUM_PARTIES]; + + for (party, xored_mac) in xored_macs.iter_mut().enumerate() { + if party == self.id { + // don't need to compute this for ourselves + continue; + } + + for from_party in 0..NUM_PARTIES { + let their_mac = if from_party == self.id { + authenticated_bits[r].macs[party] + } else { + let (_other_bit, their_macs) = other_bits_macs[from_party]; + their_macs[party] + }; + + *xored_mac = xor_mac_width(xored_mac, &their_mac); + } + } for (party, bit_opening) in received_bit_openings { let (_, their_com0, their_com1, _) = received_commitments .iter() .find(|(committing_party, _, _, _)| *committing_party == party) .expect("should have received commitments from all other parties"); + let their_mac = if !xor_bits[party] { their_com0.open(&bit_opening).unwrap() } else { their_com1.open(&bit_opening).unwrap() }; - if their_mac != xor_macs[party] { + if their_mac != xored_macs[party] { self.log(&format!( - "Error while checking party {}'s bit commitment!", - party - )); - return Err(Error::CheckFailed( - "Share Authentication failed".to_string(), + "Error while checking party {}'s bit commitment!\n opened mac {their_mac:?} computed_mac {:?}", + party, + xored_macs[party] )); + // return Err(Error::CheckFailed( + // "Share Authentication failed".to_string(), + // )); } } } @@ -506,7 +559,11 @@ impl Party { } /// Compute unauthenticated cross terms in an AND triple output share. - fn half_and(&mut self, x: &AuthBit, y: &AuthBit) -> Result { + fn half_and( + &mut self, + x: &AuthBit, + y: &AuthBit, + ) -> Result { /// Obtain the least significant bit of some hash output fn lsb(input: &[u8]) -> bool { (input[input.len() - 1] & 1) != 0 @@ -527,14 +584,9 @@ impl Party { } = hashes_message { debug_assert_eq!(to, self.id); - let their_mac = x - .macs - .iter() - .find(|(party, _mac)| *party == from) - .expect("should have MACs from all other parties") - .1; + let their_mac = x.macs[from]; let hash_lsb = lsb(&hash_to_mac_width(domain_separator, &their_mac)); - let t_j = if x.bit.value { + let t_j = if x.bit { hash_j_1 ^ hash_lsb } else { hash_j_0 ^ hash_lsb @@ -556,12 +608,7 @@ impl Party { s_js[j] = s_j; // K_i[x^j] - let input_0 = x - .mac_keys - .iter() - .find(|key| key.bit_holder == j) - .expect("should have keys for all other parties") - .mac_key; + let input_0 = x.keys[j]; // K_i[x^j] xor Delta_i let mut input_1 = [0u8; MAC_LENGTH]; @@ -570,7 +617,7 @@ impl Party { } let h_0 = lsb(&hash_to_mac_width(domain_separator, &input_0)) ^ s_j; - let h_1 = lsb(&hash_to_mac_width(domain_separator, &input_1)) ^ s_j ^ y.bit.value; + let h_1 = lsb(&hash_to_mac_width(domain_separator, &input_1)) ^ s_j ^ y.bit; self.channels.parties[j] .send(Message { from: self.id, @@ -590,14 +637,10 @@ impl Party { } = hashes_message { debug_assert_eq!(to, self.id); - let their_mac = x - .macs - .iter() - .find(|(party, _mac)| *party == from) - .expect("should have MACs from all other parties") - .1; + let their_mac = x.macs[from]; + let hash_lsb = lsb(&hash_to_mac_width(domain_separator, &their_mac)); - let t_j = if x.bit.value { + let t_j = if x.bit { hash_j_1 ^ hash_lsb } else { hash_j_0 ^ hash_lsb @@ -622,9 +665,19 @@ impl Party { } /// Compute authenticated AND triples. - fn random_leaky_and(&mut self, len: usize) -> Result, Error> { + fn random_leaky_and( + &mut self, + len: usize, + ) -> Result< + Vec<( + AuthBit, + AuthBit, + AuthBit, + )>, + Error, + > { let mut results = Vec::new(); - let mut shares: Vec = self.ashare_pool.drain(..3 * len).collect(); + let mut shares: Vec> = self.ashare_pool.drain(..3 * len).collect(); for _i in 0..len { let x = shares.pop().expect("requested enough authenticated bits"); let y = shares.pop().expect("requested enough authenticated bits"); @@ -632,21 +685,21 @@ impl Party { let v_i = self.half_and(&x, &y)?; - let z_i_value = (y.bit.value && x.bit.value) ^ v_i; - let e_i_value = z_i_value ^ r.bit.value; + let z_i_value = (y.bit && x.bit) ^ v_i; + let e_i_value = z_i_value ^ r.bit; let other_e_is = self.broadcast(&[e_i_value as u8])?; - for key in r.mac_keys.iter_mut() { + for (bit_holder, key) in r.keys.iter_mut().enumerate() { let (_, other_e_j) = other_e_is .iter() - .find(|(party, _)| *party == key.bit_holder) + .find(|(party, _)| *party == bit_holder) .expect("should have received e_j from every other party j"); let correction_necessary = other_e_j[0] != 0; if correction_necessary { - key.mac_key = xor_mac_width(&key.mac_key, &self.global_mac_key); + *key = xor_mac_width(&key, &self.global_mac_key); } } - r.bit.value = z_i_value; + r.bit = z_i_value; let z = r; self.sync().expect("sync should always succeed"); @@ -654,17 +707,14 @@ impl Party { // Triple Check // 4. compute Phi let mut phi = [0u8; MAC_LENGTH]; - for key in y.mac_keys.iter() { - let (_, their_mac) = y - .macs - .iter() - .find(|(maccing_party, _)| *maccing_party == key.bit_holder) - .unwrap(); - let intermediate_xor = xor_mac_width(&key.mac_key, their_mac); + for (bit_holder, key) in y.keys.iter().enumerate() { + let their_mac = y.macs[bit_holder]; + + let intermediate_xor = xor_mac_width(&key, &their_mac); phi = xor_mac_width(&phi, &intermediate_xor); } - if y.bit.value { + if y.bit { phi = xor_mac_width(&phi, &self.global_mac_key); } @@ -682,13 +732,10 @@ impl Party { { debug_assert_eq!(self.id, to); // compute M_phi - let (_, their_mac) = x - .macs - .iter() - .find(|(maccing_party, _)| *maccing_party == from) - .expect("should have MACs from all other parties"); - let mut mac_phi = hash_to_mac_width(domain_separator_triple, their_mac); - if x.bit.value { + let their_mac = x.macs[from]; + + let mut mac_phi = hash_to_mac_width(domain_separator_triple, &their_mac); + if x.bit { for byte in 0..MAC_LENGTH { mac_phi[byte] ^= u[byte]; } @@ -705,19 +752,15 @@ impl Party { continue; } // compute k_phi - let my_key = x - .mac_keys - .iter() - .find(|k| k.bit_holder == j) - .expect("should have keys for all other parties' bits"); + let my_key = x.keys[j]; - let k_phi = hash_to_mac_width(domain_separator_triple, &my_key.mac_key); + let k_phi = hash_to_mac_width(domain_separator_triple, &my_key); key_phis.push((j, k_phi)); // compute U_j let u_j_hash = hash_to_mac_width( domain_separator_triple, - &xor_mac_width(&my_key.mac_key, &self.global_mac_key), + &xor_mac_width(&my_key, &self.global_mac_key), ); let u_j = xor_mac_width(&u_j_hash, &k_phi); let u_j = xor_mac_width(&u_j, &phi); @@ -742,13 +785,10 @@ impl Party { { debug_assert_eq!(self.id, to); // compute M_phi - let (_, their_mac) = x - .macs - .iter() - .find(|(maccing_party, _)| *maccing_party == from) - .expect("should have MACs from all other parties"); - let mut mac_phi = hash_to_mac_width(domain_separator_triple, their_mac); - if x.bit.value { + let their_mac = x.macs[from]; + + let mut mac_phi = hash_to_mac_width(domain_separator_triple, &their_mac); + if x.bit { for byte in 0..MAC_LENGTH { mac_phi[byte] ^= u[byte]; } @@ -773,20 +813,17 @@ impl Party { h = xor_mac_width(&h, &intermediate_xor); } - for key in z.mac_keys.iter() { - let (_, their_mac) = z - .macs - .iter() - .find(|(maccing_party, _)| key.bit_holder == *maccing_party) - .expect("should have MACs from all other parties"); - let intermediate_xor = xor_mac_width(&key.mac_key, their_mac); + for (bit_holder, key) in z.keys.iter().enumerate() { + let their_mac = z.macs[bit_holder]; + + let intermediate_xor = xor_mac_width(&key, &their_mac); h = xor_mac_width(&h, &intermediate_xor); } - if x.bit.value { + if x.bit { h = xor_mac_width(&h, &phi); } - if z.bit.value { + if z.bit { h = xor_mac_width(&h, &self.global_mac_key); } @@ -815,7 +852,7 @@ impl Party { } /// Verifiably open an authenticated bit, revealing its value to all parties. - fn open_bit(&mut self, bit: &AuthBit) -> Result { + fn open_bit(&mut self, bit: &AuthBit) -> Result { let mut other_bits = Vec::new(); // receive earlier parties MACs and verify them @@ -828,12 +865,9 @@ impl Party { } = reveal_message { debug_assert_eq!(self.id, to); - let my_key = bit - .mac_keys - .iter() - .find(|k| k.bit_holder == from) - .expect("should have a key for every other party"); - if !verify_mac(&value, &mac, &my_key.mac_key, &self.global_mac_key) { + let my_key = bit.keys[from]; + + if !verify_mac(&value, &mac, &my_key, &self.global_mac_key) { return Err(Error::CheckFailed("Bit reveal failed".to_string())); } other_bits.push((from, value)); @@ -847,16 +881,13 @@ impl Party { if j == self.id { continue; } - let (_, their_mac) = bit - .macs - .iter() - .find(|(maccing_party, _mac)| j == *maccing_party) - .expect("should have MACs from all other parties"); + let their_mac = bit.macs[j]; + self.channels.parties[j] .send(Message { from: self.id, to: j, - payload: MessagePayload::BitReveal(bit.bit.value, *their_mac), + payload: MessagePayload::BitReveal(bit.bit, their_mac), }) .unwrap(); } @@ -871,12 +902,9 @@ impl Party { } = reveal_message { debug_assert_eq!(self.id, to); - let my_key = bit - .mac_keys - .iter() - .find(|k| k.bit_holder == from) - .expect("should have a key for every other party"); - if !verify_mac(&value, &mac, &my_key.mac_key, &self.global_mac_key) { + let my_key = bit.keys[from]; + + if !verify_mac(&value, &mac, &my_key, &self.global_mac_key) { return Err(Error::CheckFailed("Bit reveal failed".to_string())); } other_bits.push((from, value)); @@ -885,7 +913,7 @@ impl Party { } } - let mut result = bit.bit.value; + let mut result = bit.bit; for (_, other_bit) in other_bits { result ^= other_bit } @@ -895,72 +923,29 @@ impl Party { Ok(result) } - /// Locally compute the XOR of two authenticated bits, which will itself be - /// authenticated already. - fn xor_abits(&mut self, a: &AuthBit, b: &AuthBit) -> AuthBit { - let mut macs = Vec::new(); - for (maccing_party, mac) in a.macs.iter() { - let mut xored_mac = [0u8; MAC_LENGTH]; - let other_mac = b - .macs - .iter() - .find(|(party, _)| *party == *maccing_party) - .expect("should have MACs from all other parties") - .1; - for byte in 0..MAC_LENGTH { - xored_mac[byte] = mac[byte] ^ other_mac[byte]; - } - macs.push((*maccing_party, xored_mac)) - } - - let mut mac_keys = Vec::new(); - for key in a.mac_keys.iter() { - let mut xored_key = [0u8; MAC_LENGTH]; - let other_key = b - .mac_keys - .iter() - .find(|other_key| key.bit_holder == other_key.bit_holder) - .expect("should have two MAC keys for every other party") - .mac_key; - for byte in 0..MAC_LENGTH { - xored_key[byte] = key.mac_key[byte] ^ other_key[byte]; - } - mac_keys.push(BitKey { - holder_bit_id: BitID(0), // XXX: We can't know their bit ID here, is it necessary for anything though? - bit_holder: key.bit_holder, - mac_key: xored_key, - }) - } - - AuthBit { - bit: Bit { - id: self.fresh_bit_id(), - value: a.bit.value ^ b.bit.value, - }, - macs, - mac_keys, - } - } - fn and_abits( &mut self, - random_triple: (AuthBit, AuthBit, AuthBit), - x: &AuthBit, - y: &AuthBit, - ) -> Result { + random_triple: ( + AuthBit, + AuthBit, + AuthBit, + ), + x: &AuthBit, + y: &AuthBit, + ) -> Result, Error> { let (a, b, c) = random_triple; - let blinded_x_share = self.xor_abits(x, &a); - let blinded_y_share = self.xor_abits(y, &b); + let blinded_x_share = xor(x, &a); + let blinded_y_share = xor(y, &b); let blinded_x = self.open_bit(&blinded_x_share)?; let blinded_y = self.open_bit(&blinded_y_share)?; let mut result = c; if blinded_x { - result = self.xor_abits(&result, &y); + result = xor(&result, &y); } if !blinded_y { - result = self.xor_abits(&result, &a); + result = xor(&result, &a); } Ok(result) @@ -968,19 +953,17 @@ impl Party { /// Invert an authenticated bit, resulting in an authentication of the /// inverted bit. - fn invert_abit(&mut self, a: &AuthBit) -> AuthBit { - let mut mac_keys = a.mac_keys.clone(); + fn invert_abit(&mut self, a: &AuthBit) -> AuthBit { + let mut mac_keys = a.keys.clone(); for key in mac_keys.iter_mut() { - key.mac_key = xor_mac_width(&key.mac_key, &self.global_mac_key) + *key = xor_mac_width(&key, &self.global_mac_key) } AuthBit { - bit: Bit { - id: self.fresh_bit_id(), - value: a.bit.value ^ true, - }, + bit: a.bit ^ true, + macs: a.macs.clone(), - mac_keys, + keys: mac_keys, } } @@ -989,7 +972,14 @@ impl Party { &mut self, len: usize, bucket_size: usize, - ) -> Result, Error> { + ) -> Result< + Vec<( + AuthBit, + AuthBit, + AuthBit, + )>, + Error, + > { // get `len * BUCKET_SIZE` leaky ANDs let leaky_ands = self.random_leaky_and(len * bucket_size)?; @@ -997,7 +987,14 @@ impl Party { // Using random u128 bit indices for shuffling should prevent collisions // for at least 2^40 triples except with probability 2^-40. let random_indices = self.coin_flip(leaky_ands.len() * 8 * 16)?; - let mut indexed_ands: Vec<(u128, (AuthBit, AuthBit, AuthBit))> = random_indices + let mut indexed_ands: Vec<( + u128, + ( + AuthBit, + AuthBit, + AuthBit, + ), + )> = random_indices .chunks_exact(16) .map(|chunk| { u128::from_be_bytes(chunk.try_into().expect("chunks are exactly the right size")) @@ -1005,8 +1002,11 @@ impl Party { .zip(leaky_ands) .collect(); indexed_ands.sort_by_key(|(index, _)| *index); - let leaky_ands: Vec<&(AuthBit, AuthBit, AuthBit)> = - indexed_ands.iter().map(|(_, triple)| triple).collect(); + let leaky_ands: Vec<&( + AuthBit, + AuthBit, + AuthBit, + )> = indexed_ands.iter().map(|(_, triple)| triple).collect(); // combine all buckets to single ANDs let mut results = Vec::new(); @@ -1014,13 +1014,13 @@ impl Party { let (mut x, y, mut z) = bucket[0].clone(); for (next_x, next_y, next_z) in bucket[1..].iter() { - let d_i = self.xor_abits(&y, next_y); + let d_i = xor(&y, next_y); let d = self.open_bit(&d_i)?; - x = self.xor_abits(&x, next_x); - z = self.xor_abits(&z, next_z); + x = xor(&x, next_x); + z = xor(&z, next_z); if d { - z = self.xor_abits(&z, next_x); + z = xor(&z, next_x); } } results.push((x, y, z)); @@ -1030,7 +1030,7 @@ impl Party { } /// Perform the active_security check for bit authentication - fn bit_auth_check(&mut self, auth_bits: &[AuthBit]) -> Result<(), Error> { + fn bit_auth_check(&mut self, auth_bits: &[AuthBit]) -> Result<(), Error> { for _j in 0..SEC_MARGIN_BIT_AUTH { // a) Sample `ell'` random bit.s let r = self.coin_flip(auth_bits.len())?; @@ -1038,7 +1038,7 @@ impl Party { // b) Compute x_j = XOR_{m in [ell']} r_m & x_m let mut x_j = false; for (m, xm) in auth_bits.iter().enumerate() { - x_j ^= ith_bit(m, &r) & xm.bit.value; + x_j ^= ith_bit(m, &r) & xm.bit; } // broadcast x_j @@ -1057,14 +1057,12 @@ impl Party { let mut xored_tags = vec![[0u8; MAC_LENGTH]; self.num_parties]; for (m, xm) in auth_bits.iter().enumerate() { if ith_bit(m, &r) { - for mac_keys in xm.mac_keys.iter() { - for byte in 0..mac_keys.mac_key.len() { - xored_keys[mac_keys.bit_holder][byte] ^= mac_keys.mac_key[byte]; - } + for (bit_holder, key) in xm.keys.iter().enumerate() { + xored_keys[bit_holder] = xor_mac_width(&xored_keys[bit_holder], key); } - for (key_holder, tag) in xm.macs.iter() { + for (key_holder, tag) in xm.macs.iter().enumerate() { for (index, tag_byte) in tag.iter().enumerate() { - xored_tags[*key_holder][index] ^= *tag_byte; + xored_tags[key_holder][index] ^= *tag_byte; } } } @@ -1229,11 +1227,79 @@ impl Party { } } - /// Generate a fresh bit id, increasing the internal bit counter. - fn fresh_bit_id(&mut self) -> BitID { - let res = self.bit_counter; - self.bit_counter += 1; - BitID(res) + fn batched_bit_auth_receiver( + &mut self, + authenticator: usize, + local_bits: &[bool], + ) -> Result, Error> { + let (my_address, my_inbox) = mpsc::channel::(); + let (their_address, their_inbox) = mpsc::channel::(); + + // The authenticator has to initiate an OT session, so request a bit + // authentication session using the generated channels. + self.channels.parties[authenticator] + .send(Message { + from: self.id, + to: authenticator, + payload: MessagePayload::RequestBitAuth(my_address, their_inbox), + }) + .expect("all parties should be online"); + + // Join the authenticator's OT session with the local bit value as the + // receiver choice input. + let received_macs: Vec = crate::primitives::kos::kos_receive( + &local_bits, + their_address, + my_inbox, + authenticator, + self.id, + &mut self.entropy, + ) + .unwrap(); + + Ok(received_macs) + } + + fn batched_bit_auth_sender(&mut self, len: usize) -> Result, Error> { + let request_msg = self + .channels + .listen + .recv() + .expect("all parties should be online"); + + if let Message { + to, + from, + payload: MessagePayload::RequestBitAuth(their_address, my_inbox), + } = request_msg + { + debug_assert_eq!(to, self.id, "Got a wrongly addressed message"); + + let mut kos_inputs = Vec::new(); + for i in 0..len { + let input = mac(&true, &self.global_mac_key, &mut self.entropy); + kos_inputs.push(input) + } + + // Initiate an OT session with the bit holder giving the two MACs as + // sender inputs. + kos_send( + &kos_inputs, + their_address, + my_inbox, + from, + self.id, + &mut self.entropy, + ) + .unwrap_or_default(); + + let keys = kos_inputs.into_iter().map(|(_l, r)| r).collect(); + + Ok(keys) + } else { + self.log(&format!("Bit Auth: Unexpected message {request_msg:?}")); + Err(Error::UnexpectedMessage(request_msg)) + } } /// Initiate a two-party bit authentication session to oblivious obtain a @@ -1248,7 +1314,7 @@ impl Party { fn obtain_bit_authentication( &mut self, authenticator: usize, - local_bit: &Bit, + local_bit: bool, ) -> Result { // Set up channels for an OT subprotocol session with the authenticator. let (my_address, my_inbox) = mpsc::channel::(); @@ -1260,18 +1326,14 @@ impl Party { .send(Message { from: self.id, to: authenticator, - payload: MessagePayload::RequestBitAuth( - local_bit.id.clone(), - my_address, - their_inbox, - ), + payload: MessagePayload::RequestBitAuth(my_address, their_inbox), }) .expect("all parties should be online"); // Join the authenticator's OT session with the local bit value as the // receiver choice input. let received_mac: Mac = self - .ot_receive(local_bit.value, their_address, my_inbox, authenticator)? + .ot_receive(local_bit, their_address, my_inbox, authenticator)? .try_into() .expect("should receive a MAC of the right length"); @@ -1288,7 +1350,7 @@ impl Party { /// thus obliviously obtain a MAC `M = K + b * Delta` by setting `b` as /// their choice bit as an OT receiver with the authenticator acting as OT /// sender with inputs `left_value` and `right value`. - fn provide_bit_authentication(&mut self, bit_holder: usize) -> Result { + fn provide_bit_authentication(&mut self) -> Result { let request_msg = self .channels .listen @@ -1298,7 +1360,7 @@ impl Party { if let Message { to, from, - payload: MessagePayload::RequestBitAuth(holder_bit_id, their_address, my_inbox), + payload: MessagePayload::RequestBitAuth(their_address, my_inbox), } = request_msg { debug_assert_eq!(to, self.id, "Got a wrongly addressed message"); @@ -1311,11 +1373,7 @@ impl Party { // sender inputs. self.ot_send(their_address, my_inbox, from, &mac_on_true, &mac_on_false)?; - Ok(BitKey { - holder_bit_id, - bit_holder, - mac_key: mac_on_false, - }) + Ok(mac_on_false) } else { self.log(&format!("Bit Auth: Unexpected message {request_msg:?}")); Err(Error::UnexpectedMessage(request_msg)) @@ -1354,7 +1412,7 @@ impl Party { fn function_dependent( &mut self, circuit: &Circuit, - ) -> Result<(Vec, Vec<(usize, u8, AuthBit)>), Error> { + ) -> Result<(Vec, Vec<(usize, u8, AuthBit)>), Error> { let num_and_triples = circuit.num_and_gates(); let mut and_shares = self .random_and_shares(num_and_triples, circuit.and_bucket_size()) @@ -1372,7 +1430,7 @@ impl Party { .clone() .expect("should have shares for all earlier wires already"); - let xor_share = self.xor_abits(&share_left.0, &share_right.0); + let xor_share = xor(&share_left.0, &share_right.0); if self.is_evaluator() { self.wire_shares[gate_index] = Some((xor_share, None)); } else { @@ -1405,14 +1463,14 @@ impl Party { .clone() .expect("should have labels for all AND gate output wires"); - let and_0 = self.xor_abits(&and_output_share.0, &and_share); - let and_1 = self.xor_abits(&and_0, &share_left.0); - let and_2 = self.xor_abits(&and_0, &share_right.0); - let mut and_3 = self.xor_abits(&and_1, &share_right.0); + let and_0 = xor(&and_output_share.0, &and_share); + let and_1 = xor(&and_0, &share_left.0); + let and_2 = xor(&and_0, &share_right.0); + let mut and_3 = xor(&and_1, &share_right.0); if self.is_evaluator() { // do local computation and receive values - and_3.bit.value ^= true; + and_3.bit ^= true; for _j in 1..self.num_parties { let garbled_and_message = self.channels.listen.recv().unwrap(); @@ -1451,13 +1509,9 @@ impl Party { local_ands.push((gate_index, 3, and_3)); } else { // do local computation and send values - let evaluator_key = and_3 - .mac_keys - .iter_mut() - .find(|key| key.bit_holder == EVALUATOR_ID) - .expect("should have key for evaluator"); - evaluator_key.mac_key = - xor_mac_width(&evaluator_key.mac_key, &self.global_mac_key); + let mut evaluator_key = and_3.keys[EVALUATOR_ID]; + + evaluator_key = xor_mac_width(&evaluator_key, &self.global_mac_key); let WireLabel(left_label) = share_left .1 @@ -1583,13 +1637,9 @@ impl Party { { debug_assert_eq!(to, self.id); // verify mac - let my_key = wire_share - .0 - .mac_keys - .iter() - .find(|key| key.bit_holder == from) - .expect("should have keys for all other parties"); - if !verify_mac(&r_j, &mac_j, &my_key.mac_key, &self.global_mac_key) { + let my_key = wire_share.0.keys[from]; + + if !verify_mac(&r_j, &mac_j, &my_key, &self.global_mac_key) { return Err(Error::CheckFailed( "invalid input wire MAC ".to_owned(), )); @@ -1601,7 +1651,7 @@ impl Party { } // compute blinded input value - masked_wire_value = input_value ^ wire_share.0.bit.value; + masked_wire_value = input_value ^ wire_share.0.bit; for bit in other_wire_mask_shares { masked_wire_value ^= bit; } @@ -1624,18 +1674,13 @@ impl Party { self.broadcast(&vec![masked_wire_value as u8])?; } else { // send input wire shares to the party - let their_mac = wire_share - .0 - .macs - .iter() - .find(|(maccing_party, _)| *maccing_party == party) - .expect("should have macs from all other parties") - .1; + let their_mac = wire_share.0.macs[party]; + self.channels.parties[party] .send(Message { from: self.id, to: party, - payload: MessagePayload::WireMac(wire_share.0.bit.value, their_mac), + payload: MessagePayload::WireMac(wire_share.0.bit, their_mac), }) .unwrap(); @@ -1739,7 +1784,7 @@ impl Party { &mut self, circuit: &Circuit, garbled_ands: Vec, - local_ands: Vec<(usize, u8, AuthBit)>, + local_ands: Vec<(usize, u8, AuthBit)>, masked_input_wire_values: Vec<(usize, bool)>, input_wire_labels: Vec<(usize, usize, [u8; MAC_LENGTH])>, ) -> Result<(Vec<(usize, bool)>, Vec<(usize, usize, [u8; 16])>), Error> { @@ -1800,7 +1845,7 @@ impl Party { .expect("should have labels and mask for all earlier wires") .1; - let mut masked_output_value = output_wire_share.bit.value; + let mut masked_output_value = output_wire_share.bit; let mut this_wires_labels = Vec::new(); for j in 1..self.num_parties { let garble_index = @@ -1843,11 +1888,9 @@ impl Party { }) .expect("should have keys for all other parties' MACs") .2 - .mac_keys - .iter() - .find(|k| k.bit_holder == j) - .unwrap(); - if !verify_mac(&r_j, &my_mac, &my_key.mac_key, &self.global_mac_key) { + .keys[j]; + + if !verify_mac(&r_j, &my_mac, &my_key, &self.global_mac_key) { return Err(Error::CheckFailed( "AND gate evaluation: MAC check failed".to_owned(), )); @@ -1915,17 +1958,9 @@ impl Party { { debug_assert_eq!(to, self.id); // verify mac - let my_key = output_wire_share - .mac_keys - .iter() - .find(|key| key.bit_holder == from) - .expect("should have keys for all other parties"); - if !verify_mac( - &wire_mask_share, - &mac, - &my_key.mac_key, - &self.global_mac_key, - ) { + let my_key = output_wire_share.keys[from]; + + if !verify_mac(&wire_mask_share, &mac, &my_key, &self.global_mac_key) { return Err(Error::CheckFailed("invalid nput wire MAC ".to_owned())); } output_wire_value ^= wire_mask_share; @@ -1951,16 +1986,13 @@ impl Party { } } else { // send output wire mask shares - let evaluator_mac = output_wire_share.macs[EVALUATOR_ID].1; + let evaluator_mac = output_wire_share.macs[EVALUATOR_ID]; self.channels .evaluator .send(Message { from: self.id, to: EVALUATOR_ID, - payload: MessagePayload::WireMac( - output_wire_share.bit.value, - evaluator_mac, - ), + payload: MessagePayload::WireMac(output_wire_share.bit, evaluator_mac), }) .unwrap(); @@ -1982,12 +2014,9 @@ impl Party { /// Run the MPC protocol, returning the parties output, if any. pub fn run( &mut self, - read_stored_triples: bool, circuit: &Circuit, input: &[bool], - ) -> Result>, Error> { - use std::io::Write; - + ) -> Result<(usize, Option>), Error> { // Validate the circuit circuit .validate_circuit_specification() @@ -1998,43 +2027,10 @@ impl Party { panic!("Invalid input provided to party {}", self.id) } - let num_auth_shares = circuit.share_authentication_cost() + SEC_MARGIN_SHARE_AUTH; - - if read_stored_triples { - let file = std::fs::File::open(format!("{}.triples", self.id)); - if let Ok(f) = file { - (self.global_mac_key, self.abit_pool) = - serde_json::from_reader(f).map_err(|_| Error::OtherError)?; - - let max_id = self - .abit_pool - .iter() - .max_by_key(|abit| abit.bit.id.0) - .map(|abit| abit.bit.id.0) - .unwrap_or(0); - self.bit_counter = max_id; - - if num_auth_shares > self.abit_pool.len() { - self.log(&format!( - "Insufficient precomputation (by {})", - num_auth_shares - self.abit_pool.len() - )); - return Ok(None); - } - } - } else { - let target_number = circuit.share_authentication_cost(); - - self.abit_pool = self.precompute_abits(target_number + SEC_MARGIN_SHARE_AUTH)?; - - let file = std::fs::File::create(format!("{}.triples", self.id)) - .map_err(|_| Error::OtherError)?; - let mut writer = std::io::BufWriter::new(file); - serde_json::to_writer(&mut writer, &(self.global_mac_key, &self.abit_pool)) - .map_err(|_| Error::OtherError)?; - writer.flush().unwrap(); - } + let target_number = circuit.share_authentication_cost(); + // self.abit_pool = self.precompute_abits(target_number + SEC_MARGIN_SHARE_AUTH)?; + self.abit_pool = self.batch_precompute_abits(target_number + SEC_MARGIN_SHARE_AUTH)?; self.function_independent(circuit).unwrap(); let (garbled_ands, local_ands) = self.function_dependent(circuit).unwrap(); @@ -2075,11 +2071,11 @@ impl Party { result }; - Ok(if result.is_empty() { - Some(result) + if !result.is_empty() { + Ok((self.id, Some(result))) } else { - None - }) + Ok((self.id, None)) + } } /// Synchronise parties. @@ -2150,7 +2146,7 @@ impl Party { &self, gate_index: usize, garble_index: u8, - and_share: AuthBit, + and_share: AuthBit, output_label: [u8; 16], left_label: [u8; 16], right_label: [u8; 16], @@ -2178,7 +2174,7 @@ impl Party { garbled_and: &[u8], left_label: [u8; 16], right_label: [u8; 16], - ) -> Result<(bool, Vec<[u8; MAC_LENGTH]>, [u8; MAC_LENGTH]), Error> { + ) -> Result<(bool, [Mac; NUM_PARTIES], [u8; MAC_LENGTH]), Error> { let blinding: Vec = compute_blinding( garbled_and.len(), left_label, @@ -2196,14 +2192,18 @@ impl Party { } /// Serialize an authenticated wire share for garbling AND gates. - fn garbling_serialize(&self, and_share: AuthBit, output_label: [u8; 16]) -> Vec { + fn garbling_serialize( + &self, + and_share: AuthBit, + output_label: [u8; 16], + ) -> Vec { let mut result = and_share.serialize_bit_macs(); let mut garbled_label = output_label; - for key in and_share.mac_keys { - garbled_label = xor_mac_width(&garbled_label, &key.mac_key); + for key in and_share.keys { + garbled_label = xor_mac_width(&garbled_label, &key); } - if and_share.bit.value { + if and_share.bit { garbled_label = xor_mac_width(&garbled_label, &self.global_mac_key); } result.extend_from_slice(&garbled_label); @@ -2214,7 +2214,7 @@ impl Party { fn garbling_deserialize( &self, serialization: &[u8], - ) -> Result<(bool, Vec<[u8; 16]>, [u8; 16]), Error> { + ) -> Result<(bool, [Mac; NUM_PARTIES], [u8; 16]), Error> { let (bit_mac_bytes, label) = serialization.split_at(1 + MAC_LENGTH * self.num_parties); let (bit_value, macs) = AuthBit::deserialize_bit_macs(bit_mac_bytes)?; Ok((bit_value, macs, label.try_into().unwrap())) diff --git a/atlas-spec/mpc-engine/src/primitives/auth_share.rs b/atlas-spec/mpc-engine/src/primitives/auth_share.rs index 2351d6e..5a72649 100644 --- a/atlas-spec/mpc-engine/src/primitives/auth_share.rs +++ b/atlas-spec/mpc-engine/src/primitives/auth_share.rs @@ -1,37 +1,27 @@ //! This module defines the interface for share authentication. -use serde::{Deserialize, Serialize}; +use crate::{ + messages::{Message, MessagePayload}, + party::Party, + primitives::mac::MAC_LENGTH, + Error, +}; -use crate::{primitives::mac::MAC_LENGTH, Error}; +use super::mac::{self, verify_mac, Mac, MacKey}; -use super::mac::{Mac, MacKey}; - -/// A bit held by a party with a given ID. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Bit { - pub(crate) id: BitID, - pub(crate) value: bool, -} -#[derive(Debug, Clone, Serialize, Deserialize)] -/// A bit identifier. -/// -/// This is unique per party, not globally, so if referring bits held by another -/// party, their party ID is also required to disambiguate. -pub struct BitID(pub(crate) usize); - -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] /// A bit authenticated between two parties. -pub struct AuthBit { - pub(crate) bit: Bit, - pub(crate) macs: Vec<(usize, Mac)>, - pub(crate) mac_keys: Vec, +pub struct AuthBit { + pub(crate) bit: bool, + pub(crate) macs: [Mac; NUM_PARTIES], + pub(crate) keys: [MacKey; NUM_PARTIES], } -impl AuthBit { +impl AuthBit { /// Serialize the bit value and all MACs on the bit. pub fn serialize_bit_macs(&self) -> Vec { - let mut result = vec![0u8; (self.macs.len() + 1) * MAC_LENGTH + 1]; - result[0] = self.bit.value as u8; - for (key_holder, mac) in self.macs.iter() { + let mut result = vec![0u8; NUM_PARTIES * MAC_LENGTH + 1]; + result[0] = self.bit as u8; + for (key_holder, mac) in self.macs.iter().enumerate() { result[1 + key_holder * MAC_LENGTH..1 + (key_holder + 1) * MAC_LENGTH] .copy_from_slice(mac); } @@ -40,7 +30,7 @@ impl AuthBit { } /// Deserialize a bit and MACs on that bit. - pub fn deserialize_bit_macs(bytes: &[u8]) -> Result<(bool, Vec<[u8; MAC_LENGTH]>), Error> { + pub fn deserialize_bit_macs(bytes: &[u8]) -> Result<(bool, [Mac; NUM_PARTIES]), Error> { if bytes[0] > 1 { return Err(Error::InvalidSerialization); } @@ -50,22 +40,92 @@ impl AuthBit { return Err(Error::InvalidSerialization); } - let mut macs: Vec<[u8; MAC_LENGTH]> = Vec::new(); - for mac in mac_chunks { - macs.push( - mac.try_into() - .expect("chunks should be of the required length"), - ) + let mut macs = [mac::zero_mac(); NUM_PARTIES]; + + for (party_index, mac) in mac_chunks.enumerate() { + macs[party_index] = mac + .try_into() + .expect("chunks should be of the required length"); } Ok((bit_value, macs)) } } -/// The key to authenticate a two-party authenticated bit. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BitKey { - pub(crate) holder_bit_id: BitID, - pub(crate) bit_holder: usize, - pub(crate) mac_key: MacKey, +/// Locally compute the XOR of two authenticated bits, which will itself be +/// authenticated already. +pub fn xor( + a: &AuthBit, + b: &AuthBit, +) -> AuthBit { + let mut macs = [mac::zero_mac(); NUM_PARTIES]; + + for (maccing_party, mac) in a.macs.iter().enumerate() { + let mut xored_mac = [0u8; MAC_LENGTH]; + let other_mac = b.macs[maccing_party]; + + for byte in 0..MAC_LENGTH { + xored_mac[byte] = mac[byte] ^ other_mac[byte]; + } + macs[maccing_party] = xored_mac; + } + + let mut mac_keys = [mac::zero_key(); NUM_PARTIES]; + for (bit_holder, key) in a.keys.iter().enumerate() { + let mut xored_key = [0u8; MAC_LENGTH]; + let other_key = b.keys[bit_holder]; + + for byte in 0..MAC_LENGTH { + xored_key[byte] = key[byte] ^ other_key[byte]; + } + mac_keys[bit_holder] = xored_key; + } + + AuthBit { + bit: a.bit ^ b.bit, + macs, + keys: mac_keys, + } +} + +#[test] +fn serialization() { + let macs_1 = [ + [1u8; MAC_LENGTH], + [2; MAC_LENGTH], + [3; MAC_LENGTH], + [4; MAC_LENGTH], + ]; + let macs_2 = [ + [11u8; MAC_LENGTH], + [22; MAC_LENGTH], + [33; MAC_LENGTH], + [44; MAC_LENGTH], + ]; + let keys = [ + [5u8; MAC_LENGTH], + [6; MAC_LENGTH], + [7; MAC_LENGTH], + [8; MAC_LENGTH], + ]; + let test_bit_1 = AuthBit { + bit: true, + macs: macs_1, + keys, + }; + let test_bit_2 = AuthBit { + bit: false, + macs: macs_2, + keys, + }; + + let (bit_1, deserialized_macs_1) = + AuthBit::<4>::deserialize_bit_macs(&test_bit_1.serialize_bit_macs()).unwrap(); + + let (bit_2, deserialized_macs_2) = + AuthBit::<4>::deserialize_bit_macs(&test_bit_2.serialize_bit_macs()).unwrap(); + assert_eq!(bit_1, true); + assert_eq!(bit_2, false); + assert_eq!(deserialized_macs_1, macs_1); + assert_eq!(deserialized_macs_2, macs_2); } diff --git a/atlas-spec/mpc-engine/src/primitives/kos.rs b/atlas-spec/mpc-engine/src/primitives/kos.rs new file mode 100644 index 0000000..39c6f8e --- /dev/null +++ b/atlas-spec/mpc-engine/src/primitives/kos.rs @@ -0,0 +1,444 @@ +//! The KOS OT extension +//! +//! Computational security parameter is fixed to 128. + +#![allow(non_snake_case)] +use std::sync::mpsc::{Receiver, Sender}; + +use hacspec_lib::Randomness; +use hmac::{hkdf_expand, hkdf_extract}; + +use crate::{ + messages::SubMessage, + primitives::{kos_base, mac::zero_mac}, + utils::{ith_bit, pack_bits, xor_slices}, +}; + +use super::{ + kos_base::{BaseOTReceiver, BaseOTSender, ReceiverChoose, ReceiverResponse, SenderTransfer}, + mac::{xor_mac_width, Mac, MAC_LENGTH}, +}; + +const BASE_OT_LEN: usize = 128; + +#[derive(Debug)] +/// An Error in the KOS OT extension +pub enum Error { + /// An Error that occurred in the BaseOT. + BaseOTError, + /// A consistency check has failed. + Consistency, +} + +impl From for Error { + fn from(_value: crate::primitives::kos_base::Error) -> Self { + Self::BaseOTError + } +} + +/// Implements a tweakable correlation robust hash function. +/// +/// Note: This could also be implemented as +/// +/// H(sid|tweak|input) = pi(pi(sid|input) xor tweak) xor pi(sid|input) +/// +/// where pi is an ideal permutation, fixed-key AES in practice. +fn CRF(sid: &[u8], input: &Mac, tweak: usize) -> Mac { + let mut ikm = sid.to_vec(); + ikm.extend_from_slice(&[tweak as u8]); + ikm.extend_from_slice(input); + let prk = hkdf_extract(b"", &ikm); + let result = hkdf_expand(&prk, sid, MAC_LENGTH) + .try_into() + .expect("should have received exactly `MAC_LENGTH` bytes"); + result +} + +fn PRG(sid: &[u8], k: &[u8], len: usize) -> Vec { + let mut ikm = sid.to_vec(); + ikm.extend_from_slice(k); + let prk = hkdf_extract(b"", &ikm); + let result = hkdf_expand(&prk, sid, len); + + result +} + +fn FRO2(sid: &[u8], matrix: &[Vec; 128]) -> Vec { + let mut ikm = sid.to_vec(); + let out_len = matrix[0].len(); + for column in matrix { + ikm.extend_from_slice(column) + } + let prk = hkdf_extract(b"", &ikm); + let result_bytes = hkdf_expand(&prk, sid, out_len * 8 * 16); + let result = result_bytes + .chunks_exact(16) + .map(|chunk| { + u128::from_be_bytes( + chunk + .try_into() + .expect("should be given exactly 16 byte chunks"), + ) + }) + .collect(); + result +} + +/// This implements Xor_{j in [m+k]} (Chi_j * M_j). +/// `selection_matrix` is the whole matrix given as a vector of columns. +fn challenge_selection(challenge: &[u128], selection_matrix: &[Vec; 128]) -> u128 { + let mut result = 0u128; + for i in 0..challenge.len() { + result ^= challenge[i] & packed_row(selection_matrix, i); + } + result +} + +/// Pack all the bits in a row into a `u128`. +/// `matrix` is the whole matrix given as a vector of columns. +fn packed_row(matrix: &[Vec; 128], row_index: usize) -> u128 { + let mut result = 0u128; + for column in 0..128 { + let b = ith_bit(row_index, &matrix[column]); + if b { + result |= 1 << (127 - column); + } + } + result +} + +fn kos_dst(sender_id: usize, receiver_id: usize) -> Vec { + format!("KOS-Base-OT-{}-{}", sender_id, receiver_id) + .as_bytes() + .to_vec() +} + +/// The message sent by the KOS15 Receiver in phase I of the protocol. +#[derive(Debug)] +pub struct KOSReceiverPhaseI { + base_ot_transfer: SenderTransfer, + D: [Vec; 128], + u: u128, + v: u128, +} + +/// The KOS Receiver state. +pub struct KOSReceiver { + selection_bits: Vec, + base_sender: BaseOTSender, + M_columns: [Vec; 128], + sid: Vec, + requested_len: usize, +} + +impl KOSReceiver { + /// `selection.len` must be a multiple of 8 + pub(crate) fn phase_i( + selection: &[bool], + sender_phase_i: KOSSenderPhaseI, + sid: &[u8], + entropy: &mut Randomness, + ) -> Result<(Self, KOSReceiverPhaseI), Error> { + let requested_len = selection.len(); + // Extend selection lenght to next multiple of 8. + let mut selection_padded = vec![false; padded_len(selection.len())]; + selection_padded[0..selection.len()].copy_from_slice(&selection); + let selection = selection_padded.as_slice(); + let (base_sender, base_sender_transfer) = kos_base::BaseOTSender::::transfer( + entropy, + &sid, + sender_phase_i.base_ot_choice, + ); + match base_sender.inputs { + Some(base_sender_inputs) => { + let tau = entropy.bytes(128 / 8).unwrap(); + let mut r_prime = crate::utils::pack_bits(selection); + r_prime.extend_from_slice(&tau); + + let M_columns: [Vec; 128] = std::array::from_fn(|i| { + PRG(&sid, &base_sender_inputs[i].0, 16 + selection.len() / 8) + }); + + let R_columns: [Vec; 128] = std::array::from_fn(|_i| r_prime.clone()); + let D_columns: [Vec; 128] = std::array::from_fn(|i| { + let prg_result = PRG(&sid, &base_sender_inputs[i].1, 16 + selection.len() / 8); + let temp_result = crate::utils::xor_slices(&M_columns[i], &prg_result); + crate::utils::xor_slices(&temp_result, &R_columns[i]) + }); + + let Chi = FRO2(&sid, &D_columns); + + let u = challenge_selection(&Chi, &M_columns); + let v = challenge_selection(&Chi, &R_columns); + + Ok(( + Self { + selection_bits: selection.to_owned(), + base_sender, + M_columns, + sid: sid.to_owned(), + requested_len, + }, + KOSReceiverPhaseI { + base_ot_transfer: base_sender_transfer, + D: D_columns, + u, + v, + }, + )) + } + None => Err(Error::BaseOTError), + } + } + + fn phase_ii(self, sender_phase_ii: KOSSenderPhaseII) -> Result, Error> { + let mut results = Vec::new(); + self.base_sender.verify(sender_phase_ii.base_ot_response)?; + for (index, selection_bit) in self.selection_bits.iter().enumerate() { + let crf_input = packed_row(&self.M_columns, index).to_be_bytes(); + + let crf = CRF(&self.sid, &crf_input, index); + let y = if *selection_bit { + sender_phase_ii.ys[index].1 + } else { + sender_phase_ii.ys[index].0 + }; + let a = xor_mac_width(&y, &crf); + results.push(a) + } + results.truncate(self.requested_len); + Ok(results) + } +} + +pub(crate) struct KOSSender { + base_receiver: BaseOTReceiver, + sid: Vec, +} + +/// The message sent by the KOS15 Sender in phase I of the protocol. +#[derive(Debug)] +pub struct KOSSenderPhaseI { + base_ot_choice: ReceiverChoose, +} + +/// The message sent by the KOS15 Sender in phase II of the protocol. +#[derive(Debug)] +pub struct KOSSenderPhaseII { + ys: Vec<(Mac, Mac)>, + base_ot_response: ReceiverResponse, +} + +fn padded_len(len: usize) -> usize { + if len % 8 == 0 { + len + } else { + len + 8 - len % 8 + } +} + +impl KOSSender { + pub(crate) fn phase_i(sid: &[u8], entropy: &mut Randomness) -> (Self, KOSSenderPhaseI) { + let (base_receiver, base_ot_choice) = + crate::primitives::kos_base::BaseOTReceiver::::choose(entropy, &sid); + + ( + Self { + sid: sid.to_owned(), + base_receiver, + }, + KOSSenderPhaseI { base_ot_choice }, + ) + } + + fn check_uvw(u: u128, v: u128, w: u128, s: u128) -> Result<(), Error> { + if w == u ^ (s & v) { + Ok(()) + } else { + Err(Error::Consistency) + } + } + + /// `inputs.len()` must be a multiple of 8. + fn phase_ii( + &mut self, + inputs: &[(Mac, Mac)], + receiver_phase_i: KOSReceiverPhaseI, + ) -> Result { + let mut inputs_padded = vec![(zero_mac(), zero_mac()); padded_len(inputs.len())]; + inputs_padded[0..inputs.len()].copy_from_slice(inputs); + let inputs = inputs_padded.as_slice(); + + let (base_receiver_output, base_ot_response) = self + .base_receiver + .response(receiver_phase_i.base_ot_transfer) + .unwrap(); + + match self.base_receiver.selection_bits { + Some(base_selection_bits) => { + let Q_columns: [Vec; 128] = std::array::from_fn(|i| { + let mut result = + PRG(&self.sid, &base_receiver_output[i], 16 + inputs.len() / 8); + // the following is obviously secret-dependent timing + if base_selection_bits[i] { + result = crate::utils::xor_slices(&result, &receiver_phase_i.D[i]); + } + + result + }); + + let Chi = FRO2(&self.sid, &receiver_phase_i.D); + + let w = challenge_selection(&Chi, &Q_columns); + + let s = pack_bits(&base_selection_bits); + + let mut s_array = [0u8; 16]; + s_array.copy_from_slice(&s[..16]); + + Self::check_uvw( + receiver_phase_i.u, + receiver_phase_i.v, + w, + u128::from_be_bytes(s_array), + )?; + + let mut ys = Vec::new(); + for (index, (a_0, a_1)) in inputs.iter().enumerate() { + let crf_input_0 = packed_row(&Q_columns, index).to_be_bytes(); + let crf_input_1 = xor_slices(&packed_row(&Q_columns, index).to_be_bytes(), &s); + + let crf_0 = CRF(&self.sid, &crf_input_0, index); + let crf_1 = CRF(&self.sid, &crf_input_1.try_into().unwrap(), index); + + let y_0 = xor_mac_width(a_0, &crf_0); + let y_1 = xor_mac_width(a_1, &crf_1); + ys.push((y_0, y_1)) + } + + Ok(KOSSenderPhaseII { + ys, + base_ot_response, + }) + } + None => Err(Error::BaseOTError), + } + } +} + +/// Run the KOS15 protocol in the role of the receiver. +/// +/// Uses the given Channels to communicate the KOS messages from the +/// perspective of the receiver. The input `selection` determines +/// which of the senders inputs get obliviously transfered to the +/// receiver. +pub(crate) fn kos_receive( + selection: &[bool], + sender_address: Sender, + my_inbox: Receiver, + receiver_id: usize, + sender_id: usize, + entropy: &mut Randomness, +) -> Result, crate::Error> { + let sid = kos_dst(receiver_id, sender_id); + + let sender_phase_i_msg = my_inbox.recv().unwrap(); + if let SubMessage::KOSSenderPhaseI(sender_phase_i) = sender_phase_i_msg { + let (receiver, phase_i) = + KOSReceiver::phase_i(selection, sender_phase_i, &sid, entropy).unwrap(); + sender_address + .send(SubMessage::KOSReceiverPhaseI(phase_i)) + .unwrap(); + let sender_phase_ii_msg = my_inbox.recv().unwrap(); + if let SubMessage::KOSSenderPhaseII(sender_phase_ii) = sender_phase_ii_msg { + let outputs = receiver.phase_ii(sender_phase_ii).unwrap(); + + Ok(outputs) + } else { + Err(crate::Error::UnexpectedSubprotocolMessage( + sender_phase_ii_msg, + )) + } + } else { + Err(crate::Error::UnexpectedSubprotocolMessage( + sender_phase_i_msg, + )) + } +} + +/// Run the KOS15 protocol in the role of the sender. +/// +/// Uses the given Channels to communicate the KOS messages from the +/// perspective of the sender. The receiver's input `selection` +/// determines which of the senders inputs get obliviously transfered +/// to the receiver. +pub(crate) fn kos_send( + inputs: &[(Mac, Mac)], + receiver_address: Sender, + my_inbox: Receiver, + receiver_id: usize, + sender_id: usize, + entropy: &mut Randomness, +) -> Result<(), crate::Error> { + let sid = kos_dst(sender_id, receiver_id); + + let (mut kos_sender, phase_i) = KOSSender::phase_i(&sid, entropy); + receiver_address + .send(SubMessage::KOSSenderPhaseI(phase_i)) + .unwrap(); + let receiver_phase_i_message = my_inbox.recv().unwrap(); + if let SubMessage::KOSReceiverPhaseI(receiver_phase_i) = receiver_phase_i_message { + let phase_ii = kos_sender.phase_ii(inputs, receiver_phase_i).unwrap(); + receiver_address + .send(SubMessage::KOSSenderPhaseII(phase_ii)) + .unwrap(); + Ok(()) + } else { + Err(crate::Error::UnexpectedSubprotocolMessage( + receiver_phase_i_message, + )) + } +} + +#[test] +fn kos_simple() { + // pre-requisites + use rand::{thread_rng, RngCore}; + let sid = b"test"; + let mut rng = thread_rng(); + let mut entropy = [0u8; 100000]; + rng.fill_bytes(&mut entropy); + let mut entropy = Randomness::new(entropy.to_vec()); + + let selection = [true, false, true, false, true, false, true]; + let inputs = [ + ([2u8; 16], [1u8; 16]), + ([2u8; 16], [1u8; 16]), + ([2u8; 16], [1u8; 16]), + ([2u8; 16], [1u8; 16]), + ([2u8; 16], [1u8; 16]), + ([2u8; 16], [1u8; 16]), + ([2u8; 16], [1u8; 16]), + ]; + + let (mut sender, sender_phase_i) = KOSSender::phase_i(sid, &mut entropy); + eprintln!("Sender Phase I complete"); + + let (receiver, receiver_phase_i) = + KOSReceiver::phase_i(&selection, sender_phase_i, sid, &mut entropy).unwrap(); + eprintln!("Receiver Phase I complete"); + + let sender_phase_ii = sender.phase_ii(&inputs, receiver_phase_i).unwrap(); + eprintln!("Sender Phase II complete"); + + let receiver_outputs = receiver.phase_ii(sender_phase_ii).unwrap(); + eprintln!("Receiver Phase II complete"); + + assert_eq!(receiver_outputs[0], [1u8; 16]); + assert_eq!(receiver_outputs[1], [2u8; 16]); + assert_eq!(receiver_outputs[2], [1u8; 16]); + assert_eq!(receiver_outputs[3], [2u8; 16]); + assert_eq!(receiver_outputs[4], [1u8; 16]); + assert_eq!(receiver_outputs[5], [2u8; 16]); + assert_eq!(receiver_outputs[6], [1u8; 16]); +} diff --git a/atlas-spec/mpc-engine/src/primitives/kos_base.rs b/atlas-spec/mpc-engine/src/primitives/kos_base.rs new file mode 100644 index 0000000..2180ac2 --- /dev/null +++ b/atlas-spec/mpc-engine/src/primitives/kos_base.rs @@ -0,0 +1,318 @@ +//! This module implements a base OT for the maliciously secure KOS15 OT extension. +//! +//! BaseOT taken from https://eprint.iacr.org/2020/110.pdf. +#![allow(non_snake_case)] +use std::ops::Neg; + +use crate::COMPUTATIONAL_SECURITY; +use hacspec_lib::{hacspec_helper::NatMod, Randomness}; +use hash_to_curve::p256_hash::hash_to_curve; +use hmac::{hkdf_expand, hkdf_extract}; +use p256::{p256_point_mul, random_scalar, P256Point, P256Scalar}; + +use super::mac::MAC_LENGTH; +type BaseOTSeed = [u8; COMPUTATIONAL_SECURITY]; + +#[derive(Debug)] +pub enum Error { + ReceiverAbort, + SenderCheatDetected, +} + +fn FRO1(seed: &[u8], dst: &[u8]) -> P256Point { + let mut dst = dst.to_vec(); + dst.extend_from_slice(b"F1"); + hash_to_curve(seed, &dst).unwrap() +} + +fn FRO2(point: &P256Point, dst: &[u8]) -> [u8; MAC_LENGTH] { + let mut dst = dst.to_vec(); + dst.extend_from_slice(b"F2"); + let prk = hkdf_extract(b"", &point.raw_bytes()); + let result = hkdf_expand(&prk, &dst, COMPUTATIONAL_SECURITY); + let mut result_array = [0u8; COMPUTATIONAL_SECURITY]; + result_array.copy_from_slice(&result); + result_array +} + +fn FRO3(sender_message: &[u8], dst: &[u8]) -> [u8; COMPUTATIONAL_SECURITY] { + let mut dst = dst.to_vec(); + dst.extend_from_slice(b"F3"); + let prk = hkdf_extract(b"", sender_message); + let result = hkdf_expand(&prk, &dst, COMPUTATIONAL_SECURITY); + let mut result_array = [0u8; COMPUTATIONAL_SECURITY]; + result_array.copy_from_slice(&result); + result_array +} + +fn FRO4( + hashes: &[[u8; COMPUTATIONAL_SECURITY]; L], + dst: &[u8], +) -> [u8; COMPUTATIONAL_SECURITY] { + let mut dst = dst.to_vec(); + dst.extend_from_slice(b"F4"); + let mut input = Vec::new(); + for i in 0..L { + input.extend_from_slice(&hashes[i]); + } + let prk = hkdf_extract(b"", &input); + let result = hkdf_expand(&prk, &dst, COMPUTATIONAL_SECURITY); + let mut result_array = [0u8; COMPUTATIONAL_SECURITY]; + result_array.copy_from_slice(&result); + result_array +} + +pub(crate) struct BaseOTReceiver { + sid: Vec, + T: P256Point, + pub selection_bits: Option<[bool; L]>, + alphas: [P256Scalar; L], +} + +pub(crate) struct BaseOTSender { + sid: Vec, + r: P256Scalar, + pub inputs: Option<[([u8; 16], [u8; 16]); L]>, + expected_answer: [u8; 16], + negTr: P256Point, + chall_hashes: [[u8; COMPUTATIONAL_SECURITY]; L], +} + +#[derive(Debug)] +pub(crate) struct ReceiverChoose { + seed: BaseOTSeed, + messages: [P256Point; L], +} + +#[derive(Debug)] +pub(crate) struct ReceiverResponse { + response: [u8; 16], +} + +#[derive(Debug)] +pub(crate) struct SenderTransfer { + seed: P256Point, + challenge: [[u8; 16]; L], + gamma: [u8; 16], +} + +impl BaseOTReceiver { + pub(crate) fn choose(entropy: &mut Randomness, sid: &[u8]) -> (Self, ReceiverChoose) { + let (mut receiver, seed) = Self::parameters(entropy, sid); + let (bits, messages) = receiver.messages(entropy); + receiver.selection_bits = Some(bits); + (receiver, ReceiverChoose { seed, messages }) + } + + pub(crate) fn response( + &self, + transfer: SenderTransfer, + ) -> Result<([[u8; 16]; L], ReceiverResponse), Error> { + let messages = self.decrypt(transfer.seed); + + match &self.selection_bits { + Some(selection_bits) => { + let response = self.responses(selection_bits, &messages, &transfer.challenge); + self.challenge_verification(&response, &transfer.gamma)?; + Ok((messages, ReceiverResponse { response })) + } + None => Err(Error::ReceiverAbort), + } + } + + fn parameters(entropy: &mut Randomness, sid: &[u8]) -> (Self, BaseOTSeed) { + let mut seed_array = [0u8; COMPUTATIONAL_SECURITY]; + let seed = entropy.bytes(COMPUTATIONAL_SECURITY).unwrap().to_owned(); + seed_array.copy_from_slice(&seed); + let alphas = [P256Scalar::zero(); L]; + + let T = FRO1(&seed_array, sid); + ( + Self { + sid: sid.to_owned(), + T, + selection_bits: None, + alphas, + }, + seed_array, + ) + } + + fn messages(&mut self, entropy: &mut Randomness) -> ([bool; L], [P256Point; L]) { + let mut messages = [P256Point::AtInfinity; L]; + let bits: [bool; L] = std::array::from_fn(|_| entropy.bit().unwrap()); + for i in 0..L { + self.alphas[i] = random_scalar(entropy, &self.sid).unwrap(); + messages[i] = p256::p256_point_mul_base(self.alphas[i]).unwrap(); + if bits[i] { + messages[i] = p256::point_add(messages[i], self.T).unwrap(); + } + } + (bits, messages) + } + + fn decrypt(&self, z: P256Point) -> [[u8; COMPUTATIONAL_SECURITY]; L] { + let mut messages = [[0u8; COMPUTATIONAL_SECURITY]; L]; + for i in 0..L { + let input = p256::p256_point_mul(self.alphas[i], z).unwrap(); + messages[i] = FRO2(&input, &self.sid); + } + messages + } + + fn responses( + &self, + bits: &[bool; L], + messages: &[[u8; MAC_LENGTH]; L], + challenges: &[[u8; COMPUTATIONAL_SECURITY]; L], + ) -> [u8; COMPUTATIONAL_SECURITY] { + let mut responses = [[0u8; COMPUTATIONAL_SECURITY]; L]; + for i in 0..L { + responses[i] = FRO3(&messages[i], &self.sid); + if bits[i] { + responses[i] = xor_arrays(&responses[i], &challenges[i]); + } + } + FRO4(&responses, &self.sid) + } + + fn challenge_verification( + &self, + Ans: &[u8; COMPUTATIONAL_SECURITY], + gamma: &[u8; COMPUTATIONAL_SECURITY], + ) -> Result<(), Error> { + let gamma_prime = FRO3(Ans, &self.sid); + if gamma_prime != *gamma { + eprintln!("challenge verification failed"); + return Err(Error::ReceiverAbort); + } + Ok(()) + } +} + +impl BaseOTSender { + pub(crate) fn transfer( + entropy: &mut Randomness, + sid: &[u8], + choice: ReceiverChoose, + ) -> (Self, SenderTransfer) { + let (mut sender, seed) = Self::parameters(entropy, sid, &choice.seed); + let inputs = sender.generate_inputs(choice.messages); + let challenge = sender.challenges(&inputs); + sender.inputs = Some(inputs); + let (expected_answer, gamma) = sender.proof(); + sender.expected_answer = expected_answer; + ( + sender, + SenderTransfer { + seed, + challenge, + gamma, + }, + ) + } + + pub(crate) fn verify(&self, response: ReceiverResponse) -> Result<(), Error> { + if response.response != self.expected_answer { + Err(Error::SenderCheatDetected) + } else { + Ok(()) + } + } + + fn parameters(entropy: &mut Randomness, sid: &[u8], seed: &BaseOTSeed) -> (Self, P256Point) { + let T = FRO1(seed, sid); + let r = random_scalar(entropy, sid).unwrap(); + let negTr = p256::p256_point_mul(r, T).unwrap().neg(); + let chall_hashes = [[0u8; COMPUTATIONAL_SECURITY]; L]; + let z = p256::p256_point_mul_base(r).unwrap(); + ( + Self { + sid: sid.to_owned().into(), + chall_hashes, + r, + negTr, + inputs: None, + expected_answer: [0u8; 16], + }, + z, + ) + } + + fn generate_inputs( + &self, + receiver_messages: [P256Point; L], + ) -> [([u8; MAC_LENGTH], [u8; MAC_LENGTH]); L] { + let mut messages = [([0u8; MAC_LENGTH], [0u8; MAC_LENGTH]); L]; + for i in 0..L { + let preimg_0 = p256_point_mul(self.r, receiver_messages[i]).unwrap(); + let preimg_1 = p256::point_add(self.negTr, preimg_0).unwrap(); + let pi_0 = FRO2(&preimg_0, &self.sid); + let pi_1 = FRO2(&preimg_1, &self.sid); + messages[i] = (pi_0, pi_1); + } + + messages + } + + fn challenges( + &mut self, + messages: &[([u8; MAC_LENGTH], [u8; MAC_LENGTH]); L], + ) -> [[u8; COMPUTATIONAL_SECURITY]; L] { + let mut challenges = [[0u8; COMPUTATIONAL_SECURITY]; L]; + for i in 0..L { + let chall_hash_0 = FRO3(&messages[i].0, &self.sid); + let chall_hash_1 = FRO3(&messages[i].1, &self.sid); + self.chall_hashes[i] = chall_hash_0; + challenges[i] = xor_arrays(&chall_hash_0, &chall_hash_1); + } + challenges + } + + fn proof(&self) -> ([u8; COMPUTATIONAL_SECURITY], [u8; COMPUTATIONAL_SECURITY]) { + let expected_answer = FRO4(&self.chall_hashes, &self.sid); + let gamma = FRO3(&expected_answer, &self.sid); + (expected_answer, gamma) + } +} + +fn xor_arrays(a: &[u8; L], b: &[u8; L]) -> [u8; L] { + let mut result = [0u8; L]; + for i in 0..L { + result[i] = a[i] ^ b[i]; + } + result +} + +#[test] +fn kos_base_simple() { + // pre-requisites + use rand::{thread_rng, RngCore}; + let sid = b"test"; + let mut rng = thread_rng(); + let mut entropy = [0u8; 100000]; + rng.fill_bytes(&mut entropy); + let mut entropy = Randomness::new(entropy.to_vec()); + + let (receiver, choice_message) = BaseOTReceiver::<5>::choose(&mut entropy, sid); + + let (sender, transfer_message) = BaseOTSender::<5>::transfer(&mut entropy, sid, choice_message); + + let (receiver_outputs, response) = receiver.response(transfer_message).unwrap(); + + sender.verify(response).unwrap(); + + let selection_bits = receiver.selection_bits.unwrap(); + + for (i, selection_bit) in selection_bits.iter().enumerate() { + eprintln! {"{i}:\n\tInput 0: {:?}\n\tInput 1: {:?}\n\tSelection bit: {:?}\n\tOutput: {:?}", sender.inputs.unwrap()[i].0, sender.inputs.unwrap()[i].1, selection_bit, receiver_outputs[i]}; + assert_eq!( + receiver_outputs[i], + if *selection_bit { + sender.inputs.unwrap()[i].1 + } else { + sender.inputs.unwrap()[i].0 + } + ) + } +} diff --git a/atlas-spec/mpc-engine/src/primitives/mac.rs b/atlas-spec/mpc-engine/src/primitives/mac.rs index a3dc0ae..919738e 100644 --- a/atlas-spec/mpc-engine/src/primitives/mac.rs +++ b/atlas-spec/mpc-engine/src/primitives/mac.rs @@ -13,6 +13,16 @@ pub type Mac = [u8; MAC_LENGTH]; /// A MAC key for authenticating a bit to another party. pub type MacKey = [u8; MAC_LENGTH]; +/// Returns an all-zero byte array of MAC width. +pub fn zero_mac() -> Mac { + [0u8; MAC_LENGTH] +} + +/// Returns an all-zero byte array of MAC key width. +pub fn zero_key() -> MacKey { + [0u8; MAC_LENGTH] +} + /// Hash the given input to the width of a MAC. /// /// Instantiates a Random Oracle. diff --git a/atlas-spec/mpc-engine/src/primitives/mod.rs b/atlas-spec/mpc-engine/src/primitives/mod.rs index 3ff4ced..fc09cfb 100644 --- a/atlas-spec/mpc-engine/src/primitives/mod.rs +++ b/atlas-spec/mpc-engine/src/primitives/mod.rs @@ -2,5 +2,7 @@ pub mod auth_share; pub mod commitment; +pub mod kos; +mod kos_base; pub mod mac; pub mod ot; diff --git a/atlas-spec/mpc-engine/src/runner.rs b/atlas-spec/mpc-engine/src/runner.rs new file mode 100644 index 0000000..7476ec6 --- /dev/null +++ b/atlas-spec/mpc-engine/src/runner.rs @@ -0,0 +1,62 @@ +//! This module implements a local MPC runner. +use std::{sync::mpsc, thread}; + +use hacspec_lib::Randomness; +use rand::RngCore; + +use crate::circuit::Circuit; + +/// A local runner for an MPC session based on MPSC channels. +pub struct Runner; + +impl Runner { + /// Set up and run an MPC session of the given circuit with the provided + /// inputs. + pub fn run_mpc( + circuit: &Circuit, + inputs: &[&[bool]], + logging: Vec, + ) -> Vec>> { + let num_parties = inputs.len(); + let (broadcast_relay, party_channels) = crate::utils::set_up_channels(num_parties); + + let _ = thread::spawn(move || broadcast_relay.run()); + let mut results = vec![None; num_parties]; + + let (sender, receiver) = mpsc::channel(); + + let mut party_join_handles = Vec::new(); + for config in party_channels.into_iter() { + let input = inputs[config.id].to_owned(); + let logging = logging.contains(&config.id); + let c = circuit.clone(); + let sender = sender.clone(); + let party_join_handle = thread::spawn(move || { + let mut rng = rand::thread_rng(); + let mut bytes = vec![0u8; 100 * usize::from(u16::MAX)]; + rng.fill_bytes(&mut bytes); + let rng = Randomness::new(bytes); + eprintln!("Starting party {} with input: {:?}", config.id, input); + let mut p = crate::party::Party::new(config, &c, logging, rng); + let result = p.run(&c, &input).unwrap(); + sender.send(result).unwrap(); + }); + party_join_handles.push(party_join_handle); + } + + for _i in 0..num_parties { + let (party, result) = receiver.recv().unwrap(); + + results[party] = result; + } + + for _i in 0..num_parties { + party_join_handles + .pop() + .expect("every party should have a join handle") + .join() + .expect("party did not panic"); + } + results + } +} diff --git a/atlas-spec/mpc-engine/src/utils.rs b/atlas-spec/mpc-engine/src/utils.rs index 6612669..46f3b58 100644 --- a/atlas-spec/mpc-engine/src/utils.rs +++ b/atlas-spec/mpc-engine/src/utils.rs @@ -40,3 +40,69 @@ pub(crate) fn ith_bit(i: usize, bytes: &[u8]) -> bool { let bit_index = 7 - i % 8; ((bytes[byte_index] >> bit_index) & 1u8) == 1u8 } + +/// Pack slice of `bool`s into a byte vector. +/// +/// We assume that `bits.len()` is a multiple of 8. +pub(crate) fn pack_bits(bits: &[bool]) -> Vec { + let mut result = Vec::new(); + let full_blocks = bits.len() / 8; + let remainder = bits.len() % 8; + + debug_assert_eq!(remainder, 0); + + for i in 0..full_blocks { + let mut current_byte = 0u8; + for bit in 0..8 { + current_byte += (bits[i * 8 + bit] as u8) << (7 - bit); + } + result.push(current_byte); + } + + result +} + +pub(crate) fn xor_slices(left: &[u8], right: &[u8]) -> Vec { + debug_assert_eq!(left.len(), right.len()); + let mut result = Vec::with_capacity(left.len()); + for i in 0..left.len() { + result.push(left[i] ^ right[i]) + } + result +} + +#[test] +fn bit_packing() { + let bits1 = [false, false, false, false, false, false, false, true]; + let bits255 = [true, true, true, true, true, true, true, true]; + let bits1255 = [ + false, false, false, false, false, false, false, true, true, true, true, true, true, true, + true, true, + ]; + let bits2551 = [ + true, true, true, true, true, true, true, true, false, false, false, false, false, false, + false, true, + ]; + assert_eq!(pack_bits(&bits1), vec![1]); + assert_eq!(pack_bits(&bits255), vec![255]); + assert_eq!(pack_bits(&bits1255), vec![1, 255]); + assert_eq!(pack_bits(&bits2551), vec![255, 1]); +} + +#[test] +fn select_bits() { + assert_eq!(ith_bit(0, &[255, 1]), true); + assert_eq!(ith_bit(1, &[255, 1]), true); + assert_eq!(ith_bit(15, &[255, 1]), true); + assert_eq!(ith_bit(14, &[255, 1]), false); + assert_eq!(ith_bit(14, &[1, 1, 1, 1]), false); + assert_eq!(ith_bit(16, &[1, 1, 1, 1]), false); + assert_eq!(ith_bit(7, &[1, 1, 1, 1]), true); + assert_eq!(ith_bit(15, &[1, 1, 1, 1]), true); + assert_eq!(ith_bit(23, &[1, 1, 1, 1]), true); + assert_eq!(ith_bit(31, &[1, 1, 1, 1]), true); + assert_eq!(ith_bit(8, &[1, 255, 1, 1]), true); + assert_eq!(ith_bit(10, &[1, 255, 1, 1]), true); + assert_eq!(ith_bit(12, &[1, 255, 1, 1]), true); + assert_eq!(ith_bit(14, &[1, 255, 1, 1]), true); +}