Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/deep_impact/models/original.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def get_query_document_token_mask(cls, query_terms: Set[str], term_to_token_inde
def process_query(cls, query: str) -> Set[str]:
query = cls.tokenizer.normalizer.normalize_str(query)
terms = map(lambda x: x[0], cls.tokenizer.pre_tokenizer.pre_tokenize_str(query))
filtered_terms = filter(lambda x: x not in cls.punctuation, terms)
filtered_terms = set(filter(lambda x: x not in cls.punctuation, terms))
stemmed_terms = set()
for term in filtered_terms:
if term not in cls.stemmer_cache:
Expand Down Expand Up @@ -141,13 +141,14 @@ def process_document(cls, document: str) -> Tuple[tokenizers.Encoding, Dict[str,

# filter out duplicate terms, punctuations, and terms whose tokens overflow the max length
for i, term in enumerate(document_terms):
if term not in cls.stemmer_cache:
cls.stemmer_cache[term] = cls.stemmer.stem(term)
term = cls.stemmer_cache[term]

if term not in filtered_term_to_token_index \
and term not in cls.punctuation \
and i in term_index_to_token_index:
# check if stemm is cached
if term not in cls.stemmer_cache:
cls.stemmer_cache[term] = cls.stemmer.stem(term)
term = cls.stemmer_cache[term]
filtered_term_to_token_index[term] = term_index_to_token_index[i]
return encoded, filtered_term_to_token_index

Expand Down