diff --git a/.github/workflows/medcat-v2-tutorials_main.yml b/.github/workflows/medcat-v2-tutorials_main.yml index 61f297dc5..c7fe8eb3c 100644 --- a/.github/workflows/medcat-v2-tutorials_main.yml +++ b/.github/workflows/medcat-v2-tutorials_main.yml @@ -83,4 +83,4 @@ jobs: - name: Smoke test tutorial run: | pytest --capture=no --collect-only --nbmake ${{ matrix.part }} - pytest --capture=no --nbmake -n=auto --nbmake-kernel=smoketests --nbmake-timeout=1800 ${{ matrix.part }} + pytest --capture=no --nbmake -n=auto --nbmake-kernel=smoketests --nbmake-timeout=2400 ${{ matrix.part }} diff --git a/medcat-v2/medcat/pipeline/pipeline.py b/medcat-v2/medcat/pipeline/pipeline.py index deaedd47e..dd50055f6 100644 --- a/medcat-v2/medcat/pipeline/pipeline.py +++ b/medcat-v2/medcat/pipeline/pipeline.py @@ -1,6 +1,7 @@ from typing import Optional, Iterable, Union import logging import os +import warnings from medcat.utils.defaults import COMPONENTS_FOLDER from medcat.tokenizing.tokenizers import BaseTokenizer, create_tokenizer @@ -43,8 +44,19 @@ def create_entity(self, doc: MutableDocument, doc, token_start_index, token_end_index, label) def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity: + warnings.warn( + "The `medcat.pipeline.pipeline.entity_from_tokens` method is" + "depreacated and subject to removal in a future release. Please " + "use `medcat.pipeline.pipeline.entity_from_tokens_in_doc` instead.", + DeprecationWarning, + stacklevel=2 + ) return self.tokenizer.entity_from_tokens(tokens) + def entity_from_tokens_in_doc( + self, tokens: list[MutableToken], doc: MutableDocument) -> MutableEntity: + return self.tokenizer.entity_from_tokens_in_doc(tokens, doc) + def __call__(self, text: str) -> MutableDocument: doc = self.tokenizer(text) for comp in self.components: @@ -342,6 +354,23 @@ def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity: """ return self._tokenizer.entity_from_tokens(tokens) + def entity_from_tokens_in_doc(self, tokens: list[MutableToken], + doc: MutableDocument) -> MutableEntity: + """Get the entity from the list of tokens in a document. + + This effectively turns a list of (consecutive) documents + into an entity. But it is also designed to reuse existing + instances on the document instead of creating new ones. + + Args: + tokens (list[MutableToken]): The tokens to use. + doc (MutableDocument): The document for these tokens. + + Returns: + MutableEntity: The resulting entity. + """ + return self._tokenizer.entity_from_tokens_in_doc(tokens, doc) + def get_component(self, ctype: CoreComponentType) -> CoreComponent: """Get the core component by the component type. diff --git a/medcat-v2/medcat/tokenizing/regex_impl/tokenizer.py b/medcat-v2/medcat/tokenizing/regex_impl/tokenizer.py index 01b60b016..407c65b4e 100644 --- a/medcat-v2/medcat/tokenizing/regex_impl/tokenizer.py +++ b/medcat-v2/medcat/tokenizing/regex_impl/tokenizer.py @@ -1,6 +1,7 @@ import re from typing import cast, Optional, Iterator, overload, Union, Any, Type from collections import defaultdict +import warnings from medcat.tokenizing.tokens import ( BaseToken, BaseEntity, BaseDocument, @@ -340,6 +341,14 @@ def create_entity(self, doc: MutableDocument, # return Entity(span) def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity: + warnings.warn( + "The `medcat.tokenizing.tokenizers.Tokenizer.entity_from_tokens` method is" + "depreacated and subject to removal in a future release. Please use " + "`medcat.tokenizing.tokenizers.Tokenizer.entity_from_tokens_in_doc` " + "instead.", + DeprecationWarning, + stacklevel=2 + ) if not tokens: raise ValueError("Need at least one token for an entity") doc = cast(Token, tokens[0])._doc @@ -347,6 +356,23 @@ def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity: end_index = doc._tokens.index(tokens[-1]) return _entity_from_tokens(doc, tokens, start_index, end_index) + def _get_existing_entity(self, tokens: list[MutableToken], + doc: MutableDocument) -> Optional[MutableEntity]: + if not tokens: + return None + for ent in doc.ner_ents + doc.linked_ents: + if (ent.base.start_index == tokens[0].base.index and + ent.base.end_index == tokens[-1].base.index): + return ent + return None + + def entity_from_tokens_in_doc(self, tokens: list[MutableToken], + doc: MutableDocument) -> MutableEntity: + existing_ent = self._get_existing_entity(tokens, doc) + if existing_ent: + return existing_ent + return self.entity_from_tokens(tokens) + def _get_tokens_matches(self, text: str) -> list[re.Match[str]]: tokens = self.REGEX.finditer(text) return list(tokens) diff --git a/medcat-v2/medcat/tokenizing/spacy_impl/tokenizers.py b/medcat-v2/medcat/tokenizing/spacy_impl/tokenizers.py index 0d18ed4f5..ed1493225 100644 --- a/medcat-v2/medcat/tokenizing/spacy_impl/tokenizers.py +++ b/medcat-v2/medcat/tokenizing/spacy_impl/tokenizers.py @@ -3,6 +3,7 @@ import os import shutil import logging +import warnings import spacy from spacy.tokens import Span @@ -77,6 +78,14 @@ def create_entity(self, doc: MutableDocument, return Entity(span) def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity: + warnings.warn( + "The `medcat.tokenizing.tokenizers.Tokenizer.entity_from_tokens` method is" + "depreacated and subject to removal in a future release. Please use " + "`medcat.tokenizing.tokenizers.Tokenizer.entity_from_tokens_in_doc` " + "instead.", + DeprecationWarning, + stacklevel=2 + ) if not tokens: raise ValueError("Need at least one token for an entity") spacy_tokens = cast(list[Token], tokens) @@ -84,6 +93,23 @@ def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity: spacy_tokens[-1].index + 1) return Entity(span) + def _get_existing_entity(self, tokens: list[MutableToken], + doc: MutableDocument) -> Optional[MutableEntity]: + if not tokens: + return None + for ent in doc.ner_ents + doc.linked_ents: + if (ent.base.start_index == tokens[0].base.index and + ent.base.end_index == tokens[-1].base.index): + return ent + return None + + def entity_from_tokens_in_doc(self, tokens: list[MutableToken], + doc: MutableDocument) -> MutableEntity: + existing_ent = self._get_existing_entity(tokens, doc) + if existing_ent: + return existing_ent + return self.entity_from_tokens(tokens) + def __call__(self, text: str) -> MutableDocument: if self._avoid_pipe: doc = Document(self._nlp.make_doc(text)) diff --git a/medcat-v2/medcat/tokenizing/tokenizers.py b/medcat-v2/medcat/tokenizing/tokenizers.py index 834a7ec34..4c7d05a45 100644 --- a/medcat-v2/medcat/tokenizing/tokenizers.py +++ b/medcat-v2/medcat/tokenizing/tokenizers.py @@ -34,15 +34,22 @@ def create_entity(self, doc: MutableDocument, pass def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity: - """Get an entity from the list of tokens. + """Deprecated: use entity_from_tokens_in_doc instead.""" + pass + + def entity_from_tokens_in_doc(self, tokens: list[MutableToken], + doc: MutableDocument) -> MutableEntity: + """Get an entity from the list of tokens in the specified document. + + This method is designed to reuse entities where possible. Args: tokens (list[MutableToken]): List of tokens. + doc (MutableDocument): The document for these tokens. Returns: MutableEntity: The resulting entity. """ - pass def __call__(self, text: str) -> MutableDocument: pass diff --git a/medcat-v2/medcat/trainer.py b/medcat-v2/medcat/trainer.py index e462e8d5d..7a2b32ef8 100644 --- a/medcat-v2/medcat/trainer.py +++ b/medcat-v2/medcat/trainer.py @@ -11,7 +11,7 @@ from medcat.utils.data_utils import make_mc_train_test, get_false_positives from medcat.utils.filters import project_filters from medcat.data.mctexport import ( - MedCATTrainerExport, MedCATTrainerExportProject, + MedCATTrainerExport, MedCATTrainerExportAnnotation, MedCATTrainerExportProject, MedCATTrainerExportDocument, count_all_annotations, iter_anns) from medcat.preprocessors.cleaners import prepare_name, NameDescriptor from medcat.components.types import CoreComponentType, TrainableComponent @@ -397,6 +397,20 @@ def _train_supervised_for_project(self, docs, current_document, train_from_false_positives, devalue_others) + def _prepare_doc_with_anns( + self, doc: MutableDocument, + anns: list[MedCATTrainerExportAnnotation]) -> None: + ents = [] + for ann in anns: + tkns = doc.get_tokens(ann['start'], ann['end']) + ents.append(self._pipeline.entity_from_tokens_in_doc(tkns, doc)) + # set NER ents + doc.ner_ents.clear() + doc.ner_ents.extend(ents) + # duplicate for linked as well, but in a a separate list + doc.linked_ents.clear() + doc.linked_ents.extend(ents) + def _train_supervised_for_project2(self, docs: list[MedCATTrainerExportDocument], current_document: int, @@ -412,9 +426,10 @@ def _train_supervised_for_project2(self, with temp_changed_config(self.config.components.linking, 'train', False): mut_doc = self.caller(doc['text']) + self._prepare_doc_with_anns(mut_doc, doc['annotations']) # Compatibility with old output where annotations are a list - for ann in doc['annotations']: + for ann, mut_entity in zip(doc['annotations'], mut_doc.linked_ents): if ann.get('killed', False): continue logger.info(" Annotation %s (%s) [%d:%d]", @@ -422,7 +437,6 @@ def _train_supervised_for_project2(self, cui = ann['cui'] start = ann['start'] end = ann['end'] - mut_entity = mut_doc.get_tokens(start, end) if not mut_entity: logger.warning( "When looking for CUI '%s' (value '%s') [%d...%d] " diff --git a/medcat-v2/tests/resources/supervised_mct_export.json b/medcat-v2/tests/resources/supervised_mct_export.json index d6cdcd828..3f2c6df54 100644 --- a/medcat-v2/tests/resources/supervised_mct_export.json +++ b/medcat-v2/tests/resources/supervised_mct_export.json @@ -58,7 +58,7 @@ { "cui": "C04", "start": 81, - "end": 87, + "end": 88, "value": "fittest" } ], @@ -125,7 +125,7 @@ "id": "ID-3", "last_modified": "2024-08-21", "name": "Doc#4", - "text": "The RHS male is healthy as considered by all available tests. There are no indications that the patient is not fittest." + "text": "The RHS male is healthy as considered by all available tests. There are no indications that the patient is not fittest" } ], "id": "Project#0", diff --git a/medcat-v2/tests/test_cat.py b/medcat-v2/tests/test_cat.py index 26733d2ea..f28ba7e5e 100644 --- a/medcat-v2/tests/test_cat.py +++ b/medcat-v2/tests/test_cat.py @@ -7,6 +7,7 @@ from contextlib import contextmanager from medcat import cat +from medcat.data.mctexport import count_all_annotations, iter_anns from medcat.data.model_card import ModelCard from medcat.vocab import Vocab from medcat.config import Config @@ -576,7 +577,7 @@ class CATSupTrainingTests(CATUnsupTrainingTests): os.path.dirname(__file__), 'resources', 'supervised_mct_export.json' ) # NOTE: should remain consistent unless we change the model or data - EXPECTED_HASH = "7bfe01e8e36eb07d" + EXPECTED_HASH = "9c299628c9e6c220" @classmethod def _get_cui_counts(cls) -> dict[str, int]: @@ -620,6 +621,21 @@ def test_clearing_training_works(self): self.assertEqual(self.cat.config.meta.unsup_trained, []) self.assertEqual(self.cat.config.meta.sup_trained, []) + def test_training_happens_in_correct_order(self): + with captured_state_cdb(self.cat.cdb): + with unittest.mock.patch.object( + self.cat.trainer, "add_and_train_concept") as mock_add_and_train_concept: + self._perform_training() + mct_export = self._get_data() + called_ents = [ + args.kwargs['mut_entity'] for args in mock_add_and_train_concept.call_args_list + ] + self.assertEqual(len(called_ents), count_all_annotations(mct_export)) + for (_, _, ann), ent in zip(iter_anns(mct_export), called_ents): + with self.subTest(f"Ann: {ann} vs Ent: {ent}"): + self.assertEqual(ann['start'], ent.base.start_char_index) + self.assertEqual(ann['end'], ent.base.end_char_index) + class CATWithDictNERSupTrainingTests(CATSupTrainingTests): from medcat.components.types import CoreComponentType diff --git a/medcat-v2/tests/test_trainer.py b/medcat-v2/tests/test_trainer.py index 03b181547..3f049a4f1 100644 --- a/medcat-v2/tests/test_trainer.py +++ b/medcat-v2/tests/test_trainer.py @@ -1,12 +1,14 @@ import os import json +from medcat.tokenizing.tokens import MutableDocument from medcat.trainer import Trainer from medcat.config import Config from medcat.vocab import Vocab from medcat.data.mctexport import MedCATTrainerExport import unittest +import unittest.mock import random import pandas as pd @@ -25,6 +27,22 @@ def _add_concept(self, *args, **kwargs) -> None: pass +class FakeMutToken: + + def __init__(self, doc: 'FakeMutDoc', + index: int, start_char_index: int, + end_char_index: int) -> None: + self.index = index + self.char_index = start_char_index + self.text = doc.text[start_char_index: end_char_index] + self.to_skip = False + self.base = self + + @property + def lower(self): + return self.text.lower() + + class FakeMutEnt: def __init__(self, doc: 'FakeMutDoc', @@ -46,15 +64,20 @@ class FakeMutDoc: def __init__(self, text: str): self.text = text self.base = self + self.ner_ents = [] + self.linked_ents = [] def isupper(self) -> bool: return self.text.isupper() - def get_tokens(self, start_index: int, end_index: int): - return FakeMutEnt(self, start_index, end_index) + def get_tokens(self, start_index: int, end_index: int, chars_per_tkns: int = 5): + return [ + FakeMutToken(self, (cstart // chars_per_tkns), cstart, cstart + chars_per_tkns) + for cstart in range(start_index, end_index, chars_per_tkns) + ] def __iter__(self): - yield self.get_tokens(0, 1) + yield from self.get_tokens(0, len(self.text)) class FakeComponent: @@ -72,6 +95,9 @@ def tokenizer_with_tag(self, text: str) -> FakeMutDoc: def get_component(self, comp_type): return FakeComponent + def entity_from_tokens_in_doc(self, tkns: list, doc: MutableDocument) -> FakeMutEnt: + return FakeMutEnt(doc, tkns[0].index, tkns[-1].index) + class TrainerTestsBase(unittest.TestCase): DATA_CNT = 14 @@ -177,6 +203,33 @@ class TrainerSupervisedTests(TrainerUnsupervisedTests): } ] } + TRAIN_DATA_MULTI: MedCATTrainerExport = { + "projects": [ + *TRAIN_DATA['projects'], + { + 'cuis': '', + 'tuis': '', + 'documents': [ + { + 'id': "P2D1", + 'name': "Project#2Doc#1", + 'last_modified': 'N/A', + 'text': 'Some long text', + 'annotations': [ + { + 'cui': "C1", + 'start': 0, + 'end': 4, + 'value': 'SOME', + } + ] + } + ], + 'id': "PID#2", + 'name': "PROJECT#2", + } + ] + } @classmethod def setUpClass(cls): @@ -185,6 +238,20 @@ def setUpClass(cls): def test_training_gets_remembered_gen(self): pass # NOTE: no generation for supervised training + def test_trains_all_projects(self): + with unittest.mock.patch.object(self.trainer, "_train_supervised_for_project") as mock_train_project: + self.train(self.TRAIN_DATA_MULTI) + self.assertTrue(mock_train_project.call_count, 2) + + def test_training_happens_on_linked_ents_on_doc(self): + with unittest.mock.patch.object(self.trainer, "add_and_train_concept") as mock_add_and_train_concept: + self.train(self.TRAIN_DATA) + for num, args in enumerate(mock_add_and_train_concept.call_args_list): + with self.subTest(str(num)): + mock_add_and_train_concept.assert_called() + doc, ent = args.kwargs['mut_doc'], args.kwargs['mut_entity'] + self.assertIn(ent, doc.linked_ents) + class FromSratchBase(TrainedModelTests): RNG_SEED = 42 diff --git a/medcat-v2/tests/tokenizing/spacy_impl/test_tokenizers.py b/medcat-v2/tests/tokenizing/spacy_impl/test_tokenizers.py index fca604c0e..49b9252cf 100644 --- a/medcat-v2/tests/tokenizing/spacy_impl/test_tokenizers.py +++ b/medcat-v2/tests/tokenizing/spacy_impl/test_tokenizers.py @@ -4,6 +4,7 @@ from medcat.tokenizing.spacy_impl.tokenizers import SpacyTokenizer from medcat.tokenizing.regex_impl.tokenizer import RegexTokenizer from medcat.config import Config +from medcat.tokenizing.tokens import MutableDocument, MutableEntity, MutableToken from medcat.utils.registry import Registry import unittest @@ -42,3 +43,40 @@ class DefaultTokenizerInitTests2(DefaultTokenizerInitTests): default_provider = 'regex' default_cls = RegexTokenizer default_creator = RegexTokenizer.create_new_tokenizer + + +class TokenizerTests(unittest.TestCase): + default_provider = 'spacy' + text = "Some text to tokenize" + + @classmethod + def setUpClass(cls): + cls.cnf = Config() + + def setUp(self) -> None: + self.tokenizer = tokenizers.create_tokenizer( + self.default_provider, self.cnf) + self.doc = self.tokenizer(self.text) + self.doc.ner_ents = self._create_ner_ents(self.doc) + self.doc.linked_ents = self.doc.ner_ents.copy() + + def _create_ner_ents( + self, doc: MutableDocument, + targets: list[str] = ["text",]) -> list[MutableEntity]: + return [ + self.tokenizer.create_entity( + doc, + (tkns := doc.get_tokens( + start := doc.base.text.index(target), start + len(target)))[0].base.index, + tkns[-1].base.index + 1, + label=target) + for target in targets + ] + + def test_getting_entity_based_on_tokens_gets_same_instance(self): + for ent in self.doc.ner_ents: + with self.subTest(f"Ent: {ent} in doc {self.doc}"): + tokens = list(ent) + got_ent = self.tokenizer.entity_from_tokens_in_doc(tokens, self.doc) + self.assertIs(got_ent, ent) + self.assertIn(got_ent, self.doc.ner_ents)