Skip to content

Commit 70da310

Browse files
authored
perf: deduplicate queries (#2698)
* deduplicate queries Deduplicate queries in the UserInputAst after parsing queries * add return type
1 parent 85010b5 commit 70da310

File tree

7 files changed

+59
-21
lines changed

7 files changed

+59
-21
lines changed

query-grammar/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@ edition = "2024"
1515
nom = "7"
1616
serde = { version = "1.0.219", features = ["derive"] }
1717
serde_json = "1.0.140"
18+
ordered-float = "5.0.0"
19+
fnv = "1.0.7"

query-grammar/src/lib.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,17 @@ pub fn parse_query_lenient(query: &str) -> (UserInputAst, Vec<LenientError>) {
3131

3232
#[cfg(test)]
3333
mod tests {
34-
use crate::{parse_query, parse_query_lenient};
34+
use crate::{UserInputAst, parse_query, parse_query_lenient};
35+
36+
#[test]
37+
fn test_deduplication() {
38+
let ast: UserInputAst = parse_query("a a").unwrap();
39+
let json = serde_json::to_string(&ast).unwrap();
40+
assert_eq!(
41+
json,
42+
r#"{"type":"bool","clauses":[[null,{"type":"literal","field_name":null,"phrase":"a","delimiter":"none","slop":0,"prefix":false}]]}"#
43+
);
44+
}
3545

3646
#[test]
3747
fn test_parse_query_serialization() {

query-grammar/src/query_grammar.rs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::borrow::Cow;
22
use std::iter::once;
33

4+
use fnv::FnvHashSet;
45
use nom::IResult;
56
use nom::branch::alt;
67
use nom::bytes::complete::tag;
@@ -814,7 +815,7 @@ fn boosted_leaf(inp: &str) -> IResult<&str, UserInputAst> {
814815
tuple((leaf, fallible(boost))),
815816
|(leaf, boost_opt)| match boost_opt {
816817
Some(boost) if (boost - 1.0).abs() > f64::EPSILON => {
817-
UserInputAst::Boost(Box::new(leaf), boost)
818+
UserInputAst::Boost(Box::new(leaf), boost.into())
818819
}
819820
_ => leaf,
820821
},
@@ -826,7 +827,7 @@ fn boosted_leaf_infallible(inp: &str) -> JResult<&str, Option<UserInputAst>> {
826827
tuple_infallible((leaf_infallible, boost)),
827828
|((leaf, boost_opt), error)| match boost_opt {
828829
Some(boost) if (boost - 1.0).abs() > f64::EPSILON => (
829-
leaf.map(|leaf| UserInputAst::Boost(Box::new(leaf), boost)),
830+
leaf.map(|leaf| UserInputAst::Boost(Box::new(leaf), boost.into())),
830831
error,
831832
),
832833
_ => (leaf, error),
@@ -1077,12 +1078,25 @@ pub fn parse_to_ast_lenient(query_str: &str) -> (UserInputAst, Vec<LenientError>
10771078
(rewrite_ast(res), errors)
10781079
}
10791080

1080-
/// Removes unnecessary children clauses in AST
1081-
///
1082-
/// Motivated by [issue #1433](https://github.com/quickwit-oss/tantivy/issues/1433)
10831081
fn rewrite_ast(mut input: UserInputAst) -> UserInputAst {
1084-
if let UserInputAst::Clause(terms) = &mut input {
1085-
for term in terms {
1082+
if let UserInputAst::Clause(sub_clauses) = &mut input {
1083+
// call rewrite_ast recursively on children clauses if applicable
1084+
let mut new_clauses = Vec::with_capacity(sub_clauses.len());
1085+
for (occur, clause) in sub_clauses.drain(..) {
1086+
let rewritten_clause = rewrite_ast(clause);
1087+
new_clauses.push((occur, rewritten_clause));
1088+
}
1089+
*sub_clauses = new_clauses;
1090+
1091+
// remove duplicate child clauses
1092+
// e.g. (+a +b) OR (+c +d) OR (+a +b) => (+a +b) OR (+c +d)
1093+
let mut seen = FnvHashSet::default();
1094+
sub_clauses.retain(|term| seen.insert(term.clone()));
1095+
1096+
// Removes unnecessary children clauses in AST
1097+
//
1098+
// Motivated by [issue #1433](https://github.com/quickwit-oss/tantivy/issues/1433)
1099+
for term in sub_clauses {
10861100
rewrite_ast_clause(term);
10871101
}
10881102
}

query-grammar/src/user_input_ast.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use serde::Serialize;
55

66
use crate::Occur;
77

8-
#[derive(PartialEq, Clone, Serialize)]
8+
#[derive(PartialEq, Eq, Hash, Clone, Serialize)]
99
#[serde(tag = "type")]
1010
#[serde(rename_all = "snake_case")]
1111
pub enum UserInputLeaf {
@@ -120,15 +120,15 @@ impl Debug for UserInputLeaf {
120120
}
121121
}
122122

123-
#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize)]
123+
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, Serialize)]
124124
#[serde(rename_all = "snake_case")]
125125
pub enum Delimiter {
126126
SingleQuotes,
127127
DoubleQuotes,
128128
None,
129129
}
130130

131-
#[derive(PartialEq, Clone, Serialize)]
131+
#[derive(PartialEq, Eq, Hash, Clone, Serialize)]
132132
#[serde(rename_all = "snake_case")]
133133
pub struct UserInputLiteral {
134134
pub field_name: Option<String>,
@@ -167,7 +167,7 @@ impl fmt::Debug for UserInputLiteral {
167167
}
168168
}
169169

170-
#[derive(PartialEq, Debug, Clone, Serialize)]
170+
#[derive(PartialEq, Eq, Hash, Debug, Clone, Serialize)]
171171
#[serde(tag = "type", content = "value")]
172172
#[serde(rename_all = "snake_case")]
173173
pub enum UserInputBound {
@@ -204,11 +204,11 @@ impl UserInputBound {
204204
}
205205
}
206206

207-
#[derive(PartialEq, Clone, Serialize)]
207+
#[derive(PartialEq, Eq, Hash, Clone, Serialize)]
208208
#[serde(into = "UserInputAstSerde")]
209209
pub enum UserInputAst {
210210
Clause(Vec<(Option<Occur>, UserInputAst)>),
211-
Boost(Box<UserInputAst>, f64),
211+
Boost(Box<UserInputAst>, ordered_float::OrderedFloat<f64>),
212212
Leaf(Box<UserInputLeaf>),
213213
}
214214

@@ -230,9 +230,10 @@ impl From<UserInputAst> for UserInputAstSerde {
230230
fn from(ast: UserInputAst) -> Self {
231231
match ast {
232232
UserInputAst::Clause(clause) => UserInputAstSerde::Bool { clauses: clause },
233-
UserInputAst::Boost(underlying, boost) => {
234-
UserInputAstSerde::Boost { underlying, boost }
235-
}
233+
UserInputAst::Boost(underlying, boost) => UserInputAstSerde::Boost {
234+
underlying,
235+
boost: boost.into_inner(),
236+
},
236237
UserInputAst::Leaf(leaf) => UserInputAstSerde::Leaf(leaf),
237238
}
238239
}
@@ -391,7 +392,7 @@ mod tests {
391392
#[test]
392393
fn test_boost_serialization() {
393394
let inner_ast = UserInputAst::Leaf(Box::new(UserInputLeaf::All));
394-
let boost_ast = UserInputAst::Boost(Box::new(inner_ast), 2.5);
395+
let boost_ast = UserInputAst::Boost(Box::new(inner_ast), 2.5.into());
395396
let json = serde_json::to_string(&boost_ast).unwrap();
396397
assert_eq!(
397398
json,
@@ -418,7 +419,7 @@ mod tests {
418419
}))),
419420
),
420421
])),
421-
2.5,
422+
2.5.into(),
422423
);
423424
let json = serde_json::to_string(&boost_ast).unwrap();
424425
assert_eq!(

src/query/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ mod tests {
104104
let query = query_parser.parse_query("a a a a a").unwrap();
105105
let mut terms = Vec::new();
106106
query.query_terms(&mut |term, pos| terms.push((term, pos)));
107-
assert_eq!(vec![(&term_a, false); 5], terms);
107+
assert_eq!(vec![(&term_a, false); 1], terms);
108108
}
109109
{
110110
let query = query_parser.parse_query("a -b").unwrap();

src/query/query_parser/logical_ast.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ impl LogicalAst {
4545
}
4646
}
4747

48+
// TODO: Move to rewrite_ast in query_grammar
4849
pub fn simplify(self) -> LogicalAst {
4950
match self {
5051
LogicalAst::Clause(clauses) => {

src/query/query_parser/query_parser.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ impl QueryParser {
672672
}
673673
UserInputAst::Boost(ast, boost) => {
674674
let (ast, errors) = self.compute_logical_ast_with_occur_lenient(*ast);
675-
(ast.boost(boost as Score), errors)
675+
(ast.boost(boost.into_inner() as Score), errors)
676676
}
677677
UserInputAst::Leaf(leaf) => {
678678
let (ast, errors) = self.compute_logical_ast_from_leaf_lenient(*leaf);
@@ -2050,6 +2050,16 @@ mod test {
20502050
);
20512051
}
20522052

2053+
#[test]
2054+
pub fn test_deduplication() {
2055+
let query = "be be";
2056+
test_parse_query_to_logical_ast_helper(
2057+
query,
2058+
"(Term(field=0, type=Str, \"be\") Term(field=1, type=Str, \"be\"))",
2059+
false,
2060+
);
2061+
}
2062+
20532063
#[test]
20542064
pub fn test_regex() {
20552065
let expected_regex = tantivy_fst::Regex::new(r".*b").unwrap();

0 commit comments

Comments
 (0)