diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 48a8e3ab6..6e243b506 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,7 +26,13 @@ jobs: target/ key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} - uses: dtolnay/rust-toolchain@stable - - run: cargo test --release --no-fail-fast --features pumpkin-solver/check-propagations --features pumpkin-core/check-deductions + - run: | + cargo test \ + --release \ + --no-fail-fast \ + --features pumpkin-solver/check-propagations \ + --features pumpkin-core/check-consistency \ + --features pumpkin-core/check-deductions wasm-test: name: Test Suite for pumpkin-core in WebAssembly diff --git a/Cargo.lock b/Cargo.lock index adbdcb88d..e84e49133 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -120,6 +120,24 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "bit-set" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2f926cc3060f09db9ebc5b52823d85268d24bb917e472c0c4bea35780a7d" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b71798fca2c1fe1086445a7258a4bc81e6e49dcd24c8d0dd9a1e57395b603f51" +dependencies = [ + "serde", +] + [[package]] name = "bitfield" version = "0.19.4" @@ -973,6 +991,7 @@ dependencies = [ name = "pumpkin-core" version = "0.3.0" dependencies = [ + "bit-set", "bitfield", "bitfield-struct", "clap", @@ -992,6 +1011,7 @@ dependencies = [ "once_cell", "pumpkin-checking", "rand", + "replace_with", "thiserror", "wasm-bindgen-test", "web-time", @@ -1193,6 +1213,12 @@ version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +[[package]] +name = "replace_with" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51743d3e274e2b18df81c4dc6caf8a5b8e15dbe799e0dca05c7617380094e884" + [[package]] name = "rustc_version" version = "0.4.1" diff --git a/pumpkin-crates/constraints/src/constraints/arithmetic/mod.rs b/pumpkin-crates/constraints/src/constraints/arithmetic/mod.rs index fe41b4721..eccea63c6 100644 --- a/pumpkin-crates/constraints/src/constraints/arithmetic/mod.rs +++ b/pumpkin-crates/constraints/src/constraints/arithmetic/mod.rs @@ -3,6 +3,7 @@ mod inequality; pub use equality::*; pub use inequality::*; +use pumpkin_core::checkers::support::SupportsValue; use pumpkin_core::constraints::Constraint; use pumpkin_core::proof::ConstraintTag; use pumpkin_core::variables::IntegerVariable; @@ -23,9 +24,9 @@ pub fn plus( /// Creates the [`Constraint`] `a * b = c`. pub fn times( - a: impl IntegerVariable + 'static, - b: impl IntegerVariable + 'static, - c: impl IntegerVariable + 'static, + a: impl IntegerVariable + SupportsValue + 'static, + b: impl IntegerVariable + SupportsValue + 'static, + c: impl IntegerVariable + SupportsValue + 'static, constraint_tag: ConstraintTag, ) -> impl Constraint { IntegerMultiplicationArgs { diff --git a/pumpkin-crates/core/Cargo.toml b/pumpkin-crates/core/Cargo.toml index 371849216..36796d6d3 100644 --- a/pumpkin-crates/core/Cargo.toml +++ b/pumpkin-crates/core/Cargo.toml @@ -30,6 +30,8 @@ clap = { version = "4.5.40", optional = true, features=["derive"] } indexmap = "2.10.0" dyn-clone = "1.0.20" flate2 = { version = "1.1.2" } +bit-set = "0.10.0" +replace_with = "0.1.8" [target.'cfg(target_arch = "wasm32")'.dependencies] web-time = "1.1" @@ -39,6 +41,7 @@ getrandom = { version = "0.4.2", features = ["wasm_js"] } wasm-bindgen-test = "0.3" [features] +check-consistency = [] check-propagations = [] check-deductions = [] debug-checks = [] diff --git a/pumpkin-crates/core/src/api/mod.rs b/pumpkin-crates/core/src/api/mod.rs index 22f3a90d7..1f5944d9a 100644 --- a/pumpkin-crates/core/src/api/mod.rs +++ b/pumpkin-crates/core/src/api/mod.rs @@ -69,8 +69,8 @@ pub mod options { pub use crate::engine::ConflictResolverType; pub use crate::engine::RestartOptions; pub use crate::engine::SatisfactionSolverOptions as SolverOptions; + pub use crate::propagators::ReifiedPropagatorArgs; pub use crate::propagators::nogoods::LearningOptions; - pub use crate::propagators::reified_propagator::ReifiedPropagatorArgs; } pub mod termination { diff --git a/pumpkin-crates/core/src/basic_types/propositional_conjunction.rs b/pumpkin-crates/core/src/basic_types/propositional_conjunction.rs index 97c2f2ef1..4a718b654 100644 --- a/pumpkin-crates/core/src/basic_types/propositional_conjunction.rs +++ b/pumpkin-crates/core/src/basic_types/propositional_conjunction.rs @@ -21,6 +21,12 @@ impl Deref for PropositionalConjunction { } } +impl Into> for PropositionalConjunction { + fn into(self) -> Box<[Predicate]> { + self.predicates_in_conjunction.into() + } +} + impl PropositionalConjunction { pub fn new(predicates_in_conjunction: Vec) -> Self { PropositionalConjunction { diff --git a/pumpkin-crates/core/src/checkers/mod.rs b/pumpkin-crates/core/src/checkers/mod.rs new file mode 100644 index 000000000..05d6bc62d --- /dev/null +++ b/pumpkin-crates/core/src/checkers/mod.rs @@ -0,0 +1,14 @@ +mod propagation_checker; +mod retention_checker; +mod scope; +mod self_disabling; +mod store; +mod strong_retention_checker; +pub mod support; + +pub use propagation_checker::*; +pub use retention_checker::*; +pub use scope::*; +pub use self_disabling::*; +pub use store::*; +pub use strong_retention_checker::*; diff --git a/pumpkin-crates/core/src/checkers/propagation_checker.rs b/pumpkin-crates/core/src/checkers/propagation_checker.rs new file mode 100644 index 000000000..67de6fd32 --- /dev/null +++ b/pumpkin-crates/core/src/checkers/propagation_checker.rs @@ -0,0 +1,64 @@ +use pumpkin_checking::BoxedChecker; +use pumpkin_checking::VariableState; + +use crate::predicates::Predicate; +use crate::propagation::Domains; +use crate::propagation::ReadDomains; +use crate::variables::DomainId; + +/// Tests whether an inference is correct given the solver state. +/// +/// An inference is correct when: +/// 1. All premises are satisfied. +/// 2. The conjunction of the premises and negation of the consequent is consistent. +/// 3. The consequent is logically entailed given the inference code. +#[derive(Clone, Debug)] +pub struct PropagationChecker { + inference_checker: BoxedChecker, +} + +impl PropagationChecker { + /// Create a new propagation checker given an inference checker and inference code. + pub fn new(inference_checker: BoxedChecker) -> PropagationChecker { + PropagationChecker { inference_checker } + } + + /// Run the propagation checker for the given inference. + pub fn check( + &self, + premises: &[Predicate], + consequent: Option, + domains: Domains<'_>, + ) -> Result<(), InvalidInference> { + let premises_satisfied = premises + .iter() + .all(|&premise| domains.evaluate_predicate(premise) == Some(true)); + + if !premises_satisfied { + return Err(InvalidInference::UnsatisfiedPremises); + } + + let variable_state = + VariableState::prepare_for_conflict_check(premises.iter().copied(), consequent) + .map_err(InvalidInference::InconsistentPredicates)?; + + if self + .inference_checker + .check(variable_state, &premises, consequent.as_ref()) + { + Ok(()) + } else { + Err(InvalidInference::Unsound) + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum InvalidInference { + /// Not all premises are true given the current state. + UnsatisfiedPremises, + /// The predicates that make up the inference are trivially inconsistent. + InconsistentPredicates(DomainId), + /// Cannot establish that the inference is sound. + Unsound, +} diff --git a/pumpkin-crates/core/src/checkers/retention_checker.rs b/pumpkin-crates/core/src/checkers/retention_checker.rs new file mode 100644 index 000000000..e3dd157bf --- /dev/null +++ b/pumpkin-crates/core/src/checkers/retention_checker.rs @@ -0,0 +1,39 @@ +use std::fmt::Debug; + +use dyn_clone::DynClone; + +use crate::checkers::Scope; +use crate::propagation::Domains; + +/// A runtime verifier that determines whether domains are sufficiently pruned. +pub trait RetentionChecker: Debug + DynClone { + /// Ensure the domains do not have values that should have been removed by propagation. + /// + /// Returns `true` if the domains are sufficiently pruned, or `false` otherwise. + fn check_retention(&mut self, scope: &Scope, domains: Domains<'_>) -> bool; +} + +/// Wrapper around `Box` that implements [`Clone`]. +#[derive(Debug)] +pub struct BoxedRetentionChecker(Box); + +impl Clone for BoxedRetentionChecker { + fn clone(&self) -> Self { + BoxedRetentionChecker(dyn_clone::clone_box(&*self.0)) + } +} + +impl From for BoxedRetentionChecker +where + T: RetentionChecker + 'static, +{ + fn from(value: T) -> Self { + BoxedRetentionChecker(Box::new(value)) + } +} + +impl BoxedRetentionChecker { + pub fn check_retention(&mut self, scope: &Scope, domains: Domains<'_>) -> bool { + self.0.check_retention(scope, domains) + } +} diff --git a/pumpkin-crates/core/src/checkers/scope.rs b/pumpkin-crates/core/src/checkers/scope.rs new file mode 100644 index 000000000..b8b9c9279 --- /dev/null +++ b/pumpkin-crates/core/src/checkers/scope.rs @@ -0,0 +1,71 @@ +use crate::containers::HashMap; +use crate::propagation::LocalId; +use crate::variables::DomainId; + +/// The scope of a constraint is the collection of variables involved in the relation. +#[derive(Clone, Debug, Default)] +pub struct Scope { + domains: HashMap, +} + +impl FromIterator<(LocalId, DomainId)> for Scope { + fn from_iter>(iter: T) -> Self { + Scope { + domains: iter.into_iter().collect(), + } + } +} + +impl Scope { + /// Add a new domain to the scope with the given local id. + /// + /// Any previous occurrance of this local id will be overridden. + pub fn add_domain(&mut self, local_id: LocalId, domain_id: DomainId) { + let _ = self.domains.insert(local_id, domain_id); + } + + /// The integer domains in the scope with the [`LocalId`]s they are registered. + pub fn domains(&self) -> impl ExactSizeIterator { + self.domains.iter().map(|(lid, did)| (*lid, *did)) + } + + /// Returns a copy of this scope with the entry for `local_id` removed. + pub fn without(&self, local_id: LocalId) -> Scope { + let mut scope = self.clone(); + let _ = scope.domains.remove(&local_id); + scope + } +} + +macro_rules! impl_scope_from_tuple { + ($($lid_name:ident,$var_name:ident : $ty_name:ident),+) => { + impl<$($ty_name),+> From<($((LocalId, &$ty_name)),+)> for Scope + where + $($ty_name: ScopeItem),+ + { + fn from( + ($(($lid_name, $var_name)),+): ($((LocalId, &$ty_name)),+), + ) -> Self { + let mut scope = Scope::default(); + + $($var_name.add_to_scope(&mut scope, $lid_name);)+ + + scope + } + } + }; +} + +impl_scope_from_tuple!(la,va: VA, lb,vb: VB); +impl_scope_from_tuple!(la,va: VA, lb,vb: VB, lc,vc: VC); + +pub trait ScopeItem { + /// Adds self to the given scope with the given [`LocalId`]. + fn add_to_scope(&self, scope: &mut Scope, local_id: LocalId); +} + +impl ScopeItem for i32 { + fn add_to_scope(&self, _: &mut Scope, _: LocalId) { + // Do nothing + } +} diff --git a/pumpkin-crates/core/src/checkers/self_disabling.rs b/pumpkin-crates/core/src/checkers/self_disabling.rs new file mode 100644 index 000000000..9a5770f3a --- /dev/null +++ b/pumpkin-crates/core/src/checkers/self_disabling.rs @@ -0,0 +1,44 @@ +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; + +use super::RetentionChecker; +use super::Scope; +use crate::propagation::Domains; + +/// A [`ConsistencyChecker`] wrapper that skips the inner check when the associated constraint has +/// been deleted. +/// +/// The deletion flag is shared with the constraint owner (e.g. the nogood propagator). Setting the +/// flag to `true` causes the checker to become a permanent no-op. +#[derive(Debug, Clone)] +pub struct SelfDisablingChecker { + inner: T, + is_deleted: Arc, +} + +impl SelfDisablingChecker { + /// Create a new self-disabling checker. + /// + /// The deletion flag can be obtained with [`SelfDisablingChecker::deletion_flag`]. + pub fn new(checker: T) -> Self { + SelfDisablingChecker { + inner: checker, + is_deleted: Arc::new(AtomicBool::new(false)), + } + } + + /// The deletion flag for this self-disabling checker. + pub fn deletion_flag(&self) -> Arc { + Arc::clone(&self.is_deleted) + } +} + +impl RetentionChecker for SelfDisablingChecker { + fn check_retention(&mut self, scope: &Scope, domains: Domains<'_>) -> bool { + if self.is_deleted.load(Ordering::Relaxed) { + return true; + } + self.inner.check_retention(scope, domains) + } +} diff --git a/pumpkin-crates/core/src/checkers/store.rs b/pumpkin-crates/core/src/checkers/store.rs new file mode 100644 index 000000000..79e141422 --- /dev/null +++ b/pumpkin-crates/core/src/checkers/store.rs @@ -0,0 +1,90 @@ +use crate::checkers::BoxedRetentionChecker; +use crate::checkers::Scope; +use crate::containers::KeyedBitSet; +use crate::containers::KeyedVec; +use crate::containers::StorageKey; +use crate::propagation::Domains; +use crate::variables::DomainId; + +/// Holds the consistency checkers in the solver. +/// +/// Also responsible for enqueueing the checkers and dispatching them when instructed via +/// [`ConsistencyCheckerStore::run_enqueued`]. +#[derive(Clone, Debug, Default)] +pub struct ConsistencyCheckerStore { + /// The checkers in the store. + store: KeyedVec, + /// Map from [`DomainId`] to the relevant checkers via their ID. + watch_list: KeyedVec>, + /// The checkers to run the next time. + queue: Vec, + /// Marks which checkers are enqueued to prevent duplicate checkers in + /// [`ConsistencyCheckerStore::queue`]. + enqueued: KeyedBitSet, +} + +impl ConsistencyCheckerStore { + /// Add a new `checker` to the store with the given `scope`. + pub fn register(&mut self, scope: Scope, checker: BoxedRetentionChecker) { + let checker_slot = self.store.new_slot(); + + for (_, domain) in scope.domains() { + self.watch_list.accomodate(domain, vec![]); + self.watch_list[domain].push(checker_slot.key()); + } + + let _ = checker_slot.populate((scope, checker)); + } + + /// Called when the domain is modified. + /// + /// Causes the checkers for this domain to be enqueued. + pub fn on_domain_event(&mut self, domain_id: DomainId) { + let Some(list) = self.watch_list.get(domain_id) else { + return; + }; + + for &checker_id in list { + if !self.enqueued.insert(checker_id) { + continue; + } + + self.queue.push(checker_id); + } + } + + /// Run the enqueued consistency checkers. + pub fn run_enqueued(&mut self, mut domains: Domains<'_>) -> bool { + for checker_id in self.queue.drain(..) { + assert!(self.enqueued.remove(checker_id)); + + let (scope, checker) = &mut self.store[checker_id]; + + if !checker.check_retention(scope, domains.reborrow()) { + return false; + } + } + + true + } + + /// Clear the queue of consistency checkers. + pub fn clear_queue(&mut self) { + self.queue.clear(); + self.enqueued.clear(); + } +} + +/// An identifier for added checkers. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +struct CheckerId(u32); + +impl StorageKey for CheckerId { + fn index(&self) -> usize { + self.0 as usize + } + + fn create_from_index(index: usize) -> Self { + CheckerId(index as u32) + } +} diff --git a/pumpkin-crates/core/src/checkers/strong_retention_checker.rs b/pumpkin-crates/core/src/checkers/strong_retention_checker.rs new file mode 100644 index 000000000..3c151a994 --- /dev/null +++ b/pumpkin-crates/core/src/checkers/strong_retention_checker.rs @@ -0,0 +1,113 @@ +use super::Scope; +use crate::checkers::RetentionChecker; +use crate::checkers::support::Support; +use crate::checkers::support::SupportGenerator; +use crate::checkers::support::SupportValue; +use crate::checkers::support::UnsupportedValue; +use crate::containers::HashSet; +use crate::propagation::Domains; +use crate::propagation::ReadDomains; +use crate::variables::DomainId; + +/// The consistency level advertised by the propagator. +#[derive(Clone, Copy, Debug)] +pub enum StrongConsistency { + Domain, + Bounds, +} + +/// A [`ConsistencyChecker`] that enforces a strong consistency property. +/// +/// The level of consistency is configured via [`StrongConsistency`]. +#[derive(Clone, Debug)] +pub struct StrongRetentionChecker { + /// The generator of supports. + supports: Supports, + /// A cache of domain-value pairs that are supported. + supported_values: HashSet<(DomainId, i32)>, + /// The consistency level to test for. + consistency_level: StrongConsistency, + /// Re-usable buffer of the current support that is operated on. + support: Support, +} + +impl StrongRetentionChecker { + pub fn new(consistency_level: StrongConsistency, supports: Supports) -> Self { + StrongRetentionChecker { + consistency_level, + supports, + supported_values: HashSet::default(), + support: Support::default(), + } + } +} + +impl RetentionChecker for StrongRetentionChecker { + fn check_retention(&mut self, scope: &Scope, domains: Domains<'_>) -> bool { + // Make sure to clear the cache of supported values. At the beginning, no values are + // supported. + self.supported_values.clear(); + + for (local_id, domain) in scope.domains() { + let values_to_support = match self.consistency_level { + StrongConsistency::Domain => { + itertools::Either::Left(domains.iterate_domain(&domain)) + } + StrongConsistency::Bounds => itertools::Either::Right( + [domains.lower_bound(&domain), domains.upper_bound(&domain)].into_iter(), + ), + }; + + for value in values_to_support { + if self.supported_values.contains(&(domain, value)) { + // If this domain-value pair is already supported in this check + // then there is no need to generate a new support for it. + continue; + } + + // Generate the support for this domain-value pair. + self.supports.support( + &mut self.support, + local_id, + UnsupportedValue(value), + &domains, + ); + + if !self.process_support(&domains) { + // The support was incomplete or not a solution. Either way, the + // consistency check fails. + return false; + } + } + } + + // All required values are successfully supported, so the check passes. + true + } +} + +impl StrongRetentionChecker { + /// Tests whether the [`StrongConsistencyChecker::support`] is a valid support. + /// + /// Drains the support in the process, so it can be used again by subsequent calls to + /// [`SupportGenerator::support`]. + fn process_support(&mut self, domains: &Domains<'_>) -> bool { + if !self.supports.is_solution(&self.support) { + log::error!("Support is not a solution"); + return false; + } + + for (domain, value) in self.support.drain() { + if !value.is_in(domain, domains) { + log::error!("Support value is not in the domain"); + return false; + } + + if let Some(int) = value.as_int() { + let _ = self.supported_values.insert((domain, int)); + } + } + + true + } +} diff --git a/pumpkin-crates/core/src/checkers/support.rs b/pumpkin-crates/core/src/checkers/support.rs new file mode 100644 index 000000000..fb06a3077 --- /dev/null +++ b/pumpkin-crates/core/src/checkers/support.rs @@ -0,0 +1,162 @@ +use std::fmt::Debug; + +use crate::containers::HashMap; +use crate::propagation::Domains; +use crate::propagation::LocalId; +use crate::propagation::ReadDomains; +use crate::variables::DomainId; + +/// A [`SupportGenerator`] can produce [`Support`]s for values in domains. +pub trait SupportGenerator: Clone + Debug { + /// The type of value used in the support. + /// + /// Depending on how the generator is used, this may be a float or an integer. + type Value: SupportValue; + + /// Produce a support where the domain corresponding to `local_id` is assigned to `value`. + /// + /// The support is written into the `support` buffer. Implementations can assume that this + /// support is empty when this function is called. + /// + /// The support must satisfy the constraint it is supporting, and all assignments must be + /// within the domain bounds. + fn support( + &mut self, + support: &mut Support, + local_id: LocalId, + value: UnsupportedValue, + domains: &Domains<'_>, + ); + + /// Returns true if the support is a solution to the constraint. + /// + /// Called with the support generated by [`SupportGenerator::support`]. + fn is_solution(&self, support: &Support) -> bool; +} + +/// A value that may be used in a [`Support`]. +pub trait SupportValue: Clone + Debug { + /// Returns `true` if `self` is in the given domain. + fn is_in(&self, domain: DomainId, domains: &Domains<'_>) -> bool; + + /// If the value is an integer, we can cache it to prevent recreating supports for the same + /// value. + fn as_int(&self) -> Option; +} + +impl SupportValue for i32 { + fn is_in(&self, domain: DomainId, domains: &Domains<'_>) -> bool { + domains.contains(&domain, *self) + } + + fn as_int(&self) -> Option { + Some(*self) + } +} + +impl SupportValue for f32 { + fn is_in(&self, domain: DomainId, domains: &Domains<'_>) -> bool { + let lb = domains.lower_bound(&domain) as f32; + let ub = domains.upper_bound(&domain) as f32; + + lb <= *self && *self <= ub + } + + fn as_int(&self) -> Option { + if (self.round() - self).abs() < f32::EPSILON { + Some(self.round() as i32) + } else { + None + } + } +} + +/// An assignment which supports a value for a particular domain. +#[derive(Clone, Debug)] +pub struct Support { + assignment: HashMap, +} + +impl Default for Support { + fn default() -> Self { + Self { + assignment: Default::default(), + } + } +} + +impl Support { + /// Add a domain assignment to the support. + /// + /// Previous assignments of the given domain are overwritten. + pub fn with_assignment(&mut self, domain_id: DomainId, value: Value) { + let _ = self.assignment.insert(domain_id, value); + } + + /// Get the value for the given domain in this support. + /// + /// Panics if the domain is unassigned. + pub fn assignment(&self, domain_id: DomainId) -> Value { + self.assignment + .get(&domain_id) + .cloned() + .unwrap_or_else(|| panic!("could not get assignment for {domain_id}")) + } + + /// Drain the support of all its assigned domains. + /// + /// Will leave the support empty. + pub(super) fn drain(&mut self) -> impl ExactSizeIterator { + self.assignment.drain() + } +} + +/// A domain value that needs to be unpacked through [`UnpackUnsupportedValue::unpack`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct UnsupportedValue(pub(crate) i32); + +/// Implementors know how to unpack an [`UnsupportedValue`]. +pub trait UnpackUnsupportedValue { + /// Turn the unsupported value into an item in the domain of [`self`] (the variable). + fn unpack(&self, unsupported_value: UnsupportedValue) -> i32; +} + +impl UnpackUnsupportedValue for i32 { + fn unpack(&self, UnsupportedValue(value): UnsupportedValue) -> i32 { + assert_eq!(value, *self); + value + } +} + +/// A trait to identify types that can serve as variables in a support. +pub trait SupportsValue: UnpackUnsupportedValue { + /// Add the assignment `self = value` to the `support`. + fn assign(&self, value: Value, support: &mut Support); + + /// Get the value from the given support. + /// + /// Called with the result of [`SupportsValue::assign`]. + /// + /// Panics if the support has no value for this variable. + fn support_value(&self, support: &Support) -> Value; +} + +impl SupportsValue for i32 { + fn assign(&self, _: i32, _: &mut Support) { + // Do nothing + } + + fn support_value(&self, _: &Support) -> i32 { + *self + } +} + +impl SupportsValue for i32 { + fn assign(&self, _: f32, _: &mut Support) { + // Do nothing + } + + fn support_value(&self, _: &Support) -> f32 { + *self as f32 + } +} diff --git a/pumpkin-crates/core/src/constraints/mod.rs b/pumpkin-crates/core/src/constraints/mod.rs index 1e7f34304..1e0fb5c4b 100644 --- a/pumpkin-crates/core/src/constraints/mod.rs +++ b/pumpkin-crates/core/src/constraints/mod.rs @@ -2,7 +2,7 @@ use crate::ConstraintOperationError; use crate::Solver; use crate::propagation::PropagatorConstructor; -use crate::propagators::reified_propagator::ReifiedPropagatorArgs; +use crate::propagators::ReifiedPropagatorArgs; use crate::variables::Literal; mod constraint_poster; diff --git a/pumpkin-crates/core/src/containers/keyed_bit_set.rs b/pumpkin-crates/core/src/containers/keyed_bit_set.rs new file mode 100644 index 000000000..194a22874 --- /dev/null +++ b/pumpkin-crates/core/src/containers/keyed_bit_set.rs @@ -0,0 +1,56 @@ +use std::marker::PhantomData; + +use bit_set::BitSet; + +use crate::containers::StorageKey; + +/// A bit-set for types that implement [`StorageKey`]. +#[derive(Debug)] +pub struct KeyedBitSet { + bitset: BitSet, + key: PhantomData, +} + +impl KeyedBitSet { + /// Add the key to the set. + /// + /// Returns `true` if the set did _not_ previously contain `key`. + pub fn insert(&mut self, key: Key) -> bool { + self.bitset.insert(key.index()) + } + + /// Remove the key from the set. + /// + /// If the key was present, returns true. + pub fn remove(&mut self, key: Key) -> bool { + self.bitset.remove(key.index()) + } + + /// Get all keys in the set and remove them. + pub fn drain(&self) -> impl Iterator { + self.bitset.iter().map(Key::create_from_index) + } + + /// Remove all keys in the set. + pub fn clear(&mut self) { + self.bitset.make_empty(); + } +} + +impl Clone for KeyedBitSet { + fn clone(&self) -> Self { + Self { + bitset: self.bitset.clone(), + key: PhantomData, + } + } +} + +impl Default for KeyedBitSet { + fn default() -> Self { + Self { + bitset: BitSet::default(), + key: PhantomData, + } + } +} diff --git a/pumpkin-crates/core/src/containers/mod.rs b/pumpkin-crates/core/src/containers/mod.rs index 39343eb44..d16b86ae7 100644 --- a/pumpkin-crates/core/src/containers/mod.rs +++ b/pumpkin-crates/core/src/containers/mod.rs @@ -1,12 +1,14 @@ //! Contains containers which are used by the solver. mod key_generator; mod key_value_heap; +mod keyed_bit_set; mod keyed_vec; mod sparse_set; use fnv::FnvBuildHasher; pub use key_generator::*; pub use key_value_heap::*; +pub use keyed_bit_set::*; pub use keyed_vec::*; pub use sparse_set::*; diff --git a/pumpkin-crates/core/src/engine/cp/test_solver.rs b/pumpkin-crates/core/src/engine/cp/test_solver.rs index 4bbcdf795..a75ba9821 100644 --- a/pumpkin-crates/core/src/engine/cp/test_solver.rs +++ b/pumpkin-crates/core/src/engine/cp/test_solver.rs @@ -232,14 +232,32 @@ impl TestSolver { } pub fn propagate(&mut self, propagator: PropagatorId) -> Result<(), Conflict> { + let State { + propagators, + trailed_values, + assignments, + reason_store, + notification_engine, + #[cfg(feature = "check-consistency")] + consistency_checkers, + #[cfg(feature = "check-propagations")] + propagation_checkers, + .. + } = &mut self.state; + let context = PropagationContext::new( - &mut self.state.trailed_values, - &mut self.state.assignments, - &mut self.state.reason_store, - &mut self.state.notification_engine, + trailed_values, + assignments, + reason_store, + notification_engine, propagator, + #[cfg(feature = "check-consistency")] + consistency_checkers, + #[cfg(feature = "check-propagations")] + propagation_checkers, ); - self.state.propagators[propagator].propagate(context) + + propagators[propagator].propagate(context) } pub fn propagate_until_fixed_point( @@ -251,14 +269,32 @@ impl TestSolver { loop { { // Specify the life-times to be able to retrieve the trail entries + let State { + propagators, + trailed_values, + assignments, + reason_store, + notification_engine, + #[cfg(feature = "check-consistency")] + consistency_checkers, + #[cfg(feature = "check-propagations")] + propagation_checkers, + .. + } = &mut self.state; + let context = PropagationContext::new( - &mut self.state.trailed_values, - &mut self.state.assignments, - &mut self.state.reason_store, - &mut self.state.notification_engine, + trailed_values, + assignments, + reason_store, + notification_engine, propagator, + #[cfg(feature = "check-consistency")] + consistency_checkers, + #[cfg(feature = "check-propagations")] + propagation_checkers, ); - self.state.propagators[propagator].propagate(context)?; + + propagators[propagator].propagate(context)?; self.notify_propagator(propagator); } if self.state.assignments.num_trail_entries() == num_trail_entries { diff --git a/pumpkin-crates/core/src/engine/debug_helper.rs b/pumpkin-crates/core/src/engine/debug_helper.rs index 8e7a334ae..0e90f5760 100644 --- a/pumpkin-crates/core/src/engine/debug_helper.rs +++ b/pumpkin-crates/core/src/engine/debug_helper.rs @@ -9,6 +9,10 @@ use super::notifications::NotificationEngine; use super::predicates::predicate::Predicate; use super::reason::ReasonStore; use crate::basic_types::PropositionalConjunction; +#[cfg(feature = "check-consistency")] +use crate::checkers::ConsistencyCheckerStore; +#[cfg(feature = "check-propagations")] +use crate::containers::HashMap; use crate::engine::cp::Assignments; use crate::propagation::ExplanationContext; use crate::propagation::PropagationContext; @@ -76,6 +80,11 @@ impl DebugHelper { let num_entries_on_trail_before_propagation = assignments_clone.num_trail_entries(); + #[cfg(feature = "check-consistency")] + let mut consistency_checkers = ConsistencyCheckerStore::default(); + #[cfg(feature = "check-propagations")] + let mut inference_checkers = HashMap::default(); + let mut reason_store = Default::default(); let context = PropagationContext::new( &mut trailed_values_clone, @@ -83,6 +92,10 @@ impl DebugHelper { &mut reason_store, &mut notification_engine_clone, PropagatorId(propagator_id as u32), + #[cfg(feature = "check-consistency")] + &mut consistency_checkers, + #[cfg(feature = "check-propagations")] + &mut inference_checkers, ); let propagation_status_cp = propagator.propagate_from_scratch(context); @@ -252,6 +265,11 @@ impl DebugHelper { notification_engine_clone.debug_create_from_assignments(&assignments_clone); if adding_predicates_was_successful { + #[cfg(feature = "check-consistency")] + let mut consistency_checkers = ConsistencyCheckerStore::default(); + #[cfg(feature = "check-propagations")] + let mut inference_checkers = HashMap::default(); + // Now propagate using the debug propagation method. let mut reason_store = Default::default(); let context = PropagationContext::new( @@ -260,6 +278,10 @@ impl DebugHelper { &mut reason_store, &mut notification_engine_clone, propagator_id, + #[cfg(feature = "check-consistency")] + &mut consistency_checkers, + #[cfg(feature = "check-propagations")] + &mut inference_checkers, ); let debug_propagation_status_cp = propagator.propagate_from_scratch(context); @@ -369,12 +391,21 @@ impl DebugHelper { loop { let num_predicates_before = assignments_clone.num_trail_entries(); + #[cfg(feature = "check-consistency")] + let mut consistency_checkers = ConsistencyCheckerStore::default(); + #[cfg(feature = "check-propagations")] + let mut inference_checkers = HashMap::default(); + let context = PropagationContext::new( &mut trailed_values_clone, &mut assignments_clone, &mut reason_store, &mut notification_engine_clone, propagator_id, + #[cfg(feature = "check-consistency")] + &mut consistency_checkers, + #[cfg(feature = "check-propagations")] + &mut inference_checkers, ); let debug_propagation_status_cp = propagator.propagate_from_scratch(context); @@ -433,6 +464,11 @@ impl DebugHelper { notification_engine_clone.debug_create_from_assignments(&assignments_clone); if adding_predicates_was_successful { + #[cfg(feature = "check-consistency")] + let mut consistency_checkers = ConsistencyCheckerStore::default(); + #[cfg(feature = "check-propagations")] + let mut inference_checkers = HashMap::default(); + // now propagate using the debug propagation method let mut reason_store = Default::default(); let context = PropagationContext::new( @@ -441,6 +477,10 @@ impl DebugHelper { &mut reason_store, &mut notification_engine_clone, propagator_id, + #[cfg(feature = "check-consistency")] + &mut consistency_checkers, + #[cfg(feature = "check-propagations")] + &mut inference_checkers, ); let debug_propagation_status_cp = propagator.propagate_from_scratch(context); assert!( diff --git a/pumpkin-crates/core/src/engine/state.rs b/pumpkin-crates/core/src/engine/state.rs index 9e358fab6..6b79dd269 100644 --- a/pumpkin-crates/core/src/engine/state.rs +++ b/pumpkin-crates/core/src/engine/state.rs @@ -2,9 +2,11 @@ use std::sync::Arc; use pumpkin_checking::BoxedChecker; use pumpkin_checking::InferenceChecker; -#[cfg(feature = "check-propagations")] -use pumpkin_checking::VariableState; +use crate::checkers::BoxedRetentionChecker; +use crate::checkers::ConsistencyCheckerStore; +use crate::checkers::PropagationChecker; +use crate::checkers::Scope; use crate::containers::HashMap; use crate::containers::KeyGenerator; use crate::create_statistics_struct; @@ -28,8 +30,6 @@ use crate::proof::InferenceCode; use crate::propagation::CurrentNogood; use crate::propagation::Domains; use crate::propagation::ExplanationContext; -#[cfg(feature = "check-propagations")] -use crate::propagation::InferenceCheckers; use crate::propagation::NotificationContext; use crate::propagation::PropagationContext; use crate::propagation::Propagator; @@ -81,7 +81,8 @@ pub struct State { statistics: StateStatistics, /// Inference checkers to run in the propagation loop. - checkers: HashMap>>, + pub(crate) propagation_checkers: HashMap>, + pub(crate) consistency_checkers: ConsistencyCheckerStore, } create_statistics_struct!(StateStatistics { @@ -111,7 +112,8 @@ impl Default for State { notification_engine: NotificationEngine::default(), statistics: StateStatistics::default(), constraint_tags: KeyGenerator::default(), - checkers: HashMap::default(), + propagation_checkers: HashMap::default(), + consistency_checkers: Default::default(), }; // As a convention, the assignments contain a dummy domain_id=0, which represents a 0-1 // variable that is assigned to one. We use it to represent predicates that are @@ -333,9 +335,6 @@ impl State { Constructor: PropagatorConstructor, Constructor::PropagatorImpl: 'static, { - #[cfg(feature = "check-propagations")] - constructor.add_inference_checkers(InferenceCheckers::new(self)); - let original_handle: PropagatorHandle = self.propagators.new_propagator().key(); @@ -373,8 +372,18 @@ impl State { inference_code: InferenceCode, checker: Box>, ) { - let checkers = self.checkers.entry(inference_code).or_default(); - checkers.push(BoxedChecker::from(checker)); + let checkers = self.propagation_checkers.entry(inference_code).or_default(); + checkers.push(PropagationChecker::new(BoxedChecker::from(checker))); + } + + /// Add a consistency checker for the scope. + pub fn add_consistency_checker( + &mut self, + scope: impl Into, + checker: impl Into, + ) { + self.consistency_checkers + .register(scope.into(), checker.into()); } } @@ -412,16 +421,31 @@ impl State { &mut self, handle: PropagatorHandle

, ) -> (Option<&mut P>, PropagationContext<'_>) { - ( - self.propagators.get_propagator_mut(handle), - PropagationContext::new( - &mut self.trailed_values, - &mut self.assignments, - &mut self.reason_store, - &mut self.notification_engine, - handle.propagator_id(), - ), - ) + let Self { + propagators, + trailed_values, + assignments, + reason_store, + notification_engine, + #[cfg(feature = "check-consistency")] + consistency_checkers, + #[cfg(feature = "check-propagations")] + propagation_checkers: checkers, + .. + } = self; + let propagator = propagators.get_propagator_mut(handle); + let context = PropagationContext::new( + trailed_values, + assignments, + reason_store, + notification_engine, + handle.propagator_id(), + #[cfg(feature = "check-consistency")] + consistency_checkers, + #[cfg(feature = "check-propagations")] + checkers, + ); + (propagator, context) } } @@ -604,13 +628,29 @@ impl State { let num_trail_entries_before = self.assignments.num_trail_entries(); let propagation_status = { - let propagator = &mut self.propagators[propagator_id]; + let Self { + propagators, + trailed_values, + assignments, + reason_store, + notification_engine, + #[cfg(feature = "check-consistency")] + consistency_checkers, + #[cfg(feature = "check-propagations")] + propagation_checkers: checkers, + .. + } = self; + let propagator = &mut propagators[propagator_id]; let context = PropagationContext::new( - &mut self.trailed_values, - &mut self.assignments, - &mut self.reason_store, - &mut self.notification_engine, + trailed_values, + assignments, + reason_store, + notification_engine, propagator_id, + #[cfg(feature = "check-consistency")] + consistency_checkers, + #[cfg(feature = "check-propagations")] + checkers, ); propagator.propagate(context) }; @@ -618,6 +658,9 @@ impl State { #[cfg(feature = "check-propagations")] self.check_propagations(num_trail_entries_before); + #[cfg(feature = "check-consistency")] + self.enqueue_consistency_checkers(num_trail_entries_before); + match propagation_status { Ok(_) => { // Notify other propagators of the propagations and continue. @@ -645,6 +688,9 @@ impl State { #[cfg(feature = "check-propagations")] self.check_conflict(&conflict); + #[cfg(feature = "check-propagations")] + self.consistency_checkers.clear_queue(); + self.statistics.num_conflicts += 1; if let Conflict::Propagator(inner) = &conflict { pumpkin_assert_advanced!(DebugHelper::debug_reported_failure( @@ -718,6 +764,16 @@ impl State { } } + #[cfg(feature = "check-consistency")] + fn enqueue_consistency_checkers(&mut self, first_propagation_index: usize) { + for trail_index in first_propagation_index..self.assignments.num_trail_entries() { + let entry = self.assignments.get_trail_entry(trail_index); + + self.consistency_checkers + .on_domain_event(entry.predicate.get_domain()); + } + } + /// Performs fixed-point propagation using the propagators defined in the [`State`]. /// /// The posted [`Predicate`]s (using [`State::post`]) and added propagators (using @@ -746,6 +802,13 @@ impl State { self.propagate(propagator_id)?; } + if cfg!(feature = "check-consistency") { + assert!( + self.consistency_checkers + .run_enqueued(Domains::new(&self.assignments, &mut self.trailed_values)) + ); + } + // Only check fixed point propagation if there was no reported conflict, // since otherwise the state may be inconsistent. pumpkin_assert_extreme!(DebugHelper::debug_fixed_point_propagation( @@ -763,7 +826,7 @@ impl State { impl State { /// Run the checker for the given inference code on the given inference. fn run_checker( - &self, + &mut self, premises: impl IntoIterator, consequent: Option, inference_code: &InferenceCode, @@ -771,7 +834,7 @@ impl State { let premises: Vec<_> = premises.into_iter().collect(); let checkers = self - .checkers + .propagation_checkers .get(inference_code) .map(|vec| vec.as_slice()) .unwrap_or(&[]); @@ -782,18 +845,13 @@ impl State { ); let any_checker_accepts_inference = checkers.iter().any(|checker| { - // Construct the variable state for the conflict check. - let variable_state = VariableState::prepare_for_conflict_check( - premises.clone(), - consequent, - ) - .unwrap_or_else(|domain| { - panic!( - "inconsistent atomics over domain {domain:?} in inference by {inference_code:?}" + checker + .check( + &premises, + consequent, + Domains::new(&self.assignments, &mut self.trailed_values), ) - }); - - checker.check(variable_state, &premises, consequent.as_ref()) + .is_ok() }); assert!( @@ -1142,12 +1200,27 @@ impl State { } pub fn get_propagation_context(&mut self) -> PropagationContext<'_> { + let Self { + trailed_values, + assignments, + reason_store, + notification_engine, + #[cfg(feature = "check-consistency")] + consistency_checkers, + #[cfg(feature = "check-propagations")] + propagation_checkers: checkers, + .. + } = self; PropagationContext::new( - &mut self.trailed_values, - &mut self.assignments, - &mut self.reason_store, - &mut self.notification_engine, + trailed_values, + assignments, + reason_store, + notification_engine, PropagatorId(0), + #[cfg(feature = "check-consistency")] + consistency_checkers, + #[cfg(feature = "check-propagations")] + checkers, ) } } diff --git a/pumpkin-crates/core/src/engine/variables/affine_view.rs b/pumpkin-crates/core/src/engine/variables/affine_view.rs index 404184e57..daf56fd59 100644 --- a/pumpkin-crates/core/src/engine/variables/affine_view.rs +++ b/pumpkin-crates/core/src/engine/variables/affine_view.rs @@ -5,6 +5,12 @@ use pumpkin_checking::CheckerVariable; use pumpkin_checking::IntExt; use super::TransformableVariable; +use crate::checkers::Scope; +use crate::checkers::ScopeItem; +use crate::checkers::support::Support; +use crate::checkers::support::SupportsValue; +use crate::checkers::support::UnpackUnsupportedValue; +use crate::checkers::support::UnsupportedValue; use crate::engine::Assignments; use crate::engine::notifications::DomainEvent; use crate::engine::notifications::OpaqueDomainEvent; @@ -14,6 +20,7 @@ use crate::engine::predicates::predicate_constructor::PredicateConstructor; use crate::engine::variables::DomainId; use crate::engine::variables::IntegerVariable; use crate::math::num_ext::NumExt; +use crate::propagation::LocalId; /// Models the constraint `y = ax + b`, by expressing the domain of `y` as a transformation of the /// domain of `x`. @@ -46,6 +53,13 @@ impl AffineView { match rounding { Rounding::Up => ::div_ceil(inverted_translation, self.scale), Rounding::Down => ::div_floor(inverted_translation, self.scale), + Rounding::None => { + if inverted_translation % self.scale == 0 { + inverted_translation / self.scale + } else { + panic!("do not want to round but cannot unscale") + } + } } } @@ -54,6 +68,50 @@ impl AffineView { } } +impl ScopeItem for AffineView { + fn add_to_scope(&self, scope: &mut Scope, local_id: LocalId) { + self.inner.add_to_scope(scope, local_id); + } +} + +impl UnpackUnsupportedValue for AffineView +where + Inner: UnpackUnsupportedValue, +{ + fn unpack(&self, unsupported_value: UnsupportedValue) -> i32 { + self.map(self.inner.unpack(unsupported_value)) + } +} + +impl SupportsValue for AffineView +where + Inner: SupportsValue, +{ + fn assign(&self, value: i32, support: &mut Support) { + let value = self.invert(value, Rounding::None); + self.inner.assign(value, support); + } + + fn support_value(&self, support: &Support) -> i32 { + self.map(self.inner.support_value(support)) + } +} + +impl SupportsValue for AffineView +where + Inner: SupportsValue, +{ + fn assign(&self, value: f32, support: &mut Support) { + let inverted_translation = value - self.offset as f32; + let value = inverted_translation / self.scale as f32; + self.inner.assign(value, support); + } + + fn support_value(&self, support: &Support) -> f32 { + self.scale as f32 * self.inner.support_value(support) + self.offset as f32 + } +} + impl CheckerVariable for AffineView { fn does_atomic_constrain_self(&self, atomic: &Predicate) -> bool { self.inner.does_atomic_constrain_self(atomic) @@ -404,6 +462,7 @@ impl From for AffineView { enum Rounding { Up, Down, + None, } #[cfg(test)] diff --git a/pumpkin-crates/core/src/engine/variables/domain_id.rs b/pumpkin-crates/core/src/engine/variables/domain_id.rs index 22ede3a40..9a1a02af7 100644 --- a/pumpkin-crates/core/src/engine/variables/domain_id.rs +++ b/pumpkin-crates/core/src/engine/variables/domain_id.rs @@ -2,6 +2,12 @@ use enumset::EnumSet; use pumpkin_checking::CheckerVariable; use super::TransformableVariable; +use crate::checkers::Scope; +use crate::checkers::ScopeItem; +use crate::checkers::support::Support; +use crate::checkers::support::SupportsValue; +use crate::checkers::support::UnpackUnsupportedValue; +use crate::checkers::support::UnsupportedValue; use crate::containers::StorageKey; use crate::engine::Assignments; use crate::engine::notifications::DomainEvent; @@ -12,6 +18,7 @@ use crate::engine::variables::IntegerVariable; use crate::predicates::Predicate; use crate::predicates::PredicateConstructor; use crate::predicates::PredicateType; +use crate::propagation::LocalId; use crate::pumpkin_assert_simple; /// A structure which represents the most basic [`IntegerVariable`]; it is simply the id which links @@ -32,6 +39,38 @@ impl DomainId { } } +impl ScopeItem for DomainId { + fn add_to_scope(&self, scope: &mut Scope, local_id: LocalId) { + scope.add_domain(local_id, *self); + } +} + +impl UnpackUnsupportedValue for DomainId { + fn unpack(&self, UnsupportedValue(value): UnsupportedValue) -> i32 { + value + } +} + +impl SupportsValue for DomainId { + fn assign(&self, value: i32, support: &mut Support) { + support.with_assignment(*self, value); + } + + fn support_value(&self, support: &Support) -> i32 { + support.assignment(*self) + } +} + +impl SupportsValue for DomainId { + fn assign(&self, value: f32, support: &mut Support) { + support.with_assignment(*self, value); + } + + fn support_value(&self, support: &Support) -> f32 { + support.assignment(*self) + } +} + impl CheckerVariable for DomainId { fn does_atomic_constrain_self(&self, atomic: &Predicate) -> bool { atomic.get_domain() == *self diff --git a/pumpkin-crates/core/src/engine/variables/integer_variable.rs b/pumpkin-crates/core/src/engine/variables/integer_variable.rs index 09badb92e..1ed710e61 100644 --- a/pumpkin-crates/core/src/engine/variables/integer_variable.rs +++ b/pumpkin-crates/core/src/engine/variables/integer_variable.rs @@ -4,6 +4,8 @@ use enumset::EnumSet; use pumpkin_checking::CheckerVariable; use super::TransformableVariable; +use crate::checkers::ScopeItem; +use crate::checkers::support::SupportsValue; use crate::engine::Assignments; use crate::engine::notifications::DomainEvent; use crate::engine::notifications::OpaqueDomainEvent; @@ -19,6 +21,8 @@ pub trait IntegerVariable: + TransformableVariable + Debug + CheckerVariable + + ScopeItem + + SupportsValue { type AffineView: IntegerVariable; diff --git a/pumpkin-crates/core/src/engine/variables/literal.rs b/pumpkin-crates/core/src/engine/variables/literal.rs index 980ffa737..a2c4a88e0 100644 --- a/pumpkin-crates/core/src/engine/variables/literal.rs +++ b/pumpkin-crates/core/src/engine/variables/literal.rs @@ -8,6 +8,12 @@ use pumpkin_checking::VariableState; use super::DomainId; use super::IntegerVariable; use super::TransformableVariable; +use crate::checkers::Scope; +use crate::checkers::ScopeItem; +use crate::checkers::support::Support; +use crate::checkers::support::SupportsValue; +use crate::checkers::support::UnpackUnsupportedValue; +use crate::checkers::support::UnsupportedValue; use crate::engine::Assignments; use crate::engine::notifications::DomainEvent; use crate::engine::notifications::OpaqueDomainEvent; @@ -15,6 +21,7 @@ use crate::engine::notifications::Watchers; use crate::engine::predicates::predicate::Predicate; use crate::engine::predicates::predicate_constructor::PredicateConstructor; use crate::engine::variables::AffineView; +use crate::propagation::LocalId; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Literal { @@ -73,6 +80,28 @@ macro_rules! forward { } } +impl ScopeItem for Literal { + fn add_to_scope(&self, scope: &mut Scope, local_id: LocalId) { + self.integer_variable.add_to_scope(scope, local_id); + } +} + +impl UnpackUnsupportedValue for Literal { + fn unpack(&self, unsupported_value: UnsupportedValue) -> i32 { + self.integer_variable.unpack(unsupported_value) + } +} + +impl SupportsValue for Literal { + fn assign(&self, value: i32, support: &mut Support) { + self.integer_variable.assign(value, support) + } + + fn support_value(&self, support: &Support) -> i32 { + self.integer_variable.support_value(support) + } +} + impl CheckerVariable for Literal { forward!(integer_variable, fn does_atomic_constrain_self(&self, atomic: &Predicate) -> bool); forward!(integer_variable, fn atomic_less_than(&self, value: i32) -> Predicate); diff --git a/pumpkin-crates/core/src/lib.rs b/pumpkin-crates/core/src/lib.rs index 2a8680544..19d6fcefb 100644 --- a/pumpkin-crates/core/src/lib.rs +++ b/pumpkin-crates/core/src/lib.rs @@ -12,6 +12,7 @@ use crate::branching::Brancher; use crate::termination::TerminationCondition; pub mod branching; +pub mod checkers; pub mod conflict_resolving; pub mod constraints; pub mod optimisation; diff --git a/pumpkin-crates/core/src/propagation/constructor.rs b/pumpkin-crates/core/src/propagation/constructor.rs index 791627880..220441c3d 100644 --- a/pumpkin-crates/core/src/propagation/constructor.rs +++ b/pumpkin-crates/core/src/propagation/constructor.rs @@ -12,6 +12,8 @@ use super::PropagatorVarId; use crate::Solver; use crate::basic_types::PredicateId; use crate::basic_types::RefOrOwned; +use crate::checkers::BoxedRetentionChecker; +use crate::checkers::Scope; use crate::engine::Assignments; use crate::engine::State; use crate::engine::TrailedValues; @@ -25,9 +27,7 @@ use crate::proof::InferenceCode; #[cfg(doc)] use crate::propagation::DomainEvent; use crate::propagation::DomainEvents; -use crate::propagators::reified_propagator::ReifiedChecker; use crate::variables::IntegerVariable; -use crate::variables::Literal; /// A propagator constructor creates a fully initialized instance of a [`Propagator`]. /// @@ -35,63 +35,18 @@ use crate::variables::Literal; /// 1) Indicating on which [`DomainEvent`]s the propagator should be enqueued (via the /// [`PropagatorConstructorContext`]). /// 2) Initialising the [`PropagatorConstructor::PropagatorImpl`] and its structures. +/// +/// Inference checkers and consistency checkers should be registered inside [`Self::create`] via +/// [`PropagatorConstructorContext::add_inference_checker`] and +/// [`PropagatorConstructorContext::add_consistency_checker`]. pub trait PropagatorConstructor { /// The propagator that is produced by this constructor. type PropagatorImpl: Propagator + Clone; - /// Add inference checkers to the solver if applicable. - /// - /// If the `check-propagations` feature is turned on, then the inference checker will be used - /// to verify the propagations done by this propagator are correct. - /// - /// See [`InferenceChecker`] for more information. - fn add_inference_checkers(&self, _checkers: InferenceCheckers<'_>) {} - /// Create the propagator instance from `Self`. fn create(self, context: PropagatorConstructorContext) -> Self::PropagatorImpl; } -/// Interface used to add [`InferenceChecker`]s to the [`State`]. -#[derive(Debug)] -pub struct InferenceCheckers<'state> { - state: &'state mut State, - reification_literal: Option, -} - -impl<'state> InferenceCheckers<'state> { - #[cfg(feature = "check-propagations")] - pub(crate) fn new(state: &'state mut State) -> Self { - InferenceCheckers { - state, - reification_literal: None, - } - } -} - -impl InferenceCheckers<'_> { - /// Forwards to [`State::add_inference_checker`]. - pub fn add_inference_checker( - &mut self, - inference_code: InferenceCode, - checker: Box>, - ) { - if let Some(reification_literal) = self.reification_literal { - let reification_checker = ReifiedChecker { - inner: checker.into(), - reification_literal, - }; - self.state - .add_inference_checker(inference_code, Box::new(reification_checker)); - } else { - self.state.add_inference_checker(inference_code, checker); - } - } - - pub fn with_reification_literal(&mut self, literal: Literal) { - self.reification_literal = Some(literal) - } -} - /// [`PropagatorConstructorContext`] is used when [`Propagator`]s are initialised after creation. /// /// It represents a communication point between the [`Solver`] and the [`Propagator`]. @@ -110,6 +65,19 @@ pub struct PropagatorConstructorContext<'a> { /// Marker to indicate whether the constructor registered for at least one domain event or /// predicate becoming assigned. If not, the [`Drop`] implementation will cause a panic. did_register: RefOrOwned<'a, bool>, + + /// Pending consistency checkers to be registered into [`State`] when this context is dropped. + #[cfg(feature = "check-consistency")] + pub(crate) pending_consistency_checkers: RefOrOwned<'a, Vec<(Scope, BoxedRetentionChecker)>>, + + /// Pending inference checkers to be registered into [`State`] when this context is dropped. + #[cfg(feature = "check-propagations")] + #[allow( + clippy::type_complexity, + reason = "it's not clear where the type alias would live" + )] + pub(crate) pending_inference_checkers: + RefOrOwned<'a, Vec<(InferenceCode, Box>)>>, } impl PropagatorConstructorContext<'_> { @@ -122,6 +90,10 @@ impl PropagatorConstructorContext<'_> { propagator_id, state, did_register: RefOrOwned::Owned(false), + #[cfg(feature = "check-consistency")] + pending_consistency_checkers: RefOrOwned::Owned(vec![]), + #[cfg(feature = "check-propagations")] + pending_inference_checkers: RefOrOwned::Owned(vec![]), } } @@ -225,19 +197,51 @@ impl PropagatorConstructorContext<'_> { next_local_id: self.next_local_id.reborrow(), did_register: self.did_register.reborrow(), state: self.state, + #[cfg(feature = "check-consistency")] + pending_consistency_checkers: self.pending_consistency_checkers.reborrow(), + #[cfg(feature = "check-propagations")] + pending_inference_checkers: self.pending_inference_checkers.reborrow(), + } + } + + /// Add a consistency checker for the given constraint and scope. + /// + /// If the `check-consistency` feature is not enabled, this is a no-op. + pub fn add_consistency_checker( + &mut self, + scope: impl Into, + checker: impl Into, + ) { + #[cfg(feature = "check-consistency")] + self.pending_consistency_checkers + .push((scope.into(), checker.into())); + + // Avoid unused variable warning. + #[cfg(not(feature = "check-consistency"))] + { + let _ = scope; + let _ = checker; } } /// Add an inference checker for inferences produced by the propagator. /// - /// If the `check-propagations` feature is not enabled, adding an [`InferenceChecker`] will not - /// do anything. + /// If the `check-propagations` feature is not enabled, this is a no-op. pub fn add_inference_checker( &mut self, inference_code: InferenceCode, checker: Box>, ) { - self.state.add_inference_checker(inference_code, checker); + #[cfg(feature = "check-propagations")] + self.pending_inference_checkers + .push((inference_code, checker)); + + // Avoid unused variable warning. + #[cfg(not(feature = "check-propagations"))] + { + let _ = inference_code; + let _ = checker; + } } /// Set the next local id to be at least one more than the largest encountered local id. @@ -256,7 +260,8 @@ impl Drop for PropagatorConstructorContext<'_> { } let did_register = match self.did_register { - // If we are in a reborrowed context, we do not want to enforce registration. + // If we are in a reborrowed context, we do not want to enforce registration or drain + // pending checkers (the root context handles this). RefOrOwned::Ref(_) => return, RefOrOwned::Owned(did_register) => did_register, @@ -267,6 +272,16 @@ impl Drop for PropagatorConstructorContext<'_> { "Propagator did not register to be enqueued. If this is intentional, call PropagatorConstructorContext::will_not_register_any_events()." ); } + + #[cfg(feature = "check-consistency")] + for (scope, checker) in self.pending_consistency_checkers.drain(..) { + self.state.add_consistency_checker(scope, checker); + } + + #[cfg(feature = "check-propagations")] + for (inference_code, checker) in self.pending_inference_checkers.drain(..) { + self.state.add_inference_checker(inference_code, checker); + } } } diff --git a/pumpkin-crates/core/src/propagation/contexts/propagation_context.rs b/pumpkin-crates/core/src/propagation/contexts/propagation_context.rs index ab3218552..0ebe4ffd0 100644 --- a/pumpkin-crates/core/src/propagation/contexts/propagation_context.rs +++ b/pumpkin-crates/core/src/propagation/contexts/propagation_context.rs @@ -1,4 +1,16 @@ +#[cfg(feature = "check-propagations")] +use pumpkin_checking::BoxedChecker; +use pumpkin_checking::InferenceChecker; + use crate::basic_types::PredicateId; +use crate::checkers::BoxedRetentionChecker; +#[cfg(feature = "check-consistency")] +use crate::checkers::ConsistencyCheckerStore; +#[cfg(feature = "check-propagations")] +use crate::checkers::PropagationChecker; +use crate::checkers::Scope; +#[cfg(feature = "check-propagations")] +use crate::containers::HashMap; use crate::engine::Assignments; use crate::engine::EmptyDomain; use crate::engine::EmptyDomainConflict; @@ -10,6 +22,7 @@ use crate::engine::reason::Reason; use crate::engine::reason::ReasonStore; use crate::engine::reason::StoredReason; use crate::engine::variables::Literal; +use crate::proof::InferenceCode; use crate::propagation::DomainEvents; use crate::propagation::Domains; use crate::propagation::HasAssignments; @@ -84,6 +97,11 @@ pub struct PropagationContext<'a> { pub(crate) propagator_id: PropagatorId, pub(crate) notification_engine: &'a mut NotificationEngine, reification_literal: Option, + + #[cfg(feature = "check-consistency")] + pub(crate) consistency_checkers: &'a mut ConsistencyCheckerStore, + #[cfg(feature = "check-propagations")] + pub(crate) inference_checkers: &'a mut HashMap>, } impl<'a> HasAssignments for PropagationContext<'a> { @@ -107,6 +125,11 @@ impl<'a> PropagationContext<'a> { reason_store: &'a mut ReasonStore, notification_engine: &'a mut NotificationEngine, propagator_id: PropagatorId, + #[cfg(feature = "check-consistency")] consistency_checkers: &'a mut ConsistencyCheckerStore, + #[cfg(feature = "check-propagations")] inference_checkers: &'a mut HashMap< + InferenceCode, + Vec, + >, ) -> Self { PropagationContext { trailed_values, @@ -115,6 +138,62 @@ impl<'a> PropagationContext<'a> { propagator_id, notification_engine, reification_literal: None, + #[cfg(feature = "check-consistency")] + consistency_checkers, + #[cfg(feature = "check-propagations")] + inference_checkers, + } + } + + /// Add a consistency checker for the given constraint and scope. + /// + /// If the `check-consistency` feature is not enabled, this is a no-op. + pub fn add_consistency_checker( + &mut self, + scope: impl Into, + checker: impl Into, + ) { + pumpkin_assert_simple!( + self.reification_literal.is_none(), + "Cannot add consistency checkers from within a reified propagation context." + ); + + #[cfg(feature = "check-consistency")] + self.consistency_checkers + .register(scope.into(), checker.into()); + + // Use variables to avoid unused warnings. + #[cfg(not(feature = "check-consistency"))] + { + let _ = scope; + let _ = checker; + } + } + + /// Add an inference checker for inferences produced by the propagator. + /// + /// If the `check-propagations` feature is not enabled, this is a no-op. + pub fn add_inference_checker( + &mut self, + inference_code: InferenceCode, + checker: Box>, + ) { + pumpkin_assert_simple!( + self.reification_literal.is_none(), + "Cannot add inference checkers from within a reified propagation context." + ); + + #[cfg(feature = "check-propagations")] + self.inference_checkers + .entry(inference_code) + .or_default() + .push(PropagationChecker::new(BoxedChecker::from(checker))); + + // Use variables to avoid unused warnings. + #[cfg(not(feature = "check-propagations"))] + { + let _ = inference_code; + let _ = checker; } } @@ -224,6 +303,10 @@ impl<'a> PropagationContext<'a> { propagator_id: self.propagator_id, notification_engine: self.notification_engine, reification_literal: self.reification_literal, + #[cfg(feature = "check-consistency")] + consistency_checkers: self.consistency_checkers, + #[cfg(feature = "check-propagations")] + inference_checkers: self.inference_checkers, } } } diff --git a/pumpkin-crates/core/src/propagators/hypercube_linear/propagator.rs b/pumpkin-crates/core/src/propagators/hypercube_linear/propagator.rs index fc1e0a3f7..92d0f7fb9 100644 --- a/pumpkin-crates/core/src/propagators/hypercube_linear/propagator.rs +++ b/pumpkin-crates/core/src/propagators/hypercube_linear/propagator.rs @@ -7,7 +7,6 @@ use crate::predicates::PropositionalConjunction; use crate::proof::ConstraintTag; use crate::proof::InferenceCode; use crate::propagation::DomainEvents; -use crate::propagation::InferenceCheckers; use crate::propagation::LocalId; use crate::propagation::PropagationContext; use crate::propagation::Propagator; @@ -33,17 +32,6 @@ pub struct HypercubeLinearConstructor { impl PropagatorConstructor for HypercubeLinearConstructor { type PropagatorImpl = HypercubeLinearPropagator; - fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { - checkers.add_inference_checker( - InferenceCode::new(self.constraint_tag, HypercubeLinear), - Box::new(HypercubeLinearChecker { - hypercube: self.hypercube.iter_predicates().collect(), - terms: self.linear.terms().collect(), - bound: self.linear.bound(), - }), - ); - } - fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { let HypercubeLinearConstructor { hypercube, @@ -51,6 +39,15 @@ impl PropagatorConstructor for HypercubeLinearConstructor { constraint_tag, } = self; + context.add_inference_checker( + InferenceCode::new(constraint_tag, HypercubeLinear), + Box::new(HypercubeLinearChecker { + hypercube: hypercube.iter_predicates().collect(), + terms: linear.terms().collect(), + bound: linear.bound(), + }), + ); + let hypercube_predicates = hypercube.iter_predicates().collect::>(); let watched_predicates = if hypercube_predicates.is_empty() { diff --git a/pumpkin-crates/core/src/propagators/mod.rs b/pumpkin-crates/core/src/propagators/mod.rs index 2af1b61a5..25afdc21d 100644 --- a/pumpkin-crates/core/src/propagators/mod.rs +++ b/pumpkin-crates/core/src/propagators/mod.rs @@ -1,5 +1,5 @@ pub mod hypercube_linear; pub mod nogoods; -pub(crate) mod reified_propagator; +mod reified_propagator; pub use reified_propagator::*; diff --git a/pumpkin-crates/core/src/propagators/nogoods/checker.rs b/pumpkin-crates/core/src/propagators/nogoods/checker.rs index 700ee6a2c..381b15ba4 100644 --- a/pumpkin-crates/core/src/propagators/nogoods/checker.rs +++ b/pumpkin-crates/core/src/propagators/nogoods/checker.rs @@ -2,6 +2,13 @@ use std::fmt::Debug; use pumpkin_checking::AtomicConstraint; use pumpkin_checking::InferenceChecker; +use pumpkin_checking::VariableState; + +use crate::checkers::RetentionChecker; +use crate::checkers::Scope; +use crate::predicates::Predicate; +use crate::propagation::Domains; +use crate::propagation::ReadDomains; #[derive(Debug, Clone)] pub struct NogoodChecker { @@ -12,12 +19,85 @@ impl InferenceChecker for NogoodChecker where Atomic: AtomicConstraint + Clone + Debug, { - fn check( - &self, - state: pumpkin_checking::VariableState, - _: &[Atomic], - _: Option<&Atomic>, - ) -> bool { + fn check(&self, state: VariableState, _: &[Atomic], _: Option<&Atomic>) -> bool { self.nogood.iter().all(|atomic| state.is_true(atomic)) } } + +impl RetentionChecker for NogoodChecker { + fn check_retention(&mut self, _: &Scope, domains: Domains<'_>) -> bool { + // For unit propagation, the state is consistent if: + // - at least two predicates are unassigned + // - or otherwise, at least one predicate is assigned + + let untrue_predicate_count = self + .nogood + .iter() + .filter(|&&predicate| domains.evaluate_predicate(predicate) != Some(true)) + .count(); + + if untrue_predicate_count >= 2 { + // If at least two predicates are not true, then the domain is + // unit-propagation consistent. + return true; + } + + // At least one predicate must be false for the domain to be unit-propagation consistent. + self.nogood + .iter() + .any(|&predicate| domains.evaluate_predicate(predicate) == Some(false)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::conjunction; + use crate::propagation::LocalId; + use crate::state::State; + + #[test] + fn a_nogood_with_multiple_untrue_predicates_is_consistent() { + let mut state = State::default(); + + let x = state.new_interval_variable(1, 5, Some("x".into())); + let y = state.new_interval_variable(1, 5, Some("y".into())); + + let mut checker = NogoodChecker { + nogood: conjunction!([x >= 4] & [y <= 2]).into(), + }; + + let scope = Scope::from_iter([(LocalId::from(0), x), (LocalId::from(1), y)]); + assert!(checker.check_retention(&scope, state.get_domains())); + } + + #[test] + fn a_nogood_with_one_untrue_predicates_and_no_false_predicates_is_inconsistent() { + let mut state = State::default(); + + let x = state.new_interval_variable(1, 5, Some("x".into())); + let y = state.new_interval_variable(1, 5, Some("y".into())); + + let mut checker = NogoodChecker { + nogood: conjunction!([x >= 4] & [y <= 5]).into(), + }; + + let scope = Scope::from_iter([(LocalId::from(0), x), (LocalId::from(1), y)]); + assert!(!checker.check_retention(&scope, state.get_domains())); + } + + #[test] + fn a_nogood_with_any_false_predicates_is_consistent() { + let mut state = State::default(); + + let x = state.new_interval_variable(1, 3, Some("x".into())); + let y = state.new_interval_variable(1, 5, Some("y".into())); + + let mut checker = NogoodChecker { + nogood: conjunction!([x >= 4] & [y <= 2]).into(), + }; + + let scope = Scope::from_iter([(LocalId::from(0), x), (LocalId::from(1), y)]); + assert!(checker.check_retention(&scope, state.get_domains())); + } +} diff --git a/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs b/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs index 4312f65cd..291493054 100644 --- a/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs +++ b/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs @@ -1,5 +1,11 @@ use std::cmp::max; use std::ops::Not; +#[cfg(feature = "check-consistency")] +use std::sync::Arc; +#[cfg(feature = "check-consistency")] +use std::sync::atomic::AtomicBool; +#[cfg(feature = "check-consistency")] +use std::sync::atomic::Ordering; use log::warn; @@ -8,6 +14,10 @@ use super::NogoodId; use super::NogoodInfo; use crate::basic_types::PredicateId; use crate::basic_types::PropositionalConjunction; +#[cfg(feature = "check-consistency")] +use crate::checkers::Scope; +#[cfg(feature = "check-consistency")] +use crate::checkers::SelfDisablingChecker; use crate::containers::KeyedVec; use crate::containers::StorageKey; use crate::engine::Assignments; @@ -79,6 +89,13 @@ pub struct NogoodPropagator { /// current subtree. To test for that, we compare this handle with the propagator ID of a /// proapgated literal to see if this propagator propagated a predicate. handle: PropagatorHandle, + + /// Flags shared with consistency checkers to signal that a nogood has been deleted. + /// + /// When clause management deletes a nogood, the corresponding flag is set to `true`, causing + /// the checker to become a no-op. + #[cfg(feature = "check-consistency")] + deletion_flags: KeyedVec>, } /// [`PropagatorConstructor`] for constructing a new instance of the [`NogoodPropagator`] with the @@ -117,6 +134,8 @@ impl PropagatorConstructor for NogoodPropagatorConstructor { lbd_helper: Default::default(), bumped_nogoods: Default::default(), temp_nogood_reason: Default::default(), + #[cfg(feature = "check-consistency")] + deletion_flags: Default::default(), } } } @@ -452,6 +471,10 @@ impl NogoodPropagator { .lbd_helper .compute_lbd(&nogood.as_slice()[1..], context); + // Capture checker predicates before conversion to PredicateIds. + #[cfg(any(feature = "check-consistency", feature = "check-propagations"))] + let checker_predicates: Box<[Predicate]> = nogood.clone().into(); + let nogood = nogood .iter() .map(|predicate| context.get_id(*predicate)) @@ -464,7 +487,18 @@ impl NogoodPropagator { let _ = self .nogood_info .push(NogoodInfo::new_learned_nogood_info(lbd)); - let _ = self.inference_codes.push(inference_code); + let _ = self.inference_codes.push(inference_code.clone()); + + #[cfg(feature = "check-consistency")] + self.add_consistency_checker(checker_predicates.clone(), context); + + #[cfg(feature = "check-propagations")] + context.add_inference_checker( + inference_code, + Box::new(super::NogoodChecker { + nogood: checker_predicates, + }), + ); let watcher = Watcher { nogood_id, @@ -619,6 +653,10 @@ impl NogoodPropagator { // // The preprocessing ensures that all predicates are unassigned. else { + // Capture the checker predicates before conversion to PredicateIds. + #[cfg(any(feature = "check-consistency", feature = "check-propagations"))] + let checker_predicates: Box<[Predicate]> = input_nogood.clone().into(); + #[cfg(feature = "check-propagations")] let nogood = input_nogood .iter() @@ -638,10 +676,21 @@ impl NogoodPropagator { let _ = self .nogood_info .push(NogoodInfo::new_permanent_nogood_info()); - let _ = self.inference_codes.push(inference_code); + let _ = self.inference_codes.push(inference_code.clone()); self.permanent_nogood_ids.push(nogood_id); + #[cfg(feature = "check-consistency")] + self.add_consistency_checker(checker_predicates.clone(), context); + + #[cfg(feature = "check-propagations")] + context.add_inference_checker( + inference_code, + Box::new(super::NogoodChecker { + nogood: checker_predicates, + }), + ); + let watcher = Watcher { nogood_id, cached_predicate: self.nogood_predicates[nogood_id][0], @@ -663,6 +712,58 @@ impl NogoodPropagator { Ok(()) } } + + /// Add a consistency checker for the given nogood predicates. + #[cfg(feature = "check-consistency")] + fn add_consistency_checker( + &mut self, + nogood: Box<[Predicate]>, + context: &mut PropagationContext, + ) { + let scope = build_nogood_scope(&nogood); + let checker = SelfDisablingChecker::new(super::NogoodChecker { nogood }); + let _ = self.deletion_flags.push(checker.deletion_flag()); + context.add_consistency_checker(scope, checker); + } +} + +/// Build a [`Scope`] for a nogood by extracting unique [`DomainId`]s from its predicates. +/// +/// Avoids multiple enqueuing of the consistency checker if the nogood contains multiple +/// predicates over the same variable. +#[cfg(feature = "check-consistency")] +fn build_nogood_scope(predicates: &[Predicate]) -> Scope { + use crate::containers::HashSet; + use crate::containers::KeyGenerator; + use crate::variables::DomainId; + + let mut scope = Scope::default(); + let mut seen: HashSet = HashSet::default(); + let mut id_generator = KeyGenerator::default(); + + for predicate in predicates { + let domain = predicate.get_domain(); + if seen.insert(domain) { + scope.add_domain(id_generator.next_key(), domain); + } + } + + scope +} + +#[cfg(feature = "check-consistency")] +impl NogoodPropagator { + /// Set the deletion flag for every nogood that has been marked as deleted in `nogood_info`. + /// + /// Called after clause management removes nogoods so that consistency checkers self-disable. + fn signal_deleted_checker_flags(&self) { + for idx in 0..self.nogood_info.len() { + let idx = NogoodIndex::create_from_index(idx); + if self.nogood_info[idx].is_deleted { + self.deletion_flags[idx].store(true, Ordering::Relaxed); + } + } + } } /// Methods concerning the watchers and watch lists @@ -810,6 +911,8 @@ impl NogoodPropagator { } if removed_at_least_one_nogood { + #[cfg(feature = "check-consistency")] + self.signal_deleted_checker_flags(); self.remove_deleted_nogoods_from_watchers(assignments, notification_engine); } } diff --git a/pumpkin-crates/core/src/propagators/reified_propagator/checker.rs b/pumpkin-crates/core/src/propagators/reified_propagator/checker.rs new file mode 100644 index 000000000..825b62e6b --- /dev/null +++ b/pumpkin-crates/core/src/propagators/reified_propagator/checker.rs @@ -0,0 +1,68 @@ +use pumpkin_checking::AtomicConstraint; +use pumpkin_checking::BoxedChecker; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::InferenceChecker; +use pumpkin_checking::VariableState; + +use crate::checkers::BoxedRetentionChecker; +use crate::checkers::RetentionChecker; +use crate::checkers::Scope; +use crate::propagation::Domains; +use crate::propagation::LocalId; +use crate::propagation::ReadDomains; +use crate::variables::Literal; + +/// A [`ConsistencyChecker`] wrapper that skips the inner check when the reification literal is +/// not assigned to true. +#[derive(Debug, Clone)] +pub struct ReifiedConsistencyChecker { + pub inner: BoxedRetentionChecker, + pub reification_literal: Literal, + /// The [`LocalId`] of the reification literal in the scope, used to strip it before passing + /// the scope to the inner checker. + pub reification_literal_id: LocalId, +} + +impl RetentionChecker for ReifiedConsistencyChecker { + fn check_retention(&mut self, scope: &Scope, domains: Domains<'_>) -> bool { + if domains.evaluate_literal(self.reification_literal) != Some(true) { + return true; + } + + let inner_scope = scope.without(self.reification_literal_id); + self.inner.check_retention(&inner_scope, domains) + } +} + +#[derive(Debug, Clone)] +pub struct ReifiedChecker { + pub inner: BoxedChecker, + pub reification_literal: Var, +} + +impl InferenceChecker for ReifiedChecker +where + Atomic: AtomicConstraint + Clone, + Var: CheckerVariable, +{ + fn check( + &self, + state: VariableState, + premises: &[Atomic], + consequent: Option<&Atomic>, + ) -> bool { + if self.reification_literal.induced_domain_contains(&state, 0) { + return false; + } + + if let Some(consequent) = consequent + && self + .reification_literal + .does_atomic_constrain_self(consequent) + { + self.inner.check(state, premises, None) + } else { + self.inner.check(state, premises, consequent) + } + } +} diff --git a/pumpkin-crates/core/src/propagators/reified_propagator/constructor.rs b/pumpkin-crates/core/src/propagators/reified_propagator/constructor.rs new file mode 100644 index 000000000..deabe1a0e --- /dev/null +++ b/pumpkin-crates/core/src/propagators/reified_propagator/constructor.rs @@ -0,0 +1,103 @@ +#[cfg(feature = "check-consistency")] +use crate::checkers::BoxedRetentionChecker; +use crate::propagation::DomainEvents; +#[cfg(feature = "check-consistency")] +use crate::propagation::LocalId; +use crate::propagation::Propagator; +use crate::propagation::PropagatorConstructor; +use crate::propagation::PropagatorConstructorContext; +#[cfg(feature = "check-consistency")] +use crate::propagators::ReifiedConsistencyChecker; +use crate::propagators::ReifiedPropagator; +use crate::variables::Literal; + +/// A [`PropagatorConstructor`] for the reified propagator. +#[derive(Clone, Debug)] +pub struct ReifiedPropagatorArgs { + pub propagator: WrappedArgs, + pub reification_literal: Literal, +} + +impl PropagatorConstructor for ReifiedPropagatorArgs +where + WrappedArgs: PropagatorConstructor, + WrappedPropagator: Propagator + Clone, +{ + type PropagatorImpl = ReifiedPropagator; + + fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { + let ReifiedPropagatorArgs { + propagator, + reification_literal, + } = self; + + let propagator = propagator.create(context.reborrow()); + + let reification_literal_id = context.get_next_local_id(); + + context.register( + reification_literal, + DomainEvents::BOUNDS, + reification_literal_id, + ); + + #[cfg(feature = "check-propagations")] + wrap_inference_checkers(&mut context, reification_literal); + + #[cfg(feature = "check-consistency")] + wrap_consistency_checkers(&mut context, reification_literal, reification_literal_id); + + let name = format!("Reified({})", propagator.name()); + + ReifiedPropagator { + propagator, + reification_literal, + reification_literal_id, + name, + reason_buffer: vec![], + } + } +} + +/// Wrap inference checkers: the literal is already known, no local id needed. +#[cfg(feature = "check-propagations")] +fn wrap_inference_checkers( + context: &mut PropagatorConstructorContext<'_>, + reification_literal: Literal, +) { + use crate::propagators::ReifiedChecker; + + for (_, checker) in context.pending_inference_checkers.iter_mut() { + replace_with::replace_with_or_abort(checker, |inner_checker| { + use pumpkin_checking::BoxedChecker; + + Box::new(ReifiedChecker { + inner: BoxedChecker::from(inner_checker), + reification_literal, + }) + }); + } +} + +/// Wrap consistency checkers: add the reification literal to each scope with the now-known +/// local id, then wrap the checker. +#[cfg(feature = "check-consistency")] +fn wrap_consistency_checkers( + context: &mut PropagatorConstructorContext<'_>, + reification_literal: Literal, + reification_literal_id: LocalId, +) { + use crate::checkers::ScopeItem; + + for (scope, checker) in context.pending_consistency_checkers.iter_mut() { + reification_literal.add_to_scope(scope, reification_literal_id); + + replace_with::replace_with_or_abort(checker, |inner_checker| { + BoxedRetentionChecker::from(ReifiedConsistencyChecker { + inner: inner_checker, + reification_literal, + reification_literal_id, + }) + }); + } +} diff --git a/pumpkin-crates/core/src/propagators/reified_propagator/mod.rs b/pumpkin-crates/core/src/propagators/reified_propagator/mod.rs new file mode 100644 index 000000000..550564f93 --- /dev/null +++ b/pumpkin-crates/core/src/propagators/reified_propagator/mod.rs @@ -0,0 +1,7 @@ +mod checker; +mod constructor; +mod propagator; + +pub use checker::*; +pub use constructor::*; +pub use propagator::*; diff --git a/pumpkin-crates/core/src/propagators/reified_propagator.rs b/pumpkin-crates/core/src/propagators/reified_propagator/propagator.rs similarity index 83% rename from pumpkin-crates/core/src/propagators/reified_propagator.rs rename to pumpkin-crates/core/src/propagators/reified_propagator/propagator.rs index 0bb782495..6a683be04 100644 --- a/pumpkin-crates/core/src/propagators/reified_propagator.rs +++ b/pumpkin-crates/core/src/propagators/reified_propagator/propagator.rs @@ -1,76 +1,20 @@ -use pumpkin_checking::AtomicConstraint; -use pumpkin_checking::BoxedChecker; -use pumpkin_checking::CheckerVariable; -use pumpkin_checking::InferenceChecker; - use crate::engine::PropagationStatusCP; use crate::engine::notifications::OpaqueDomainEvent; use crate::predicates::Predicate; -use crate::propagation::DomainEvents; use crate::propagation::Domains; use crate::propagation::EnqueueDecision; use crate::propagation::ExplanationContext; -use crate::propagation::InferenceCheckers; use crate::propagation::LazyExplanation; use crate::propagation::LocalId; use crate::propagation::NotificationContext; use crate::propagation::Priority; use crate::propagation::PropagationContext; use crate::propagation::Propagator; -use crate::propagation::PropagatorConstructor; -use crate::propagation::PropagatorConstructorContext; use crate::propagation::ReadDomains; use crate::pumpkin_assert_simple; use crate::state::Conflict; use crate::variables::Literal; -/// A [`PropagatorConstructor`] for the reified propagator. -#[derive(Clone, Debug)] -pub struct ReifiedPropagatorArgs { - pub propagator: WrappedArgs, - pub reification_literal: Literal, -} - -impl PropagatorConstructor for ReifiedPropagatorArgs -where - WrappedArgs: PropagatorConstructor, - WrappedPropagator: Propagator + Clone, -{ - type PropagatorImpl = ReifiedPropagator; - - fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { - let ReifiedPropagatorArgs { - propagator, - reification_literal, - } = self; - - let propagator = propagator.create(context.reborrow()); - let reification_literal_id = context.get_next_local_id(); - - context.register( - self.reification_literal, - DomainEvents::BOUNDS, - reification_literal_id, - ); - - let name = format!("Reified({})", propagator.name()); - - ReifiedPropagator { - propagator, - reification_literal, - reification_literal_id, - name, - reason_buffer: vec![], - } - } - - fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { - checkers.with_reification_literal(self.reification_literal); - - self.propagator.add_inference_checkers(checkers); - } -} - /// Propagator for the constraint `r -> p`, where `r` is a Boolean literal and `p` is an arbitrary /// propagator. /// @@ -80,16 +24,16 @@ where /// propagated to false. #[derive(Clone, Debug)] pub struct ReifiedPropagator { - propagator: WrappedPropagator, - reification_literal: Literal, + pub(super) propagator: WrappedPropagator, + pub(super) reification_literal: Literal, /// The formatted name of the propagator. - name: String, + pub(super) name: String, /// The `LocalId` of the reification literal. Is guaranteed to be a larger ID than any of the /// registered ids of the wrapped propagator. - reification_literal_id: LocalId, + pub(super) reification_literal_id: LocalId, /// Holds the lazy explanations. - reason_buffer: Vec, + pub(super) reason_buffer: Vec, } impl Propagator for ReifiedPropagator { @@ -231,37 +175,6 @@ impl ReifiedPropagator { } } -#[derive(Debug, Clone)] -pub struct ReifiedChecker { - pub inner: BoxedChecker, - pub reification_literal: Var, -} - -impl> InferenceChecker - for ReifiedChecker -{ - fn check( - &self, - state: pumpkin_checking::VariableState, - premises: &[Atomic], - consequent: Option<&Atomic>, - ) -> bool { - if self.reification_literal.induced_domain_contains(&state, 0) { - return false; - } - - if let Some(consequent) = consequent - && self - .reification_literal - .does_atomic_constrain_self(consequent) - { - self.inner.check(state, premises, None) - } else { - self.inner.check(state, premises, consequent) - } - } -} - #[allow(deprecated, reason = "Will be refactored")] #[cfg(test)] mod tests { @@ -274,6 +187,10 @@ mod tests { use crate::predicates::PropositionalConjunction; use crate::proof::ConstraintTag; use crate::proof::InferenceCode; + use crate::propagation::DomainEvents; + use crate::propagation::PropagatorConstructor; + use crate::propagation::PropagatorConstructorContext; + use crate::propagators::ReifiedPropagatorArgs; use crate::variables::DomainId; #[test] diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/absolute_value.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/absolute_value.rs index 2881dfa39..e06585952 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/absolute_value.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/absolute_value.rs @@ -8,7 +8,6 @@ use pumpkin_core::predicate; use pumpkin_core::proof::ConstraintTag; use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvents; -use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::Priority; use pumpkin_core::propagation::PropagationContext; @@ -35,16 +34,6 @@ where { type PropagatorImpl = AbsoluteValuePropagator; - fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { - checkers.add_inference_checker( - InferenceCode::new(self.constraint_tag, AbsoluteValue), - Box::new(AbsoluteValueChecker { - signed: self.signed.clone(), - absolute: self.absolute.clone(), - }), - ); - } - fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { let AbsoluteValueArgs { signed, @@ -52,6 +41,14 @@ where constraint_tag, } = self; + context.add_inference_checker( + InferenceCode::new(constraint_tag, AbsoluteValue), + Box::new(AbsoluteValueChecker { + signed: signed.clone(), + absolute: absolute.clone(), + }), + ); + context.register(signed.clone(), DomainEvents::BOUNDS, LocalId::from(0)); context.register(absolute.clone(), DomainEvents::BOUNDS, LocalId::from(1)); diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/binary/binary_not_equals.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/binary/binary_not_equals.rs index b33a98543..e7e4669c8 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/binary/binary_not_equals.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/binary/binary_not_equals.rs @@ -8,7 +8,6 @@ use pumpkin_core::proof::ConstraintTag; use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvents; use pumpkin_core::propagation::Domains; -use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::Priority; use pumpkin_core::propagation::PropagationContext; @@ -37,16 +36,6 @@ where { type PropagatorImpl = BinaryNotEqualsPropagator; - fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { - checkers.add_inference_checker( - InferenceCode::new(self.constraint_tag, BinaryNotEquals), - Box::new(BinaryNotEqualsChecker { - lhs: self.a.clone(), - rhs: self.b.clone(), - }), - ); - } - fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { let BinaryNotEqualsPropagatorArgs { a, @@ -54,6 +43,14 @@ where constraint_tag, } = self; + context.add_inference_checker( + InferenceCode::new(constraint_tag, BinaryNotEquals), + Box::new(BinaryNotEqualsChecker { + lhs: a.clone(), + rhs: b.clone(), + }), + ); + // We only care about the case where one of the two is assigned context.register(a.clone(), DomainEvents::ASSIGN, LocalId::from(0)); context.register(b.clone(), DomainEvents::ASSIGN, LocalId::from(1)); diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/binary/mod.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/binary/mod.rs index 8d3c6bbb8..39cfb5e78 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/binary/mod.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/binary/mod.rs @@ -1,5 +1,3 @@ -pub(crate) mod binary_equals; pub(crate) mod binary_not_equals; -pub use binary_equals::*; pub use binary_not_equals::*; diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/binary_equals/checker.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/binary_equals/checker.rs new file mode 100644 index 000000000..e50818450 --- /dev/null +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/binary_equals/checker.rs @@ -0,0 +1,81 @@ +use pumpkin_checking::AtomicConstraint; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::InferenceChecker; +use pumpkin_checking::IntExt; +use pumpkin_core::checkers::support::Support; +use pumpkin_core::checkers::support::SupportGenerator; +use pumpkin_core::checkers::support::SupportsValue; +use pumpkin_core::checkers::support::UnsupportedValue; +use pumpkin_core::propagation::Domains; +use pumpkin_core::propagation::LocalId; +use pumpkin_core::variables::IntegerVariable; + +#[derive(Clone, Debug)] +pub struct BinaryEqualsChecker { + pub lhs: Lhs, + pub rhs: Rhs, +} + +impl InferenceChecker for BinaryEqualsChecker +where + Atomic: AtomicConstraint, + Lhs: CheckerVariable, + Rhs: CheckerVariable, +{ + fn check( + &self, + mut state: pumpkin_checking::VariableState, + _: &[Atomic], + _: Option<&Atomic>, + ) -> bool { + // We apply the domain of variable 2 to variable 1. If the state remains consistent, then + // the step is unsound! + let mut consistent = true; + + if let IntExt::Int(value) = self.rhs.induced_upper_bound(&state) { + let atomic = self.lhs.atomic_less_than(value); + consistent &= state.apply(&atomic); + } + + if let IntExt::Int(value) = self.rhs.induced_lower_bound(&state) { + let atomic = self.lhs.atomic_greater_than(value); + consistent &= state.apply(&atomic); + } + + for value in self.rhs.induced_holes(&state).collect::>() { + let atomic = self.lhs.atomic_not_equal(value); + consistent &= state.apply(&atomic); + } + + !consistent + } +} + +impl SupportGenerator for BinaryEqualsChecker +where + Lhs: IntegerVariable + SupportsValue, + Rhs: IntegerVariable + SupportsValue, +{ + type Value = i32; + + fn support( + &mut self, + support: &mut Support, + local_id: LocalId, + value: UnsupportedValue, + _: &Domains<'_>, + ) { + let value = match local_id { + super::ID_LHS => self.lhs.unpack(value), + super::ID_RHS => self.rhs.unpack(value), + _ => unreachable!(), + }; + + self.lhs.assign(value, support); + self.rhs.assign(value, support); + } + + fn is_solution(&self, support: &Support) -> bool { + self.lhs.support_value(support) == self.rhs.support_value(support) + } +} diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/binary_equals/constructor.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/binary_equals/constructor.rs new file mode 100644 index 000000000..b144492ed --- /dev/null +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/binary_equals/constructor.rs @@ -0,0 +1,73 @@ +use pumpkin_core::checkers::StrongConsistency; +use pumpkin_core::checkers::StrongRetentionChecker; +use pumpkin_core::containers::HashSet; +use pumpkin_core::predicates::Predicate; +use pumpkin_core::proof::ConstraintTag; +use pumpkin_core::proof::InferenceCode; +use pumpkin_core::propagation::DomainEvents; +use pumpkin_core::propagation::PropagatorConstructor; +use pumpkin_core::propagation::PropagatorConstructorContext; +use pumpkin_core::variables::IntegerVariable; + +use crate::arithmetic::BinaryEqualsChecker; +use crate::arithmetic::BinaryEqualsPropagator; + +/// The [`PropagatorConstructor`] for the [`BinaryEqualsPropagator`]. +#[derive(Clone, Debug)] +pub struct BinaryEqualsPropagatorArgs { + pub a: AVar, + pub b: BVar, + pub constraint_tag: ConstraintTag, +} + +impl PropagatorConstructor for BinaryEqualsPropagatorArgs +where + AVar: IntegerVariable + 'static, + BVar: IntegerVariable + 'static, +{ + type PropagatorImpl = BinaryEqualsPropagator; + + fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { + let BinaryEqualsPropagatorArgs { + a, + b, + constraint_tag, + } = self; + + context.add_inference_checker( + InferenceCode::new(constraint_tag, super::BinaryEquals), + Box::new(BinaryEqualsChecker { + lhs: a.clone(), + rhs: b.clone(), + }), + ); + + context.add_consistency_checker( + ((super::ID_LHS, &a), (super::ID_RHS, &b)), + StrongRetentionChecker::new( + StrongConsistency::Domain, + BinaryEqualsChecker { + lhs: a.clone(), + rhs: b.clone(), + }, + ), + ); + + context.register(a.clone(), DomainEvents::ANY_INT, super::ID_LHS); + context.register(b.clone(), DomainEvents::ANY_INT, super::ID_RHS); + + BinaryEqualsPropagator { + a, + b, + + a_removed_values: HashSet::default(), + b_removed_values: HashSet::default(), + + inference_code: InferenceCode::new(constraint_tag, super::BinaryEquals), + + has_backtracked: false, + first_propagation_loop: true, + reason: Predicate::trivially_false(), + } + } +} diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/binary_equals/mod.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/binary_equals/mod.rs new file mode 100644 index 000000000..bc6e585ef --- /dev/null +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/binary_equals/mod.rs @@ -0,0 +1,14 @@ +mod checker; +mod constructor; +mod propagator; + +pub use checker::*; +pub use constructor::*; +pub use propagator::*; +use pumpkin_core::declare_inference_label; +use pumpkin_core::propagation::LocalId; + +const ID_LHS: LocalId = LocalId::from(0); +const ID_RHS: LocalId = LocalId::from(1); + +declare_inference_label!(BinaryEquals); diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/binary/binary_equals.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/binary_equals/propagator.rs similarity index 81% rename from pumpkin-crates/propagators/src/propagators/arithmetic/binary/binary_equals.rs rename to pumpkin-crates/propagators/src/propagators/arithmetic/binary_equals/propagator.rs index b95b86b46..166af6a37 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/binary/binary_equals.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/binary_equals/propagator.rs @@ -2,26 +2,18 @@ use std::slice; use bitfield_struct::bitfield; -use pumpkin_checking::AtomicConstraint; -use pumpkin_checking::CheckerVariable; -use pumpkin_checking::InferenceChecker; -use pumpkin_checking::IntExt; use pumpkin_core::asserts::pumpkin_assert_advanced; use pumpkin_core::conjunction; use pumpkin_core::containers::HashSet; -use pumpkin_core::declare_inference_label; use pumpkin_core::predicate; use pumpkin_core::predicates::Predicate; use pumpkin_core::predicates::PredicateConstructor; use pumpkin_core::predicates::PredicateType; -use pumpkin_core::proof::ConstraintTag; use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvent; -use pumpkin_core::propagation::DomainEvents; use pumpkin_core::propagation::Domains; use pumpkin_core::propagation::EnqueueDecision; use pumpkin_core::propagation::ExplanationContext; -use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LazyExplanation; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::NotificationContext; @@ -29,99 +21,44 @@ use pumpkin_core::propagation::OpaqueDomainEvent; use pumpkin_core::propagation::Priority; use pumpkin_core::propagation::PropagationContext; use pumpkin_core::propagation::Propagator; -use pumpkin_core::propagation::PropagatorConstructor; -use pumpkin_core::propagation::PropagatorConstructorContext; use pumpkin_core::propagation::ReadDomains; use pumpkin_core::state::EmptyDomainConflict; use pumpkin_core::state::PropagationStatusCP; use pumpkin_core::state::PropagatorConflict; use pumpkin_core::variables::IntegerVariable; -declare_inference_label!(BinaryEquals); - -/// The [`PropagatorConstructor`] for the [`BinaryEqualsPropagator`]. -#[derive(Clone, Debug)] -pub struct BinaryEqualsPropagatorArgs { - pub a: AVar, - pub b: BVar, - pub constraint_tag: ConstraintTag, -} - -impl PropagatorConstructor for BinaryEqualsPropagatorArgs -where - AVar: IntegerVariable + 'static, - BVar: IntegerVariable + 'static, -{ - type PropagatorImpl = BinaryEqualsPropagator; - - fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { - checkers.add_inference_checker( - InferenceCode::new(self.constraint_tag, BinaryEquals), - Box::new(BinaryEqualsChecker { - lhs: self.a.clone(), - rhs: self.b.clone(), - }), - ); - } - - fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { - let BinaryEqualsPropagatorArgs { - a, - b, - constraint_tag, - } = self; - - context.register(a.clone(), DomainEvents::ANY_INT, LocalId::from(0)); - context.register(b.clone(), DomainEvents::ANY_INT, LocalId::from(1)); - - BinaryEqualsPropagator { - a, - b, - - a_removed_values: HashSet::default(), - b_removed_values: HashSet::default(), - - inference_code: InferenceCode::new(constraint_tag, BinaryEquals), - - has_backtracked: false, - first_propagation_loop: true, - reason: Predicate::trivially_false(), - } - } -} - /// Propagator for the constraint `a = b`. #[derive(Clone, Debug)] pub struct BinaryEqualsPropagator { - a: AVar, - b: BVar, + pub(super) a: AVar, + pub(super) b: BVar, /// The removed value from [`Self::a`]. /// /// These are tracked to make sure that they are also removed from [`Self::b`]. - a_removed_values: HashSet, + pub(super) a_removed_values: HashSet, /// The removed value from [`Self::b`] /// /// These are tracked to make sure that they are also removed from [`Self::a`]. - b_removed_values: HashSet, + pub(super) b_removed_values: HashSet, /// If a backtrack has occurred which caused one of the removals to be backtracked then we need /// to ensure that we do not erroneously remove values which are now part of the domain after /// backtracking. - has_backtracked: bool, + pub(super) has_backtracked: bool, /// If it is the first time that the propagator is called then we need to ensure that the /// domains of [`Self::a`] and [`Self::b`] are equal to the intersection of these domains. - first_propagation_loop: bool, + pub(super) first_propagation_loop: bool, - inference_code: InferenceCode, + pub(super) inference_code: InferenceCode, /// A re-usable buffer to store the explanations of propagations. This will always be a single /// [`Predicate`]. /// /// This field is only written to in the `lazy_explanation` function, as that returns a slice /// which needs to be owned somewhere. Hence we put that ownership here. - reason: Predicate, + pub(super) reason: Predicate, } impl BinaryEqualsPropagator @@ -401,47 +338,6 @@ struct BinaryEqualsPropagation { __: u16, } -#[derive(Clone, Debug)] -pub struct BinaryEqualsChecker { - pub lhs: Lhs, - pub rhs: Rhs, -} - -impl InferenceChecker for BinaryEqualsChecker -where - Atomic: AtomicConstraint, - Lhs: CheckerVariable, - Rhs: CheckerVariable, -{ - fn check( - &self, - mut state: pumpkin_checking::VariableState, - _: &[Atomic], - _: Option<&Atomic>, - ) -> bool { - // We apply the domain of variable 2 to variable 1. If the state remains consistent, then - // the step is unsound! - let mut consistent = true; - - if let IntExt::Int(value) = self.rhs.induced_upper_bound(&state) { - let atomic = self.lhs.atomic_less_than(value); - consistent &= state.apply(&atomic); - } - - if let IntExt::Int(value) = self.rhs.induced_lower_bound(&state) { - let atomic = self.lhs.atomic_greater_than(value); - consistent &= state.apply(&atomic); - } - - for value in self.rhs.induced_holes(&state).collect::>() { - let atomic = self.lhs.atomic_not_equal(value); - consistent &= state.apply(&atomic); - } - - !consistent - } -} - #[cfg(test)] mod tests { use pumpkin_core::state::State; diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/integer_division.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/integer_division.rs index 882495a48..a9a22125b 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/integer_division.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/integer_division.rs @@ -9,7 +9,6 @@ use pumpkin_core::predicate; use pumpkin_core::proof::ConstraintTag; use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvents; -use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::Priority; use pumpkin_core::propagation::PropagationContext; @@ -51,6 +50,15 @@ where constraint_tag, } = self; + context.add_inference_checker( + InferenceCode::new(constraint_tag, Division), + Box::new(IntegerDivisionChecker { + numerator: numerator.clone(), + denominator: denominator.clone(), + rhs: rhs.clone(), + }), + ); + pumpkin_assert_simple!( !context.contains(&denominator, 0), "Denominator cannot contain 0" @@ -69,17 +77,6 @@ where inference_code, } } - - fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { - checkers.add_inference_checker( - InferenceCode::new(self.constraint_tag, Division), - Box::new(IntegerDivisionChecker { - numerator: self.numerator.clone(), - denominator: self.denominator.clone(), - rhs: self.rhs.clone(), - }), - ); - } } /// A propagator for maintaining the constraint `numerator / denominator = rhs`; note that this diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/integer_multiplication.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/integer_multiplication.rs deleted file mode 100644 index 9e7d8d953..000000000 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/integer_multiplication.rs +++ /dev/null @@ -1,617 +0,0 @@ -use pumpkin_checking::AtomicConstraint; -use pumpkin_checking::CheckerVariable; -use pumpkin_checking::InferenceChecker; -use pumpkin_core::asserts::pumpkin_assert_simple; -use pumpkin_core::conjunction; -use pumpkin_core::declare_inference_label; -use pumpkin_core::predicate; -use pumpkin_core::proof::ConstraintTag; -use pumpkin_core::proof::InferenceCode; -use pumpkin_core::propagation::DomainEvents; -use pumpkin_core::propagation::InferenceCheckers; -use pumpkin_core::propagation::LocalId; -use pumpkin_core::propagation::Priority; -use pumpkin_core::propagation::PropagationContext; -use pumpkin_core::propagation::Propagator; -use pumpkin_core::propagation::PropagatorConstructor; -use pumpkin_core::propagation::PropagatorConstructorContext; -use pumpkin_core::propagation::ReadDomains; -use pumpkin_core::state::PropagationStatusCP; -use pumpkin_core::state::propagator_conflict; -use pumpkin_core::variables::IntegerVariable; - -declare_inference_label!(IntegerMultiplication); - -/// The [`PropagatorConstructor`] for [`IntegerMultiplicationPropagator`]. -#[derive(Clone, Debug)] -pub struct IntegerMultiplicationArgs { - pub a: VA, - pub b: VB, - pub c: VC, - pub constraint_tag: ConstraintTag, -} - -impl PropagatorConstructor for IntegerMultiplicationArgs -where - VA: IntegerVariable + 'static, - VB: IntegerVariable + 'static, - VC: IntegerVariable + 'static, -{ - type PropagatorImpl = IntegerMultiplicationPropagator; - - fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { - checkers.add_inference_checker( - InferenceCode::new(self.constraint_tag, IntegerMultiplication), - Box::new(IntegerMultiplicationChecker { - a: self.a.clone(), - b: self.b.clone(), - c: self.c.clone(), - }), - ); - } - - fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { - let IntegerMultiplicationArgs { - a, - b, - c, - constraint_tag, - } = self; - - context.register(a.clone(), DomainEvents::ANY_INT, ID_A); - context.register(b.clone(), DomainEvents::ANY_INT, ID_B); - context.register(c.clone(), DomainEvents::ANY_INT, ID_C); - - IntegerMultiplicationPropagator { - a, - b, - c, - inference_code: InferenceCode::new(constraint_tag, IntegerMultiplication), - } - } -} - -/// A propagator for maintaining the constraint `a * b = c`. The propagator -/// (currently) only propagates the signs of the variables, the case where a, b, c >= 0, and detects -/// a conflict if the variables are fixed. -#[derive(Clone, Debug)] -pub struct IntegerMultiplicationPropagator { - a: VA, - b: VB, - c: VC, - inference_code: InferenceCode, -} - -const ID_A: LocalId = LocalId::from(0); -const ID_B: LocalId = LocalId::from(1); -const ID_C: LocalId = LocalId::from(2); - -impl Propagator - for IntegerMultiplicationPropagator -where - VA: IntegerVariable, - VB: IntegerVariable, - VC: IntegerVariable, -{ - fn priority(&self) -> Priority { - Priority::High - } - - fn name(&self) -> &str { - "IntTimes" - } - - fn propagate_from_scratch(&self, context: PropagationContext) -> PropagationStatusCP { - perform_propagation(context, &self.a, &self.b, &self.c, &self.inference_code) - } -} - -fn perform_propagation( - mut context: PropagationContext, - a: &VA, - b: &VB, - c: &VC, - inference_code: &InferenceCode, -) -> PropagationStatusCP { - // First we propagate the signs - propagate_signs(&mut context, a, b, c, inference_code)?; - - let a_min = context.lower_bound(a); - let a_max = context.upper_bound(a); - let b_min = context.lower_bound(b); - let b_max = context.upper_bound(b); - let c_min = context.lower_bound(c); - let c_max = context.upper_bound(c); - - if a_min >= 0 && b_min >= 0 { - let new_max_c = a_max.saturating_mul(b_max); - let new_min_c = a_min.saturating_mul(b_min); - - // c is smaller than the maximum value that a * b can take - // - // We need the lower-bounds in the explanation as well because the reasoning does not - // hold in the case of a negative lower-bound - context.post( - predicate![c <= new_max_c], - ( - conjunction!([a >= 0] & [a <= a_max] & [b >= 0] & [b <= b_max]), - inference_code, - ), - )?; - - // c is larger than the minimum value that a * b can take - context.post( - predicate![c >= new_min_c], - (conjunction!([a >= a_min] & [b >= b_min]), inference_code), - )?; - } - - if b_min >= 0 && b_max >= 1 && c_min >= 1 { - // a >= ceil(c.min / b.max) - let bound = div_ceil_pos(c_min, b_max); - context.post( - predicate![a >= bound], - ( - conjunction!([c >= c_min] & [b >= 0] & [b <= b_max]), - inference_code, - ), - )?; - } - - if b_min >= 1 && c_min >= 0 && c_max >= 1 { - // a <= floor(c.max / b.min) - let bound = c_max / b_min; - context.post( - predicate![a <= bound], - ( - conjunction!([c >= 0] & [c <= c_max] & [b >= b_min]), - inference_code, - ), - )?; - } - - if a_min >= 1 && c_min >= 0 && c_max >= 1 { - // b <= floor(c.max / a.min) - let bound = c_max / a_min; - context.post( - predicate![b <= bound], - ( - conjunction!([c >= 0] & [c <= c_max] & [a >= a_min]), - inference_code, - ), - )?; - } - - // b >= ceil(c.min / a.max) - if a_min >= 0 && a_max >= 1 && c_min >= 1 { - let bound = div_ceil_pos(c_min, a_max); - - context.post( - predicate![b >= bound], - ( - conjunction!([c >= c_min] & [a >= 0] & [a <= a_max]), - inference_code, - ), - )?; - } - - if let Some(fixed_a) = context.fixed_value(a) - && let Some(fixed_b) = context.fixed_value(b) - && let Some(fixed_c) = context.fixed_value(c) - && (fixed_a * fixed_b) != fixed_c - { - // All variables are assigned but the resulting value is not correct, so we report a - // conflict - return propagator_conflict( - conjunction!( - [a == context.lower_bound(a)] - & [b == context.lower_bound(b)] - & [c == context.lower_bound(c)] - ), - inference_code, - ); - } - - Ok(()) -} - -/// Propagates the signs of the variables, it performs the following propagations: -/// - Propagating based on positive bounds -/// - If a is positive and b is positive then c is positive -/// - If a is positive and c is positive then b is positive -/// - If b is positive and c is positive then a is positive -/// - Propagating based on negative bounds -/// - If a is negative and b is negative then c is positive -/// - If a is negative and c is negative then b is positive -/// - If b is negative and c is negative then b is positive -/// - Propagating based on mixed bounds -/// - Propagating c based on a and b -/// - If a is negative and b is positive then c is negative -/// - If a is positive and b is negative then c is negative -/// - Propagating b based on a and c -/// - If a is negative and c is positive then b is negative -/// - If a is positive and c is negative then b is negative -/// - Propagating a based on b and c -/// - If b is negative and c is positive then a is negative -/// - If b is positive and c is negative then a is negative -/// -/// Note that this method does not propagate a value if 0 is in the domain as, for example, 0 * -3 = -/// 0 and 0 * 3 = 0 are both equally valid. -fn propagate_signs( - context: &mut PropagationContext, - a: &VA, - b: &VB, - c: &VC, - inference_code: &InferenceCode, -) -> PropagationStatusCP { - let a_min = context.lower_bound(a); - let a_max = context.upper_bound(a); - let b_min = context.lower_bound(b); - let b_max = context.upper_bound(b); - let c_min = context.lower_bound(c); - let c_max = context.upper_bound(c); - - // Propagating based on positive bounds - // a is positive and b is positive -> c is positive - if a_min >= 0 && b_min >= 0 { - context.post( - predicate![c >= 0], - (conjunction!([a >= 0] & [b >= 0]), inference_code), - )?; - } - - // a is positive and c is positive -> b is positive - if a_min >= 1 && c_min >= 1 { - context.post( - predicate![b >= 1], - (conjunction!([a >= 1] & [c >= 1]), inference_code), - )?; - } - - // b is positive and c is positive -> a is positive - if b_min >= 1 && c_min >= 1 { - context.post( - predicate![a >= 1], - (conjunction!([b >= 1] & [c >= 1]), inference_code), - )?; - } - - // Propagating based on negative bounds - // a is negative and b is negative -> c is positive - if a_max <= 0 && b_max <= 0 { - context.post( - predicate![c >= 0], - (conjunction!([a <= 0] & [b <= 0]), inference_code), - )?; - } - - // a is negative and c is negative -> b is positive - if a_max <= -1 && c_max <= -1 { - context.post( - predicate![b >= 1], - (conjunction!([a <= -1] & [c <= -1]), inference_code), - )?; - } - - // b is negative and c is negative -> a is positive - if b_max <= -1 && c_max <= -1 { - context.post( - predicate![a >= 1], - (conjunction!([b <= -1] & [c <= -1]), inference_code), - )?; - } - - // Propagating based on mixed bounds (i.e. one positive and one negative) - // Propagating c based on a and b - // a is negative and b is positive -> c is negative - if a_max <= 0 && b_min >= 0 { - context.post( - predicate![c <= 0], - (conjunction!([a <= 0] & [b >= 0]), inference_code), - )?; - } - - // a is positive and b is negative -> c is negative - if a_min >= 0 && b_max <= 0 { - context.post( - predicate![c <= 0], - (conjunction!([a >= 0] & [b <= 0]), inference_code), - )?; - } - - // Propagating b based on a and c - // a is negative and c is positive -> b is negative - if a_max <= -1 && c_min >= 1 { - context.post( - predicate![b <= -1], - (conjunction!([a <= -1] & [c >= 1]), inference_code), - )?; - } - - // a is positive and c is negative -> b is negative - if a_min >= 1 && c_max <= -1 { - context.post( - predicate![b <= -1], - (conjunction!([a >= 1] & [c <= -1]), inference_code), - )?; - } - - // Propagating a based on b and c - // b is negative and c is positive -> a is negative - if b_max <= -1 && c_min >= 1 { - context.post( - predicate![a <= -1], - (conjunction!([b <= -1] & [c >= 1]), inference_code), - )?; - } - - // b is positive and c is negative -> a is negative - if b_min >= 1 && c_max <= -1 { - context.post( - predicate![a <= -1], - (conjunction!([b >= 1] & [c <= -1]), inference_code), - )?; - } - - Ok(()) -} - -/// Compute `ceil(numerator / denominator)`. -/// -/// Assumes `numerator, denominator > 0`. -#[inline] -fn div_ceil_pos(numerator: i32, denominator: i32) -> i32 { - pumpkin_assert_simple!( - numerator > 0 && denominator > 0, - "Either the numerator {numerator} was non-positive or the denominator {denominator} was non-positive" - ); - numerator / denominator + (numerator % denominator).signum() -} - -#[derive(Clone, Debug)] -pub struct IntegerMultiplicationChecker { - pub a: VA, - pub b: VB, - pub c: VC, -} - -impl InferenceChecker for IntegerMultiplicationChecker -where - Atomic: AtomicConstraint, - VA: CheckerVariable, - VB: CheckerVariable, - VC: CheckerVariable, -{ - fn check( - &self, - state: pumpkin_checking::VariableState, - _: &[Atomic], - _: Option<&Atomic>, - ) -> bool { - // We apply interval arithmetic to determine that the computed interval `a times b` - // does not intersect with the domain of `c`. - // - // See https://en.wikipedia.org/wiki/Interval_arithmetic#Interval_operators. - - let x1 = self.a.induced_lower_bound(&state); - let x2 = self.a.induced_upper_bound(&state); - let y1 = self.b.induced_lower_bound(&state); - let y2 = self.b.induced_upper_bound(&state); - - let c_lower = self.c.induced_lower_bound(&state); - let c_upper = self.c.induced_upper_bound(&state); - - let x1y1 = x1 * y1; - let x1y2 = x1 * y2; - let x2y1 = x2 * y1; - let x2y2 = x2 * y2; - - let computed_c_lower = x1y1.min(x1y2).min(x2y1).min(x2y2); - let computed_c_upper = x1y1.max(x1y2).max(x2y1).max(x2y2); - - computed_c_upper < c_lower || computed_c_lower > c_upper - } -} - -#[cfg(test)] -mod tests { - use pumpkin_core::predicate; - use pumpkin_core::predicates::Predicate; - use pumpkin_core::predicates::PropositionalConjunction; - use pumpkin_core::propagation::CurrentNogood; - use pumpkin_core::state::State; - - use super::*; - use crate::StateExt; - - #[test] - fn bounds_of_a_and_b_propagate_bounds_c() { - let mut state = State::default(); - let a = state.new_interval_variable(1, 3, None); - let b = state.new_interval_variable(0, 4, None); - let c = state.new_interval_variable(-10, 20, None); - - let constraint_tag = state.new_constraint_tag(); - - let _ = state.add_propagator(IntegerMultiplicationArgs { - a, - b, - c, - constraint_tag, - }); - state.propagate_to_fixed_point().expect("no empty domains"); - - state.assert_bounds(a, 1, 3); - state.assert_bounds(b, 0, 4); - state.assert_bounds(c, 0, 12); - - let mut reason_buffer: Vec = vec![]; - let _ = state.get_propagation_reason( - predicate![c >= 0], - &mut reason_buffer, - CurrentNogood::empty(), - ); - let reason_lb: PropositionalConjunction = reason_buffer.into(); - assert_eq!(conjunction!([a >= 0] & [b >= 0]), reason_lb); - - let mut reason_buffer: Vec = vec![]; - let _ = state.get_propagation_reason( - predicate![c <= 12], - &mut reason_buffer, - CurrentNogood::empty(), - ); - let reason_ub: PropositionalConjunction = reason_buffer.into(); - assert_eq!( - conjunction!([a >= 0] & [a <= 3] & [b >= 0] & [b <= 4]), - reason_ub - ); - } - - #[test] - fn bounds_of_a_and_c_propagate_bounds_b() { - let mut state = State::default(); - let a = state.new_interval_variable(2, 3, None); - let b = state.new_interval_variable(0, 12, None); - let c = state.new_interval_variable(2, 12, None); - - let constraint_tag = state.new_constraint_tag(); - - let _ = state.add_propagator(IntegerMultiplicationArgs { - a, - b, - c, - constraint_tag, - }); - state.propagate_to_fixed_point().expect("no empty domains"); - - state.assert_bounds(a, 2, 3); - state.assert_bounds(b, 1, 6); - state.assert_bounds(c, 2, 12); - - let mut reason_buffer: Vec = vec![]; - let _ = state.get_propagation_reason( - predicate![b >= 1], - &mut reason_buffer, - CurrentNogood::empty(), - ); - let reason_lb: PropositionalConjunction = reason_buffer.into(); - assert_eq!(conjunction!([a >= 1] & [c >= 1]), reason_lb); - - let mut reason_buffer: Vec = vec![]; - let _ = state.get_propagation_reason( - predicate![b <= 6], - &mut reason_buffer, - CurrentNogood::empty(), - ); - let reason_ub: PropositionalConjunction = reason_buffer.into(); - assert_eq!(conjunction!([a >= 2] & [c >= 0] & [c <= 12]), reason_ub); - } - - #[test] - fn bounds_of_b_and_c_propagate_bounds_a() { - let mut state = State::default(); - let a = state.new_interval_variable(0, 10, None); - let b = state.new_interval_variable(3, 6, None); - let c = state.new_interval_variable(2, 12, None); - - let constraint_tag = state.new_constraint_tag(); - - let _ = state.add_propagator(IntegerMultiplicationArgs { - a, - b, - c, - constraint_tag, - }); - state.propagate_to_fixed_point().expect("no empty domains"); - - state.assert_bounds(a, 1, 4); - state.assert_bounds(b, 3, 6); - state.assert_bounds(c, 3, 12); - - let mut reason_buffer: Vec = vec![]; - let _ = state.get_propagation_reason( - predicate![a >= 1], - &mut reason_buffer, - CurrentNogood::empty(), - ); - let reason_lb: PropositionalConjunction = reason_buffer.into(); - assert_eq!(conjunction!([b >= 1] & [c >= 1]), reason_lb); - - let mut reason_buffer: Vec = vec![]; - let _ = state.get_propagation_reason( - predicate![a <= 4], - &mut reason_buffer, - CurrentNogood::empty(), - ); - let reason_ub: PropositionalConjunction = reason_buffer.into(); - assert_eq!(conjunction!([b >= 3] & [c >= 0] & [c <= 12]), reason_ub); - } - - #[test] - fn b_unbounded_does_not_panic() { - let mut state = State::default(); - let a = state.new_interval_variable(12, 12, None); - let b = state.new_interval_variable(i32::MIN, i32::MAX, None); - let c = state.new_interval_variable(144, 144, None); - - let constraint_tag = state.new_constraint_tag(); - let _ = state.add_propagator(IntegerMultiplicationArgs { - a, - b, - c, - constraint_tag, - }); - state.propagate_to_fixed_point().expect("No empty domains"); - } - - #[test] - fn a_unbounded_does_not_panic() { - let mut state = State::default(); - let a = state.new_interval_variable(i32::MIN, i32::MAX, None); - let b = state.new_interval_variable(12, 12, None); - let c = state.new_interval_variable(144, 144, None); - - let constraint_tag = state.new_constraint_tag(); - let _ = state.add_propagator(IntegerMultiplicationArgs { - a, - b, - c, - constraint_tag, - }); - state.propagate_to_fixed_point().expect("No empty domains"); - } - - #[test] - fn c_unbounded_does_not_panic() { - let mut state = State::default(); - let a = state.new_interval_variable(12, 12, None); - let b = state.new_interval_variable(12, 12, None); - let c = state.new_interval_variable(i32::MIN, i32::MAX, None); - - let constraint_tag = state.new_constraint_tag(); - let _ = state.add_propagator(IntegerMultiplicationArgs { - a, - b, - c, - constraint_tag, - }); - state.propagate_to_fixed_point().expect("No empty domains"); - } - - #[test] - fn all_unbounded_does_not_panic() { - let mut state = State::default(); - let a = state.new_interval_variable(i32::MIN, i32::MAX, None); - let b = state.new_interval_variable(i32::MIN, i32::MAX, None); - let c = state.new_interval_variable(i32::MIN, i32::MAX, None); - - let constraint_tag = state.new_constraint_tag(); - let _ = state.add_propagator(IntegerMultiplicationArgs { - a, - b, - c, - constraint_tag, - }); - state.propagate_to_fixed_point().expect("No empty domains"); - } -} diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs index 075af43d9..7cde056cf 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs @@ -14,7 +14,6 @@ use pumpkin_core::propagation::DomainEvents; use pumpkin_core::propagation::Domains; use pumpkin_core::propagation::EnqueueDecision; use pumpkin_core::propagation::ExplanationContext; -use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LazyExplanation; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::NotificationContext; @@ -46,16 +45,6 @@ where { type PropagatorImpl = LinearLessOrEqualPropagator; - fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { - checkers.add_inference_checker( - InferenceCode::new(self.constraint_tag, LinearBounds), - Box::new(LinearLessOrEqualInferenceChecker::new( - self.x.clone(), - self.c, - )), - ); - } - fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { let LinearLessOrEqualPropagatorArgs { x, @@ -63,6 +52,11 @@ where constraint_tag, } = self; + context.add_inference_checker( + InferenceCode::new(constraint_tag, LinearBounds), + Box::new(LinearLessOrEqualInferenceChecker::new(x.clone(), c)), + ); + let mut lower_bound_left_hand_side = 0_i64; let mut current_bounds = vec![]; diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_not_equal.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_not_equal.rs index 130d2b2d8..2aa6e11d9 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_not_equal.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_not_equal.rs @@ -18,7 +18,6 @@ use pumpkin_core::propagation::DomainEvent; use pumpkin_core::propagation::DomainEvents; use pumpkin_core::propagation::Domains; use pumpkin_core::propagation::EnqueueDecision; -use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::NotificationContext; use pumpkin_core::propagation::OpaqueDomainEvent; @@ -50,16 +49,6 @@ where { type PropagatorImpl = LinearNotEqualPropagator; - fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { - checkers.add_inference_checker( - InferenceCode::new(self.constraint_tag, LinearNotEquals), - Box::new(LinearNotEqualChecker { - terms: self.terms.as_ref().into(), - bound: self.rhs, - }), - ); - } - fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { let LinearNotEqualPropagatorArgs { terms, @@ -67,6 +56,14 @@ where constraint_tag, } = self; + context.add_inference_checker( + InferenceCode::new(constraint_tag, LinearNotEquals), + Box::new(LinearNotEqualChecker { + terms: terms.as_ref().into(), + bound: rhs, + }), + ); + for (i, x_i) in terms.iter().enumerate() { context.register(x_i.clone(), DomainEvents::ASSIGN, LocalId::from(i as u32)); context.register_backtrack( diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/maximum.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/maximum.rs index 2a847bc4c..693869274 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/maximum.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/maximum.rs @@ -9,7 +9,6 @@ use pumpkin_core::predicates::PropositionalConjunction; use pumpkin_core::proof::ConstraintTag; use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvents; -use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::Priority; use pumpkin_core::propagation::PropagationContext; @@ -36,16 +35,6 @@ where { type PropagatorImpl = MaximumPropagator; - fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { - checkers.add_inference_checker( - InferenceCode::new(self.constraint_tag, Maximum), - Box::new(MaximumChecker { - array: self.array.clone(), - rhs: self.rhs.clone(), - }), - ); - } - fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { let MaximumArgs { array, @@ -53,6 +42,14 @@ where constraint_tag, } = self; + context.add_inference_checker( + InferenceCode::new(constraint_tag, Maximum), + Box::new(MaximumChecker { + array: array.clone(), + rhs: rhs.clone(), + }), + ); + for (idx, var) in array.iter().enumerate() { context.register(var.clone(), DomainEvents::BOUNDS, LocalId::from(idx as u32)); } diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/mod.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/mod.rs index 21fc6eb9a..c849b2426 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/mod.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/mod.rs @@ -1,16 +1,18 @@ //! Contains a number of propagators for a variety of arithmetic constraints. pub(crate) mod absolute_value; pub(crate) mod binary; +mod binary_equals; pub(crate) mod integer_division; -pub(crate) mod integer_multiplication; pub(crate) mod linear_less_or_equal; pub(crate) mod linear_not_equal; pub(crate) mod maximum; +mod multiplication; pub use absolute_value::*; pub use binary::*; +pub use binary_equals::*; pub use integer_division::*; -pub use integer_multiplication::*; pub use linear_less_or_equal::*; pub use linear_not_equal::*; pub use maximum::*; +pub use multiplication::*; diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/multiplication/checker.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/multiplication/checker.rs new file mode 100644 index 000000000..cdb5fa200 --- /dev/null +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/multiplication/checker.rs @@ -0,0 +1,143 @@ +use pumpkin_checking::AtomicConstraint; +use pumpkin_checking::CheckerVariable; +use pumpkin_checking::InferenceChecker; +use pumpkin_core::checkers::support::Support; +use pumpkin_core::checkers::support::SupportGenerator; +use pumpkin_core::checkers::support::SupportsValue; +use pumpkin_core::checkers::support::UnsupportedValue; +use pumpkin_core::propagation::Domains; +use pumpkin_core::propagation::LocalId; +use pumpkin_core::propagation::ReadDomains; +use pumpkin_core::variables::IntegerVariable; + +#[derive(Clone, Debug)] +pub struct IntegerMultiplicationChecker { + pub a: VA, + pub b: VB, + pub c: VC, +} + +impl InferenceChecker for IntegerMultiplicationChecker +where + Atomic: AtomicConstraint, + VA: CheckerVariable, + VB: CheckerVariable, + VC: CheckerVariable, +{ + fn check( + &self, + state: pumpkin_checking::VariableState, + _: &[Atomic], + _: Option<&Atomic>, + ) -> bool { + // We apply interval arithmetic to determine that the computed interval `a times b` + // does not intersect with the domain of `c`. + // + // See https://en.wikipedia.org/wiki/Interval_arithmetic#Interval_operators. + + let x1 = self.a.induced_lower_bound(&state); + let x2 = self.a.induced_upper_bound(&state); + let y1 = self.b.induced_lower_bound(&state); + let y2 = self.b.induced_upper_bound(&state); + + let c_lower = self.c.induced_lower_bound(&state); + let c_upper = self.c.induced_upper_bound(&state); + + let x1y1 = x1 * y1; + let x1y2 = x1 * y2; + let x2y1 = x2 * y1; + let x2y2 = x2 * y2; + + let computed_c_lower = x1y1.min(x1y2).min(x2y1).min(x2y2); + let computed_c_upper = x1y1.max(x1y2).max(x2y1).max(x2y2); + + computed_c_upper < c_lower || computed_c_lower > c_upper + } +} + +impl SupportGenerator for IntegerMultiplicationChecker +where + VA: IntegerVariable + SupportsValue, + VB: IntegerVariable + SupportsValue, + VC: IntegerVariable + SupportsValue, +{ + type Value = f32; + + fn support( + &mut self, + support: &mut Support, + local_id: LocalId, + unsupported_value: UnsupportedValue, + domains: &Domains<'_>, + ) { + let a_min = domains.lower_bound(&self.a) as f32; + let a_max = domains.upper_bound(&self.a) as f32; + let b_min = domains.lower_bound(&self.b) as f32; + let b_max = domains.upper_bound(&self.b) as f32; + let c_min = domains.lower_bound(&self.c) as f32; + let c_max = domains.upper_bound(&self.c) as f32; + + let (value_a, value_b, value_c) = match local_id { + super::ID_A => { + let value_a = self.a.unpack(unsupported_value) as f32; + + let Some((value_b, value_c)) = [b_min, b_max].into_iter().find_map(|value_b| { + let value_c = value_a * value_b; + if c_min <= value_c && value_c <= c_max { + Some((value_b, value_c)) + } else { + None + } + }) else { + return; + }; + + (value_a, value_b, value_c) + } + super::ID_B => { + let value_b = self.b.unpack(unsupported_value) as f32; + + let Some((value_a, value_c)) = [a_min, a_max].into_iter().find_map(|value_a| { + let value_c = value_a * value_b; + if c_min <= value_c && value_c <= c_max { + Some((value_a, value_c)) + } else { + None + } + }) else { + return; + }; + + (value_a, value_b, value_c) + } + super::ID_C => { + let value_c = self.c.unpack(unsupported_value) as f32; + + let Some(values) = [a_min, a_max].into_iter().find_map(|value_a| { + let value_b = value_c / value_a; + + if b_min <= value_b && value_b <= b_max { + Some((value_a, value_b, value_c)) + } else { + None + } + }) else { + return; + }; + + values + } + + _ => unreachable!(), + }; + + self.a.assign(value_a, support); + self.b.assign(value_b, support); + self.c.assign(value_c, support); + } + + fn is_solution(&self, support: &Support) -> bool { + self.a.support_value(support) * self.b.support_value(support) + == self.c.support_value(support) + } +} diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/multiplication/constructor.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/multiplication/constructor.rs new file mode 100644 index 000000000..fa69b8383 --- /dev/null +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/multiplication/constructor.rs @@ -0,0 +1,72 @@ +use pumpkin_core::checkers::StrongConsistency; +use pumpkin_core::checkers::StrongRetentionChecker; +use pumpkin_core::checkers::support::SupportsValue; +use pumpkin_core::proof::ConstraintTag; +use pumpkin_core::proof::InferenceCode; +use pumpkin_core::propagation::DomainEvents; +use pumpkin_core::propagation::PropagatorConstructor; +use pumpkin_core::propagation::PropagatorConstructorContext; +use pumpkin_core::variables::IntegerVariable; + +use crate::arithmetic::IntegerMultiplicationPropagator; +use crate::arithmetic::multiplication::IntegerMultiplication; +use crate::arithmetic::multiplication::checker::IntegerMultiplicationChecker; + +/// The [`PropagatorConstructor`] for [`IntegerMultiplicationPropagator`]. +#[derive(Clone, Debug)] +pub struct IntegerMultiplicationArgs { + pub a: VA, + pub b: VB, + pub c: VC, + pub constraint_tag: ConstraintTag, +} + +impl PropagatorConstructor for IntegerMultiplicationArgs +where + VA: IntegerVariable + SupportsValue + 'static, + VB: IntegerVariable + SupportsValue + 'static, + VC: IntegerVariable + SupportsValue + 'static, +{ + type PropagatorImpl = IntegerMultiplicationPropagator; + + fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { + let IntegerMultiplicationArgs { + a, + b, + c, + constraint_tag, + } = self; + + context.add_inference_checker( + InferenceCode::new(constraint_tag, IntegerMultiplication), + Box::new(IntegerMultiplicationChecker { + a: a.clone(), + b: b.clone(), + c: c.clone(), + }), + ); + + context.add_consistency_checker( + ((super::ID_A, &a), (super::ID_B, &b), (super::ID_C, &c)), + StrongRetentionChecker::new( + StrongConsistency::Bounds, + IntegerMultiplicationChecker { + a: a.clone(), + b: b.clone(), + c: c.clone(), + }, + ), + ); + + context.register(a.clone(), DomainEvents::ANY_INT, super::ID_A); + context.register(b.clone(), DomainEvents::ANY_INT, super::ID_B); + context.register(c.clone(), DomainEvents::ANY_INT, super::ID_C); + + IntegerMultiplicationPropagator { + a, + b, + c, + inference_code: InferenceCode::new(constraint_tag, IntegerMultiplication), + } + } +} diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/multiplication/mod.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/multiplication/mod.rs new file mode 100644 index 000000000..afef4b2bb --- /dev/null +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/multiplication/mod.rs @@ -0,0 +1,17 @@ +mod checker; +mod constructor; +mod propagator; + +pub use checker::*; +pub use constructor::*; +pub use propagator::*; +use pumpkin_core::declare_inference_label; +use pumpkin_core::propagation::LocalId; + +// The LocalId's for the variables. +const ID_A: LocalId = LocalId::from(0); +const ID_B: LocalId = LocalId::from(1); +const ID_C: LocalId = LocalId::from(2); + +// The inference label for integer multiplication. +declare_inference_label!(IntegerMultiplication); diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/multiplication/propagator.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/multiplication/propagator.rs new file mode 100644 index 000000000..9e4f9cac5 --- /dev/null +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/multiplication/propagator.rs @@ -0,0 +1,304 @@ +use pumpkin_core::asserts::pumpkin_assert_simple; +use pumpkin_core::conjunction; +use pumpkin_core::predicate; +use pumpkin_core::proof::InferenceCode; +use pumpkin_core::propagation::Priority; +use pumpkin_core::propagation::PropagationContext; +use pumpkin_core::propagation::Propagator; +use pumpkin_core::propagation::ReadDomains; +use pumpkin_core::state::PropagationStatusCP; +use pumpkin_core::state::propagator_conflict; +use pumpkin_core::variables::IntegerVariable; + +/// A propagator for maintaining the constraint `a * b = c`. The propagator +/// (currently) only propagates the signs of the variables, the case where a, b, c >= 0, and detects +/// a conflict if the variables are fixed. +#[derive(Clone, Debug)] +pub struct IntegerMultiplicationPropagator { + pub(super) a: VA, + pub(super) b: VB, + pub(super) c: VC, + pub(super) inference_code: InferenceCode, +} + +impl Propagator + for IntegerMultiplicationPropagator +where + VA: IntegerVariable, + VB: IntegerVariable, + VC: IntegerVariable, +{ + fn priority(&self) -> Priority { + Priority::High + } + + fn name(&self) -> &str { + "IntTimes" + } + + fn propagate_from_scratch(&self, context: PropagationContext) -> PropagationStatusCP { + perform_propagation(context, &self.a, &self.b, &self.c, &self.inference_code) + } +} + +fn perform_propagation( + mut context: PropagationContext, + a: &VA, + b: &VB, + c: &VC, + inference_code: &InferenceCode, +) -> PropagationStatusCP { + // First we propagate the signs + propagate_signs(&mut context, a, b, c, inference_code)?; + + let a_min = context.lower_bound(a); + let a_max = context.upper_bound(a); + let b_min = context.lower_bound(b); + let b_max = context.upper_bound(b); + let c_min = context.lower_bound(c); + let c_max = context.upper_bound(c); + + if a_min >= 0 && b_min >= 0 { + let new_max_c = a_max.saturating_mul(b_max); + let new_min_c = a_min.saturating_mul(b_min); + + // c is smaller than the maximum value that a * b can take + // + // We need the lower-bounds in the explanation as well because the reasoning does not + // hold in the case of a negative lower-bound + context.post( + predicate![c <= new_max_c], + ( + conjunction!([a >= 0] & [a <= a_max] & [b >= 0] & [b <= b_max]), + inference_code, + ), + )?; + + // c is larger than the minimum value that a * b can take + context.post( + predicate![c >= new_min_c], + (conjunction!([a >= a_min] & [b >= b_min]), inference_code), + )?; + } + + if b_min >= 0 && b_max >= 1 && c_min >= 1 { + // a >= ceil(c.min / b.max) + let bound = div_ceil_pos(c_min, b_max); + context.post( + predicate![a >= bound], + ( + conjunction!([c >= c_min] & [b >= 0] & [b <= b_max]), + inference_code, + ), + )?; + } + + if b_min >= 1 && c_min >= 0 && c_max >= 1 { + // a <= floor(c.max / b.min) + let bound = c_max / b_min; + context.post( + predicate![a <= bound], + ( + conjunction!([c >= 0] & [c <= c_max] & [b >= b_min]), + inference_code, + ), + )?; + } + + if a_min >= 1 && c_min >= 0 && c_max >= 1 { + // b <= floor(c.max / a.min) + let bound = c_max / a_min; + context.post( + predicate![b <= bound], + ( + conjunction!([c >= 0] & [c <= c_max] & [a >= a_min]), + inference_code, + ), + )?; + } + + // b >= ceil(c.min / a.max) + if a_min >= 0 && a_max >= 1 && c_min >= 1 { + let bound = div_ceil_pos(c_min, a_max); + + context.post( + predicate![b >= bound], + ( + conjunction!([c >= c_min] & [a >= 0] & [a <= a_max]), + inference_code, + ), + )?; + } + + if let Some(fixed_a) = context.fixed_value(a) + && let Some(fixed_b) = context.fixed_value(b) + && let Some(fixed_c) = context.fixed_value(c) + && (fixed_a * fixed_b) != fixed_c + { + // All variables are assigned but the resulting value is not correct, so we report a + // conflict + return propagator_conflict( + conjunction!( + [a == context.lower_bound(a)] + & [b == context.lower_bound(b)] + & [c == context.lower_bound(c)] + ), + inference_code, + ); + } + + Ok(()) +} + +/// Propagates the signs of the variables, it performs the following propagations: +/// - Propagating based on positive bounds +/// - If a is positive and b is positive then c is positive +/// - If a is positive and c is positive then b is positive +/// - If b is positive and c is positive then a is positive +/// - Propagating based on negative bounds +/// - If a is negative and b is negative then c is positive +/// - If a is negative and c is negative then b is positive +/// - If b is negative and c is negative then b is positive +/// - Propagating based on mixed bounds +/// - Propagating c based on a and b +/// - If a is negative and b is positive then c is negative +/// - If a is positive and b is negative then c is negative +/// - Propagating b based on a and c +/// - If a is negative and c is positive then b is negative +/// - If a is positive and c is negative then b is negative +/// - Propagating a based on b and c +/// - If b is negative and c is positive then a is negative +/// - If b is positive and c is negative then a is negative +/// +/// Note that this method does not propagate a value if 0 is in the domain as, for example, 0 * -3 = +/// 0 and 0 * 3 = 0 are both equally valid. +fn propagate_signs( + context: &mut PropagationContext, + a: &VA, + b: &VB, + c: &VC, + inference_code: &InferenceCode, +) -> PropagationStatusCP { + let a_min = context.lower_bound(a); + let a_max = context.upper_bound(a); + let b_min = context.lower_bound(b); + let b_max = context.upper_bound(b); + let c_min = context.lower_bound(c); + let c_max = context.upper_bound(c); + + // Propagating based on positive bounds + // a is positive and b is positive -> c is positive + if a_min >= 0 && b_min >= 0 { + context.post( + predicate![c >= 0], + (conjunction!([a >= 0] & [b >= 0]), inference_code), + )?; + } + + // a is positive and c is positive -> b is positive + if a_min >= 1 && c_min >= 1 { + context.post( + predicate![b >= 1], + (conjunction!([a >= 1] & [c >= 1]), inference_code), + )?; + } + + // b is positive and c is positive -> a is positive + if b_min >= 1 && c_min >= 1 { + context.post( + predicate![a >= 1], + (conjunction!([b >= 1] & [c >= 1]), inference_code), + )?; + } + + // Propagating based on negative bounds + // a is negative and b is negative -> c is positive + if a_max <= 0 && b_max <= 0 { + context.post( + predicate![c >= 0], + (conjunction!([a <= 0] & [b <= 0]), inference_code), + )?; + } + + // a is negative and c is negative -> b is positive + if a_max <= -1 && c_max <= -1 { + context.post( + predicate![b >= 1], + (conjunction!([a <= -1] & [c <= -1]), inference_code), + )?; + } + + // b is negative and c is negative -> a is positive + if b_max <= -1 && c_max <= -1 { + context.post( + predicate![a >= 1], + (conjunction!([b <= -1] & [c <= -1]), inference_code), + )?; + } + + // Propagating based on mixed bounds (i.e. one positive and one negative) + // Propagating c based on a and b + // a is negative and b is positive -> c is negative + if a_max <= 0 && b_min >= 0 { + context.post( + predicate![c <= 0], + (conjunction!([a <= 0] & [b >= 0]), inference_code), + )?; + } + + // a is positive and b is negative -> c is negative + if a_min >= 0 && b_max <= 0 { + context.post( + predicate![c <= 0], + (conjunction!([a >= 0] & [b <= 0]), inference_code), + )?; + } + + // Propagating b based on a and c + // a is negative and c is positive -> b is negative + if a_max <= -1 && c_min >= 1 { + context.post( + predicate![b <= -1], + (conjunction!([a <= -1] & [c >= 1]), inference_code), + )?; + } + + // a is positive and c is negative -> b is negative + if a_min >= 1 && c_max <= -1 { + context.post( + predicate![b <= -1], + (conjunction!([a >= 1] & [c <= -1]), inference_code), + )?; + } + + // Propagating a based on b and c + // b is negative and c is positive -> a is negative + if b_max <= -1 && c_min >= 1 { + context.post( + predicate![a <= -1], + (conjunction!([b <= -1] & [c >= 1]), inference_code), + )?; + } + + // b is positive and c is negative -> a is negative + if b_min >= 1 && c_max <= -1 { + context.post( + predicate![a <= -1], + (conjunction!([b >= 1] & [c <= -1]), inference_code), + )?; + } + + Ok(()) +} + +/// Compute `ceil(numerator / denominator)`. +/// +/// Assumes `numerator, denominator > 0`. +#[inline] +fn div_ceil_pos(numerator: i32, denominator: i32) -> i32 { + pumpkin_assert_simple!( + numerator > 0 && denominator > 0, + "Either the numerator {numerator} was non-positive or the denominator {denominator} was non-positive" + ); + numerator / denominator + (numerator % denominator).signum() +} diff --git a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/over_interval_incremental_propagator/time_table_over_interval_incremental.rs b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/over_interval_incremental_propagator/time_table_over_interval_incremental.rs index fc5d108e4..e67e186ed 100644 --- a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/over_interval_incremental_propagator/time_table_over_interval_incremental.rs +++ b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/over_interval_incremental_propagator/time_table_over_interval_incremental.rs @@ -11,7 +11,6 @@ use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvent; use pumpkin_core::propagation::Domains; use pumpkin_core::propagation::EnqueueDecision; -use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::NotificationContext; use pumpkin_core::propagation::OpaqueDomainEvent; @@ -110,8 +109,8 @@ impl PropagatorConstruc { type PropagatorImpl = Self; - fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { - checkers.add_inference_checker( + fn create(mut self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { + context.add_inference_checker( InferenceCode::new(self.constraint_tag, TimeTable), Box::new(TimeTableChecker { tasks: self @@ -127,9 +126,6 @@ impl PropagatorConstruc capacity: self.parameters.capacity, }), ); - } - - fn create(mut self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { // We only register for notifications of backtrack events if incremental backtracking is // enabled register_tasks( diff --git a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/per_point_incremental_propagator/time_table_per_point_incremental.rs b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/per_point_incremental_propagator/time_table_per_point_incremental.rs index 9b4ffa1e5..05ac11900 100644 --- a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/per_point_incremental_propagator/time_table_per_point_incremental.rs +++ b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/per_point_incremental_propagator/time_table_per_point_incremental.rs @@ -11,7 +11,6 @@ use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvent; use pumpkin_core::propagation::Domains; use pumpkin_core::propagation::EnqueueDecision; -use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::NotificationContext; use pumpkin_core::propagation::OpaqueDomainEvent; @@ -107,8 +106,8 @@ impl Propagator { type PropagatorImpl = Self; - fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { - checkers.add_inference_checker( + fn create(mut self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { + context.add_inference_checker( InferenceCode::new(self.constraint_tag, TimeTable), Box::new(TimeTableChecker { tasks: self @@ -124,9 +123,6 @@ impl Propagator capacity: self.parameters.capacity, }), ); - } - - fn create(mut self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { register_tasks(&self.parameters.tasks, context.reborrow(), true); self.updatable_structures .reset_all_bounds_and_remove_fixed(context.domains(), &self.parameters); diff --git a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_over_interval.rs b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_over_interval.rs index cb8b6926b..b8b3fa314 100644 --- a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_over_interval.rs +++ b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_over_interval.rs @@ -9,7 +9,6 @@ use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvent; use pumpkin_core::propagation::Domains; use pumpkin_core::propagation::EnqueueDecision; -use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::NotificationContext; use pumpkin_core::propagation::OpaqueDomainEvent; @@ -110,8 +109,8 @@ impl PropagatorConstructor { type PropagatorImpl = Self; - fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { - checkers.add_inference_checker( + fn create(mut self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { + context.add_inference_checker( InferenceCode::new(self.constraint_tag, TimeTable), Box::new(TimeTableChecker { tasks: self @@ -127,9 +126,6 @@ impl PropagatorConstructor capacity: self.parameters.capacity, }), ); - } - - fn create(mut self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { self.updatable_structures .initialise_bounds_and_remove_fixed(context.domains(), &self.parameters); register_tasks(&self.parameters.tasks, context.reborrow(), false); diff --git a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_per_point.rs b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_per_point.rs index 5451cd01d..0c7fceed6 100644 --- a/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_per_point.rs +++ b/pumpkin-crates/propagators/src/propagators/cumulative/time_table/time_table_per_point.rs @@ -11,7 +11,6 @@ use pumpkin_core::proof::ConstraintTag; use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvent; use pumpkin_core::propagation::EnqueueDecision; -use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::NotificationContext; use pumpkin_core::propagation::OpaqueDomainEvent; @@ -100,8 +99,8 @@ impl TimeTablePerPointPropagator { impl PropagatorConstructor for TimeTablePerPointPropagator { type PropagatorImpl = Self; - fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { - checkers.add_inference_checker( + fn create(mut self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { + context.add_inference_checker( InferenceCode::new(self.constraint_tag, TimeTable), Box::new(TimeTableChecker { tasks: self @@ -117,9 +116,6 @@ impl PropagatorConstructor for TimeTablePerPoint capacity: self.parameters.capacity, }), ); - } - - fn create(mut self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { self.updatable_structures .initialise_bounds_and_remove_fixed(context.domains(), &self.parameters); register_tasks(&self.parameters.tasks, context.reborrow(), false); diff --git a/pumpkin-crates/propagators/src/propagators/disjunctive/disjunctive_propagator.rs b/pumpkin-crates/propagators/src/propagators/disjunctive/disjunctive_propagator.rs index 0d2da6002..376712d4b 100644 --- a/pumpkin-crates/propagators/src/propagators/disjunctive/disjunctive_propagator.rs +++ b/pumpkin-crates/propagators/src/propagators/disjunctive/disjunctive_propagator.rs @@ -8,7 +8,6 @@ use pumpkin_core::predicates::PropositionalConjunction; use pumpkin_core::proof::ConstraintTag; use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvents; -use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::PropagationContext; use pumpkin_core::propagation::Propagator; @@ -80,6 +79,20 @@ impl PropagatorConstructor for DisjunctiveConstr type PropagatorImpl = DisjunctivePropagator; fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { + context.add_inference_checker( + InferenceCode::new(self.constraint_tag, DisjunctiveEdgeFinding), + Box::new(DisjunctiveEdgeFindingChecker { + tasks: self + .tasks + .iter() + .map(|task| ArgDisjunctiveTask { + start_time: task.start_time.clone(), + processing_time: task.processing_time, + }) + .collect(), + }), + ); + let tasks = self .tasks .into_iter() @@ -106,22 +119,6 @@ impl PropagatorConstructor for DisjunctiveConstr inference_code, } } - - fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { - checkers.add_inference_checker( - InferenceCode::new(self.constraint_tag, DisjunctiveEdgeFinding), - Box::new(DisjunctiveEdgeFindingChecker { - tasks: self - .tasks - .iter() - .map(|task| ArgDisjunctiveTask { - start_time: task.start_time.clone(), - processing_time: task.processing_time, - }) - .collect(), - }), - ); - } } impl Propagator for DisjunctivePropagator { diff --git a/pumpkin-crates/propagators/src/propagators/element.rs b/pumpkin-crates/propagators/src/propagators/element.rs index 8874a259c..19139c998 100644 --- a/pumpkin-crates/propagators/src/propagators/element.rs +++ b/pumpkin-crates/propagators/src/propagators/element.rs @@ -18,7 +18,6 @@ use pumpkin_core::proof::ConstraintTag; use pumpkin_core::proof::InferenceCode; use pumpkin_core::propagation::DomainEvents; use pumpkin_core::propagation::ExplanationContext; -use pumpkin_core::propagation::InferenceCheckers; use pumpkin_core::propagation::LazyExplanation; use pumpkin_core::propagation::LocalId; use pumpkin_core::propagation::Priority; @@ -49,17 +48,6 @@ where { type PropagatorImpl = ElementPropagator; - fn add_inference_checkers(&self, mut checkers: InferenceCheckers<'_>) { - checkers.add_inference_checker( - InferenceCode::new(self.constraint_tag, Element), - Box::new(ElementChecker::new( - self.array.clone(), - self.index.clone(), - self.rhs.clone(), - )), - ); - } - fn create(self, mut context: PropagatorConstructorContext) -> Self::PropagatorImpl { let ElementArgs { array, @@ -68,6 +56,15 @@ where constraint_tag, } = self; + context.add_inference_checker( + InferenceCode::new(constraint_tag, Element), + Box::new(ElementChecker::new( + array.clone(), + index.clone(), + rhs.clone(), + )), + ); + for (i, x_i) in array.iter().enumerate() { context.register( x_i.clone(),