diff --git a/src/deep_impact/models/original.py b/src/deep_impact/models/original.py index 005d74e..2496b23 100644 --- a/src/deep_impact/models/original.py +++ b/src/deep_impact/models/original.py @@ -107,7 +107,7 @@ def get_query_document_token_mask(cls, query_terms: Set[str], term_to_token_inde def process_query(cls, query: str) -> Set[str]: query = cls.tokenizer.normalizer.normalize_str(query) terms = map(lambda x: x[0], cls.tokenizer.pre_tokenizer.pre_tokenize_str(query)) - filtered_terms = filter(lambda x: x not in cls.punctuation, terms) + filtered_terms = set(filter(lambda x: x not in cls.punctuation, terms)) stemmed_terms = set() for term in filtered_terms: if term not in cls.stemmer_cache: @@ -141,13 +141,14 @@ def process_document(cls, document: str) -> Tuple[tokenizers.Encoding, Dict[str, # filter out duplicate terms, punctuations, and terms whose tokens overflow the max length for i, term in enumerate(document_terms): + if term not in cls.stemmer_cache: + cls.stemmer_cache[term] = cls.stemmer.stem(term) + term = cls.stemmer_cache[term] + if term not in filtered_term_to_token_index \ and term not in cls.punctuation \ and i in term_index_to_token_index: # check if stemm is cached - if term not in cls.stemmer_cache: - cls.stemmer_cache[term] = cls.stemmer.stem(term) - term = cls.stemmer_cache[term] filtered_term_to_token_index[term] = term_index_to_token_index[i] return encoded, filtered_term_to_token_index