diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..740daf8 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,70 @@ +name: Tests + +on: + push: + branches: + - main + pull_request: + branches: + - main + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + test: + name: pytest (Python ${{ matrix.python-version }}) + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache pip + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-py${{ matrix.python-version }}-pip- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e . + pip install pytest pytest-cov + + - name: Run pytest + run: pytest -v --tb=short --cov=src/glinker --cov-report=term-missing + + lint: + name: ruff + runs-on: ubuntu-latest + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install ruff + run: pip install ruff + + - name: ruff check + run: ruff check src/glinker + + - name: ruff format --check + run: ruff format --check src/glinker diff --git a/pyproject.toml b/pyproject.toml index 9c6849c..e172381 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "pydantic>=2.0.0", "pyyaml>=6.0.0", "tqdm>=4.65.0", + "rapidfuzz>=3.0.0", ] [project.optional-dependencies] @@ -72,6 +73,101 @@ target-version = ["py310", "py311"] [tool.ruff] line-length = 100 target-version = "py310" +src = ["src"] + +[tool.ruff.lint] +select = [ + "F", # Pyflakes + "E", # pycodestyle errors + "W", # pycodestyle warnings + "I", # isort + "D", # pydocstyle + "UP", # pyupgrade + "B", # bugbear + "SIM", # simplify + "ARG", # unused arguments + "T20", # print statements + "C4", # comprehensions + "EM", # errmsg + "PL", # Pylint + "RUF" # Ruff-specific rules +] + +ignore = [ + "D100", # Missing docstring in public module + "D101", # Missing docstring in public class + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "D104", # Missing docstring in public package + "D105", # Missing docstring in magic method + "D107", # Missing docstring in `__init__` + "D200", # One-line docstring should fit on one line + "D205", # Blank line required between summary and description + "D212", # Multi-line docstring summary should start at the first line + "D400", # First line should end with a period + "D401", # First line should be in imperative mood + "D417", # Missing argument descriptions + "RUF012", # Mutable class attributes should be annotated + "PLR0913",# Too many arguments + "PLR0912",# Too many branches + "PLR0915",# Too many statements + "PLR2004",# Magic value used in comparison + "PLW2901",# Loop variable overwritten + "B006", # Mutable defaults + "B027", # Empty method without abstract decorator + "B904", # Raise without from inside except + "S101", # Use of `assert` detected + "SIM102", # Collapsible if statements + "SIM105", # Use contextlib.suppress instead of try-except-pass + "SIM108", # Use ternary operator + "UP035", # Deprecated typing imports + "UP006", # Deprecated typing imports + "EM101", # Exception message formatting + "EM102", # Exception message formatting + "ARG002", # Unused method arguments + "B905", # zip() without explicit strict= parameter + "E402", # Module level import not at top (for conditional imports) + "E501", # Line too long (handled by formatter) + "E722", # Bare except (intentional for fallback logic) + "E741", # Ambiguous variable name (short names in comprehensions) + "PLC0206",# Extracting value from dictionary without .items() + "PLC0415",# Import should be at top-level (for dynamic imports) + "T201", # Print statements (used for debugging) +] + +unfixable = [ + "T201", # Print statements + "T203", # pprint statements +] + +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = [ + "F401", # Unused imports in __init__.py are often intentional + "F403", # Star imports in __init__.py +] +"tests/**/*.py" = [ + "D", # No docstring requirements in tests + "ARG", # Unused arguments are common in test fixtures + "PLR2004",# Magic values are fine in tests + "S101", # Assert is expected in tests +] + +[tool.ruff.lint.isort] +length-sort = true +length-sort-straight = true +combine-as-imports = true +known-first-party = ["glinker"] + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.lint.pylint] +max-args = 20 +max-branches = 15 +max-returns = 8 +max-statements = 60 [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/src/glinker/__init__.py b/src/glinker/__init__.py index ab7fedc..c9b27d6 100644 --- a/src/glinker/__init__.py +++ b/src/glinker/__init__.py @@ -10,45 +10,44 @@ from glinker.l2 import processor as _l2_processor from glinker.l3 import processor as _l3_processor from glinker.l4 import processor as _l4_processor - from glinker.core import ( - BaseConfig, + PipeNode, BaseInput, + BaseConfig, BaseOutput, + DAGExecutor, + DAGPipeline, + InputConfig, + PipeContext, + OutputConfig, BaseComponent, BaseProcessor, - ProcessorRegistry, - processor_registry, + ConfigBuilder, + FieldResolver, + ReshapeConfig, ProcessorFactory, + ProcessorRegistry, load_yaml, - InputConfig, - OutputConfig, - ReshapeConfig, - PipeNode, - PipeContext, - FieldResolver, - DAGPipeline, - DAGExecutor, - ConfigBuilder, + processor_registry, ) __all__ = [ - 'BaseConfig', - 'BaseInput', - 'BaseOutput', - 'BaseComponent', - 'BaseProcessor', - 'ProcessorRegistry', - 'processor_registry', - 'ProcessorFactory', - 'load_yaml', - 'InputConfig', - 'OutputConfig', - 'ReshapeConfig', - 'PipeNode', - 'PipeContext', - 'FieldResolver', - 'DAGPipeline', - 'DAGExecutor', - 'ConfigBuilder', + "BaseComponent", + "BaseConfig", + "BaseInput", + "BaseOutput", + "BaseProcessor", + "ConfigBuilder", + "DAGExecutor", + "DAGPipeline", + "FieldResolver", + "InputConfig", + "OutputConfig", + "PipeContext", + "PipeNode", + "ProcessorFactory", + "ProcessorRegistry", + "ReshapeConfig", + "load_yaml", + "processor_registry", ] diff --git a/src/glinker/core/__init__.py b/src/glinker/core/__init__.py index 6f977b4..cee531c 100644 --- a/src/glinker/core/__init__.py +++ b/src/glinker/core/__init__.py @@ -1,56 +1,47 @@ +from .dag import ( + PipeNode, + DAGExecutor, + DAGPipeline, + InputConfig, + PipeContext, + OutputConfig, + FieldResolver, + ReshapeConfig, +) from .base import ( - BaseConfig, + InputT, + ConfigT, + OutputT, BaseInput, + BaseConfig, BaseOutput, BaseComponent, BaseProcessor, - ConfigT, - InputT, - OutputT -) -from .registry import ( - ProcessorRegistry, - processor_registry ) from .factory import ProcessorFactory, load_yaml - -from .dag import ( - InputConfig, - OutputConfig, - ReshapeConfig, - PipeNode, - PipeContext, - FieldResolver, - DAGPipeline, - DAGExecutor -) - from .builders import ConfigBuilder +from .registry import ProcessorRegistry, processor_registry __all__ = [ - 'BaseConfig', - 'BaseInput', - 'BaseOutput', - 'BaseComponent', - 'BaseProcessor', - 'ConfigT', - 'InputT', - 'OutputT', - - 'ProcessorRegistry', - 'processor_registry', - - 'ProcessorFactory', - 'load_yaml', - - 'InputConfig', - 'OutputConfig', - 'ReshapeConfig', - 'PipeNode', - 'PipeContext', - 'FieldResolver', - 'DAGPipeline', - 'DAGExecutor', - - 'ConfigBuilder', -] \ No newline at end of file + "BaseComponent", + "BaseConfig", + "BaseInput", + "BaseOutput", + "BaseProcessor", + "ConfigBuilder", + "ConfigT", + "DAGExecutor", + "DAGPipeline", + "FieldResolver", + "InputConfig", + "InputT", + "OutputConfig", + "OutputT", + "PipeContext", + "PipeNode", + "ProcessorFactory", + "ProcessorRegistry", + "ReshapeConfig", + "load_yaml", + "processor_registry", +] diff --git a/src/glinker/core/base.py b/src/glinker/core/base.py index 088b39d..9c53d50 100644 --- a/src/glinker/core/base.py +++ b/src/glinker/core/base.py @@ -1,25 +1,28 @@ from abc import ABC, abstractmethod -from typing import Generic, TypeVar, Any -from pydantic import BaseModel, Field +from typing import Any, Generic, TypeVar +from pydantic import BaseModel -ConfigT = TypeVar('ConfigT', bound=BaseModel) -InputT = TypeVar('InputT', bound=BaseModel) -OutputT = TypeVar('OutputT', bound=BaseModel) +ConfigT = TypeVar("ConfigT", bound=BaseModel) +InputT = TypeVar("InputT", bound=BaseModel) +OutputT = TypeVar("OutputT", bound=BaseModel) class BaseConfig(BaseModel): - """Base configuration for all components""" + """Base configuration for all components.""" + pass class BaseInput(BaseModel): - """Base input model""" + """Base input model.""" + pass class BaseOutput(BaseModel): - """Base output model""" + """Base output model.""" + pass @@ -28,18 +31,18 @@ class BaseComponent(ABC, Generic[ConfigT]): Base component class that implements core logic. Each component should have discrete methods that can be chained. """ - + def __init__(self, config: ConfigT): self.config = config self._setup() - + def _setup(self): - """Override this for initialization logic""" + """Override this for initialization logic.""" pass - + @abstractmethod def get_available_methods(self) -> list[str]: - """Return list of available pipeline methods""" + """Return list of available pipeline methods.""" pass @@ -47,57 +50,49 @@ class BaseProcessor(ABC, Generic[ConfigT, InputT, OutputT]): """ Base processor that orchestrates component methods via pipeline. """ - + def __init__( self, config: ConfigT, component: BaseComponent[ConfigT], - pipeline: list[tuple[str, dict[str, Any]]] = None + pipeline: list[tuple[str, dict[str, Any]]] | None = None, ): self.config = config self.component = component self.pipeline = pipeline or self._default_pipeline() - + @abstractmethod def _default_pipeline(self) -> list[tuple[str, dict[str, Any]]]: - """Define default pipeline for this processor""" + """Define default pipeline for this processor.""" pass - + def _validate_pipeline(self): - """Validate that all pipeline methods exist in component""" + """Validate that all pipeline methods exist in component.""" available = self.component.get_available_methods() for method_name, _ in self.pipeline: if method_name not in available: raise ValueError( - f"Method '{method_name}' not found in component. " - f"Available: {available}" + f"Method '{method_name}' not found in component. Available: {available}" ) - - def _execute_pipeline_step( - self, - data: Any, - method_name: str, - kwargs: dict[str, Any] - ) -> Any: - """Execute single pipeline step""" + + def _execute_pipeline_step(self, data: Any, method_name: str, kwargs: dict[str, Any]) -> Any: + """Execute single pipeline step.""" method = getattr(self.component, method_name) return method(data, **kwargs) - + def _execute_pipeline( - self, - data: Any, - pipeline: list[tuple[str, dict[str, Any]]] = None + self, data: Any, pipeline: list[tuple[str, dict[str, Any]]] | None = None ) -> Any: - """Execute full pipeline on data""" + """Execute full pipeline on data.""" pipe = pipeline or self.pipeline result = data - + for method_name, kwargs in pipe: result = self._execute_pipeline_step(result, method_name, kwargs) - + return result - + @abstractmethod def __call__(self, input_data: InputT) -> OutputT: - """Process input through pipeline""" - pass \ No newline at end of file + """Process input through pipeline.""" + pass diff --git a/src/glinker/core/builders.py b/src/glinker/core/builders.py index 1b7eece..28a3542 100644 --- a/src/glinker/core/builders.py +++ b/src/glinker/core/builders.py @@ -4,10 +4,11 @@ ConfigBuilder: Unified builder with automatic defaults and full customization support. """ -from typing import List, Optional, Dict, Any, Literal -import yaml +from typing import Any, Dict, List, Literal from pathlib import Path +import yaml + class ConfigBuilder: """ @@ -35,7 +36,7 @@ class ConfigBuilder: """ class L1Builder: - """L1 configuration builder""" + """L1 configuration builder.""" def __init__(self, parent): self.parent = parent @@ -48,9 +49,9 @@ def spacy( max_right_context: int = 50, max_left_context: int = 50, min_entity_length: int = 2, - include_noun_chunks: bool = False + include_noun_chunks: bool = False, ) -> "ConfigBuilder": - """Configure L1 with spaCy NER""" + """Configure L1 with spaCy NER.""" self.parent._l1_type = "l1_spacy" self.parent._l1_config = { "model": model, @@ -59,7 +60,7 @@ def spacy( "max_right_context": max_right_context, "max_left_context": max_left_context, "min_entity_length": min_entity_length, - "include_noun_chunks": include_noun_chunks + "include_noun_chunks": include_noun_chunks, } return self.parent @@ -67,7 +68,7 @@ def gliner( self, model: str, labels: List[str], - token: Optional[str] = None, + token: str | None = None, device: str = "cpu", threshold: float = 0.3, flat_ner: bool = True, @@ -77,9 +78,9 @@ def gliner( max_left_context: int = 50, min_entity_length: int = 2, use_precomputed_embeddings: bool = False, - max_length: Optional[int] = 512 + max_length: int | None = 512, ) -> "ConfigBuilder": - """Configure L1 with GLiNER""" + """Configure L1 with GLiNER.""" self.parent._l1_type = "l1_gliner" self.parent._l1_config = { "model": model, @@ -94,12 +95,12 @@ def gliner( "max_left_context": max_left_context, "min_entity_length": min_entity_length, "use_precomputed_embeddings": use_precomputed_embeddings, - "max_length": max_length + "max_length": max_length, } return self.parent class L2Builder: - """L2 configuration builder""" + """L2 configuration builder.""" def __init__(self, parent): self.parent = parent @@ -108,12 +109,12 @@ def add( self, layer_type: Literal["dict", "redis", "elasticsearch", "postgres"], priority: int = 0, - write: bool = None, - search_mode: List[str] = None, - ttl: int = None, - cache_policy: str = None, - fuzzy_similarity: float = None, - **db_config + write: bool | None = None, + search_mode: List[str] | None = None, + ttl: int | None = None, + cache_policy: str | None = None, + fuzzy_similarity: float | None = None, + **db_config, ) -> "ConfigBuilder": """ Add a database layer to L2. @@ -142,7 +143,9 @@ def add( search_mode = ["exact"] if layer_type == "redis" else ["exact", "fuzzy"] if ttl is None: - ttl = {"dict": 0, "redis": 3600, "elasticsearch": 86400, "postgres": 0}.get(layer_type, 0) + ttl = {"dict": 0, "redis": 3600, "elasticsearch": 86400, "postgres": 0}.get( + layer_type, 0 + ) if cache_policy is None: cache_policy = "miss" if layer_type == "elasticsearch" else "always" @@ -155,7 +158,7 @@ def add( "search_mode": search_mode, "ttl": ttl, "cache_policy": cache_policy, - "field_mapping": self._default_field_mapping() + "field_mapping": self._default_field_mapping(), } # Add database-specific config @@ -166,14 +169,14 @@ def add( "max_distance": 64, "min_similarity": fuzzy_similarity, "n_gram_size": 3, - "prefix_length": 1 + "prefix_length": 1, } elif layer_type == "redis": layer["config"] = { "host": db_config.get("host", "localhost"), "port": db_config.get("port", 6379), - "db": db_config.get("db", 0) + "db": db_config.get("db", 0), } elif layer_type == "elasticsearch": @@ -195,7 +198,7 @@ def add( "port": db_config.get("port", 5432), "database": db_config.get("database", "entities_db"), "user": db_config.get("user", "postgres"), - "password": db_config.get("password", "postgres") + "password": db_config.get("password", "postgres"), } layer["fuzzy"] = {"min_similarity": fuzzy_similarity} @@ -207,14 +210,14 @@ def embeddings( enabled: bool = True, model_name: str = "knowledgator/gliner-linker-large-v1.0", dim: int = 768, - precompute_on_load: bool = False + precompute_on_load: bool = False, ) -> "ConfigBuilder": - """Configure embeddings for L2 (BiEncoder support)""" + """Configure embeddings for L2 (BiEncoder support).""" self.parent._l2_embeddings = { "enabled": enabled, "model_name": model_name, "dim": dim, - "precompute_on_load": precompute_on_load + "precompute_on_load": precompute_on_load, } # Add embedding fields to all layers @@ -225,14 +228,14 @@ def embeddings( return self.parent def _default_field_mapping(self) -> Dict[str, str]: - """Default field mapping""" + """Default field mapping.""" mapping = { "entity_id": "entity_id", "label": "label", "aliases": "aliases", "description": "description", "entity_type": "entity_type", - "popularity": "popularity" + "popularity": "popularity", } # Add embedding fields if embeddings enabled @@ -243,7 +246,7 @@ def _default_field_mapping(self) -> Dict[str, str]: return mapping class L3Builder: - """L3 configuration builder""" + """L3 configuration builder.""" def __init__(self, parent): self.parent = parent @@ -251,7 +254,7 @@ def __init__(self, parent): def configure( self, model: str = "knowledgator/gliner-linker-large-v1.0", - token: Optional[str] = None, + token: str | None = None, device: str = "cpu", threshold: float = 0.5, flat_ner: bool = True, @@ -259,9 +262,9 @@ def configure( batch_size: int = 1, use_precomputed_embeddings: bool = False, cache_embeddings: bool = False, - max_length: Optional[int] = 512 + max_length: int | None = 512, ) -> "ConfigBuilder": - """Configure L3 entity disambiguation""" + """Configure L3 entity disambiguation.""" self.parent._l3_config = { "model_name": model, "huggingface_token": token, @@ -272,12 +275,12 @@ def configure( "batch_size": batch_size, "use_precomputed_embeddings": use_precomputed_embeddings, "cache_embeddings": cache_embeddings, - "max_length": max_length + "max_length": max_length, } return self.parent class L4Builder: - """L4 configuration builder (optional GLiNER reranker with chunking)""" + """L4 configuration builder (optional GLiNER reranker with chunking).""" def __init__(self, parent): self.parent = parent @@ -285,13 +288,13 @@ def __init__(self, parent): def configure( self, model: str = "knowledgator/gliner-linker-large-v1.0", - token: Optional[str] = None, + token: str | None = None, device: str = "cpu", threshold: float = 0.5, flat_ner: bool = True, multi_label: bool = False, max_labels: int = 20, - max_length: Optional[int] = 512 + max_length: int | None = 512, ) -> "ConfigBuilder": """Configure L4 GLiNER reranker with candidate chunking. @@ -309,12 +312,12 @@ def configure( "flat_ner": flat_ner, "multi_label": multi_label, "max_labels": max_labels, - "max_length": max_length + "max_length": max_length, } return self.parent class L0Builder: - """L0 configuration builder""" + """L0 configuration builder.""" def __init__(self, parent): self.parent = parent @@ -325,19 +328,19 @@ def configure( include_unlinked: bool = True, return_all_candidates: bool = False, strict_matching: bool = True, - position_tolerance: int = 2 + position_tolerance: int = 2, ) -> "ConfigBuilder": - """Configure L0 aggregation parameters""" + """Configure L0 aggregation parameters.""" self.parent._l0_config = { "min_confidence": min_confidence, "include_unlinked": include_unlinked, "return_all_candidates": return_all_candidates, "strict_matching": strict_matching, - "position_tolerance": position_tolerance + "position_tolerance": position_tolerance, } return self.parent - def __init__(self, name: str = "pipeline", description: str = None): + def __init__(self, name: str = "pipeline", description: str | None = None): self.name = name self.description = description or f"{name} - auto-generated configuration" self._l1_config = None @@ -351,7 +354,7 @@ def __init__(self, name: str = "pipeline", description: str = None): "include_unlinked": True, "return_all_candidates": False, "strict_matching": True, - "position_tolerance": 2 + "position_tolerance": 2, } self._schema_template = "{label}: {description}" @@ -363,7 +366,7 @@ def __init__(self, name: str = "pipeline", description: str = None): self.l0 = self.L0Builder(self) def set_schema_template(self, template: str) -> "ConfigBuilder": - """Set label formatting template for L2/L3/L0""" + """Set label formatting template for L2/L3/L0.""" self._schema_template = template return self @@ -388,34 +391,36 @@ def build(self) -> Dict[str, Any]: # Auto-add dict layer if no L2 layers specified if not self._l2_layers: - self._l2_layers.append({ - "type": "dict", - "priority": 0, - "write": True, - "search_mode": ["exact", "fuzzy"], - "ttl": 0, - "cache_policy": "always", - "field_mapping": { - "entity_id": "entity_id", - "label": "label", - "aliases": "aliases", - "description": "description", - "entity_type": "entity_type", - "popularity": "popularity" - }, - "fuzzy": { - "max_distance": 64, - "min_similarity": 0.6, - "n_gram_size": 3, - "prefix_length": 1 + self._l2_layers.append( + { + "type": "dict", + "priority": 0, + "write": True, + "search_mode": ["exact", "fuzzy"], + "ttl": 0, + "cache_policy": "always", + "field_mapping": { + "entity_id": "entity_id", + "label": "label", + "aliases": "aliases", + "description": "description", + "entity_type": "entity_type", + "popularity": "popularity", + }, + "fuzzy": { + "max_distance": 64, + "min_similarity": 0.6, + "n_gram_size": 3, + "prefix_length": 1, + }, } - }) + ) # Build L2 config l2_config = { "max_candidates": 10 if self._l2_embeddings else 5, "min_popularity": 0, - "layers": self._l2_layers + "layers": self._l2_layers, } if self._l2_embeddings: @@ -430,29 +435,19 @@ def build(self) -> Dict[str, Any]: { "id": "l1", "processor": self._l1_type, - "inputs": { - "texts": { - "source": "$input", - "fields": "texts" - } - }, + "inputs": {"texts": {"source": "$input", "fields": "texts"}}, "output": {"key": "l1_result"}, - "config": self._l1_config + "config": self._l1_config, }, # L2 Node { "id": "l2", "processor": "l2_chain", "requires": ["l1"], - "inputs": { - "mentions": { - "source": "l1_result", - "fields": "entities" - } - }, + "inputs": {"mentions": {"source": "l1_result", "fields": "entities"}}, "output": {"key": "l2_result"}, "schema": {"template": self._schema_template}, - "config": l2_config + "config": l2_config, }, # L3 Node { @@ -460,22 +455,13 @@ def build(self) -> Dict[str, Any]: "processor": "l3_batch", "requires": ["l1", "l2"], "inputs": { - "texts": { - "source": "$input", - "fields": "texts" - }, - "candidates": { - "source": "l2_result", - "fields": "candidates" - }, - "l1_entities": { - "source": "l1_result", - "fields": "entities" - } + "texts": {"source": "$input", "fields": "texts"}, + "candidates": {"source": "l2_result", "fields": "candidates"}, + "l1_entities": {"source": "l1_result", "fields": "entities"}, }, "output": {"key": "l3_result"}, "schema": {"template": self._schema_template}, - "config": self._l3_config + "config": self._l3_config, }, ] @@ -485,60 +471,42 @@ def build(self) -> Dict[str, Any]: # Optional L4 reranker node if self._l4_config: - nodes.append({ - "id": "l4", - "processor": "l4_reranker", - "requires": ["l1", "l2", "l3"], - "inputs": { - "texts": { - "source": "$input", - "fields": "texts" - }, - "candidates": { - "source": "l2_result", - "fields": "candidates" + nodes.append( + { + "id": "l4", + "processor": "l4_reranker", + "requires": ["l1", "l2", "l3"], + "inputs": { + "texts": {"source": "$input", "fields": "texts"}, + "candidates": {"source": "l2_result", "fields": "candidates"}, + "l1_entities": {"source": "l1_result", "fields": "entities"}, }, - "l1_entities": { - "source": "l1_result", - "fields": "entities" - } - }, - "output": {"key": "l4_result"}, - "schema": {"template": self._schema_template}, - "config": self._l4_config - }) + "output": {"key": "l4_result"}, + "schema": {"template": self._schema_template}, + "config": self._l4_config, + } + ) l0_entity_source = "l4_result" l0_requires.append("l4") # L0 Node - nodes.append({ - "id": "l0", - "processor": "l0_aggregator", - "requires": l0_requires, - "inputs": { - "l1_entities": { - "source": "l1_result", - "fields": "entities" - }, - "l2_candidates": { - "source": "l2_result", - "fields": "candidates" + nodes.append( + { + "id": "l0", + "processor": "l0_aggregator", + "requires": l0_requires, + "inputs": { + "l1_entities": {"source": "l1_result", "fields": "entities"}, + "l2_candidates": {"source": "l2_result", "fields": "candidates"}, + "l3_entities": {"source": l0_entity_source, "fields": "entities"}, }, - "l3_entities": { - "source": l0_entity_source, - "fields": "entities" - } - }, - "output": {"key": "l0_result"}, - "config": self._l0_config, - "schema": {"template": self._schema_template} - }) - - config = { - "name": self.name, - "description": self.description, - "nodes": nodes - } + "output": {"key": "l0_result"}, + "config": self._l0_config, + "schema": {"template": self._schema_template}, + } + ) + + config = {"name": self.name, "description": self.description, "nodes": nodes} return config @@ -550,40 +518,25 @@ def _build_linking_only(self, l2_config: Dict[str, Any]) -> Dict[str, Any]: "processor": "l2_chain", "requires": [], "inputs": { - "mentions": { - "source": "$input", - "fields": "entities" - }, - "texts": { - "source": "$input", - "fields": "texts" - } + "mentions": {"source": "$input", "fields": "entities"}, + "texts": {"source": "$input", "fields": "texts"}, }, "output": {"key": "l2_result"}, "schema": {"template": self._schema_template}, - "config": l2_config + "config": l2_config, }, { "id": "l3", "processor": "l3_batch", "requires": ["l2"], "inputs": { - "texts": { - "source": "$input", - "fields": "texts" - }, - "candidates": { - "source": "l2_result", - "fields": "candidates" - }, - "l1_entities": { - "source": "$input", - "fields": "entities" - } + "texts": {"source": "$input", "fields": "texts"}, + "candidates": {"source": "l2_result", "fields": "candidates"}, + "l1_entities": {"source": "$input", "fields": "entities"}, }, "output": {"key": "l3_result"}, "schema": {"template": self._schema_template}, - "config": self._l3_config + "config": self._l3_config, }, ] @@ -591,64 +544,49 @@ def _build_linking_only(self, l2_config: Dict[str, Any]) -> Dict[str, Any]: l0_requires = ["l2", "l3"] if self._l4_config: - nodes.append({ - "id": "l4", - "processor": "l4_reranker", - "requires": ["l2", "l3"], - "inputs": { - "texts": { - "source": "$input", - "fields": "texts" + nodes.append( + { + "id": "l4", + "processor": "l4_reranker", + "requires": ["l2", "l3"], + "inputs": { + "texts": {"source": "$input", "fields": "texts"}, + "candidates": {"source": "l2_result", "fields": "candidates"}, }, - "candidates": { - "source": "l2_result", - "fields": "candidates" - } - }, - "output": {"key": "l4_result"}, - "schema": {"template": self._schema_template}, - "config": self._l4_config - }) + "output": {"key": "l4_result"}, + "schema": {"template": self._schema_template}, + "config": self._l4_config, + } + ) l0_entity_source = "l4_result" l0_requires.append("l4") - nodes.append({ - "id": "l0", - "processor": "l0_aggregator", - "requires": l0_requires, - "inputs": { - "l1_entities": { - "source": "$input", - "fields": "entities" - }, - "l2_candidates": { - "source": "l2_result", - "fields": "candidates" + nodes.append( + { + "id": "l0", + "processor": "l0_aggregator", + "requires": l0_requires, + "inputs": { + "l1_entities": {"source": "$input", "fields": "entities"}, + "l2_candidates": {"source": "l2_result", "fields": "candidates"}, + "l3_entities": {"source": l0_entity_source, "fields": "entities"}, }, - "l3_entities": { - "source": l0_entity_source, - "fields": "entities" - } - }, - "output": {"key": "l0_result"}, - "config": self._l0_config, - "schema": {"template": self._schema_template} - }) - - return { - "name": self.name, - "description": self.description, - "nodes": nodes - } + "output": {"key": "l0_result"}, + "config": self._l0_config, + "schema": {"template": self._schema_template}, + } + ) + + return {"name": self.name, "description": self.description, "nodes": nodes} def save(self, filepath: str) -> None: - """Save configuration to YAML file""" + """Save configuration to YAML file.""" config = self.build() # Create directory if needed Path(filepath).parent.mkdir(parents=True, exist_ok=True) - with open(filepath, 'w') as f: + with open(filepath, "w") as f: yaml.dump(config, f, default_flow_style=False, sort_keys=False) print(f"✓ Configuration saved to {filepath}") diff --git a/src/glinker/core/dag.py b/src/glinker/core/dag.py index 8138066..7e647c2 100644 --- a/src/glinker/core/dag.py +++ b/src/glinker/core/dag.py @@ -1,11 +1,12 @@ -from typing import Dict, List, Set, Any, Optional, Literal, Union -from collections import defaultdict, deque, OrderedDict -from pydantic import BaseModel, Field, ConfigDict -from datetime import datetime -from pathlib import Path import re import json import logging +from typing import Any, Set, Dict, List, Literal +from pathlib import Path +from datetime import datetime +from collections import OrderedDict, deque, defaultdict + +from pydantic import Field, BaseModel, ConfigDict logger = logging.getLogger(__name__) @@ -14,63 +15,55 @@ # INPUT/OUTPUT CONFIG # ============================================================================ + class ReshapeConfig(BaseModel): - """Configuration for data reshaping""" + """Configuration for data reshaping.""" + by: str = Field(..., description="Reference structure path: 'l1_result.entities'") mode: Literal["flatten_per_group", "preserve_structure"] = Field( - "flatten_per_group", - description="Reshape mode" + "flatten_per_group", description="Reshape mode" ) class InputConfig(BaseModel): """ - Unified input data specification - + Unified input data specification. + Examples: source: "l1_result" fields: "entities[*].text" reduce: "flatten" """ + source: str = Field( - ..., - description="Data source: key ('l1_result'), index ('outputs[-1]'), or '$input'" + ..., description="Data source: key ('l1_result'), index ('outputs[-1]'), or '$input'" ) - - fields: Union[str, List[str], None] = Field( - None, - description="JSONPath fields: 'entities[*].text' or ['label', 'score']" + + fields: str | List[str] | None = Field( + None, description="JSONPath fields: 'entities[*].text' or ['label', 'score']" ) - + reduce: Literal["all", "first", "last", "flatten"] = Field( - "all", - description="Reduction mode for lists" - ) - - reshape: Optional[ReshapeConfig] = Field( - None, - description="Data reshaping configuration" - ) - - template: Optional[str] = Field( - None, - description="Field concatenation template: '{label}: {description}'" + "all", description="Reduction mode for lists" ) - - filter: Optional[str] = Field( - None, - description="Filter expression: 'score > 0.5'" + + reshape: ReshapeConfig | None = Field(None, description="Data reshaping configuration") + + template: str | None = Field( + None, description="Field concatenation template: '{label}: {description}'" ) - + + filter: str | None = Field(None, description="Filter expression: 'score > 0.5'") + default: Any = None class OutputConfig(BaseModel): - """Output specification""" + """Output specification.""" + key: str = Field(..., description="Key for storing in context") - fields: Union[str, List[str], None] = Field( - None, - description="Fields to save (optional, defaults to all)" + fields: str | List[str] | None = Field( + None, description="Fields to save (optional, defaults to all)" ) @@ -78,49 +71,40 @@ class OutputConfig(BaseModel): # PIPE NODE # ============================================================================ + class PipeNode(BaseModel): """ - Single node in DAG pipeline - + Single node in DAG pipeline. + Represents one processing stage with: - Inputs (where to get data) - Processor (what to do) - Output (where to store result) - Dependencies (execution order) """ - + id: str = Field(..., description="Unique node identifier") - + processor: str = Field(..., description="Processor name from registry") - + inputs: Dict[str, InputConfig] = Field( - default_factory=dict, - description="Input parameter mappings" + default_factory=dict, description="Input parameter mappings" ) - + output: OutputConfig = Field(..., description="Output specification") - + requires: List[str] = Field( - default_factory=list, - description="Explicit dependencies (node IDs)" - ) - - config: Dict[str, Any] = Field( - default_factory=dict, - description="Processor configuration" + default_factory=list, description="Explicit dependencies (node IDs)" ) - field_schema: Optional[Dict[str, Any]] = Field( - None, - alias="schema", - description="Schema for field mappings/transformations" - ) + config: Dict[str, Any] = Field(default_factory=dict, description="Processor configuration") - condition: Optional[str] = Field( - None, - description="Conditional execution expression" + field_schema: Dict[str, Any] | None = Field( + None, alias="schema", description="Schema for field mappings/transformations" ) + condition: str | None = Field(None, description="Conditional execution expression") + model_config = ConfigDict(populate_by_name=True) @@ -128,26 +112,27 @@ class PipeNode(BaseModel): # PIPE CONTEXT # ============================================================================ + class PipeContext: """ - Pipeline execution context - + Pipeline execution context. + Stores all outputs from pipeline stages and provides unified access: - By key: "l1_result" - By index: "outputs[-1]" (last output) - Pipeline input: "$input" """ - + def __init__(self, pipeline_input: Any = None): self._outputs: OrderedDict[str, Any] = OrderedDict() self._execution_order: List[str] = [] self._pipeline_input = pipeline_input self._metadata: Dict[str, Any] = {} - - def set(self, key: str, value: Any, metadata: Optional[Dict] = None): + + def set(self, key: str, value: Any, metadata: Dict | None = None): """ - Store output - + Store output. + Args: key: Output key value: Output value @@ -155,14 +140,14 @@ def set(self, key: str, value: Any, metadata: Optional[Dict] = None): """ self._outputs[key] = value self._execution_order.append(key) - + if metadata: self._metadata[key] = metadata - + def get(self, source: str) -> Any: """ - Unified data access - + Unified data access. + Examples: - "$input" → pipeline input - "outputs[-1]" → last output @@ -171,149 +156,151 @@ def get(self, source: str) -> Any: """ if source == "$input": return self._pipeline_input - + if source.startswith("outputs["): index_str = source.replace("outputs[", "").replace("]", "") index = int(index_str) - + if index < 0: index = len(self._execution_order) + index - + if 0 <= index < len(self._execution_order): key = self._execution_order[index] return self._outputs[key] - + return None - + return self._outputs.get(source) - + def has(self, key: str) -> bool: - """Check if output exists""" + """Check if output exists.""" return key in self._outputs - + def get_all_outputs(self) -> Dict[str, Any]: - """Get all outputs as dict""" + """Get all outputs as dict.""" return dict(self._outputs) - - def get_metadata(self, key: str) -> Optional[Dict[str, Any]]: - """Get metadata for output""" + + def get_metadata(self, key: str) -> Dict[str, Any] | None: + """Get metadata for output.""" return self._metadata.get(key) - + def get_execution_order(self) -> List[str]: - """Get list of output keys in execution order""" + """Get list of output keys in execution order.""" return list(self._execution_order) - + @property def data(self) -> Dict[str, Any]: - """For compatibility""" + """For compatibility.""" return dict(self._outputs) - + def to_dict(self) -> Dict[str, Any]: """ - Serialize context to dict - + Serialize context to dict. + Returns: Dict with full context state """ + def serialize(value): - if hasattr(value, 'dict'): - return {'__type__': 'pydantic', 'data': value.dict()} + if hasattr(value, "dict"): + return {"__type__": "pydantic", "data": value.dict()} elif isinstance(value, list): return [serialize(item) for item in value] elif isinstance(value, dict): return {k: serialize(v) for k, v in value.items()} elif isinstance(value, OrderedDict): - return {'__type__': 'OrderedDict', 'data': list(value.items())} + return {"__type__": "OrderedDict", "data": list(value.items())} elif isinstance(value, (str, int, float, bool, type(None))): return value else: - return {'__type__': 'object', 'repr': repr(value)} - + return {"__type__": "object", "repr": repr(value)} + return { - 'outputs': {k: serialize(v) for k, v in self._outputs.items()}, - 'execution_order': self._execution_order, - 'pipeline_input': serialize(self._pipeline_input), - 'metadata': self._metadata, - 'saved_at': datetime.now().isoformat() + "outputs": {k: serialize(v) for k, v in self._outputs.items()}, + "execution_order": self._execution_order, + "pipeline_input": serialize(self._pipeline_input), + "metadata": self._metadata, + "saved_at": datetime.now().isoformat(), } - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'PipeContext': + def from_dict(cls, data: Dict[str, Any]) -> "PipeContext": """ - Deserialize context from dict - + Deserialize context from dict. + Args: data: Dict with saved state - + Returns: Restored PipeContext """ + def deserialize(value): if isinstance(value, dict): - if '__type__' in value: - if value['__type__'] == 'OrderedDict': - return OrderedDict(value['data']) - elif value['__type__'] == 'pydantic': - return value['data'] - elif value['__type__'] == 'object': - return value['repr'] + if "__type__" in value: + if value["__type__"] == "OrderedDict": + return OrderedDict(value["data"]) + elif value["__type__"] == "pydantic": + return value["data"] + elif value["__type__"] == "object": + return value["repr"] return {k: deserialize(v) for k, v in value.items()} elif isinstance(value, list): return [deserialize(item) for item in value] else: return value - - pipeline_input = deserialize(data.get('pipeline_input')) + + pipeline_input = deserialize(data.get("pipeline_input")) context = cls(pipeline_input) - - outputs_data = data.get('outputs', {}) + + outputs_data = data.get("outputs", {}) for key, value in outputs_data.items(): context._outputs[key] = deserialize(value) - - context._execution_order = data.get('execution_order', []) - context._metadata = data.get('metadata', {}) - + + context._execution_order = data.get("execution_order", []) + context._metadata = data.get("metadata", {}) + return context - - def to_json(self, filepath: str = None, indent: int = 2) -> str: + + def to_json(self, filepath: str | None = None, indent: int = 2) -> str: """ - Serialize to JSON - + Serialize to JSON. + Args: filepath: Path to save (optional) indent: Indentation for formatting - + Returns: JSON string """ data = self.to_dict() json_str = json.dumps(data, indent=indent, ensure_ascii=False) - + if filepath: - Path(filepath).write_text(json_str, encoding='utf-8') + Path(filepath).write_text(json_str, encoding="utf-8") logger.info(f"Context saved to {filepath}") - + return json_str - + @classmethod - def from_json(cls, json_data: str = None, filepath: str = None) -> 'PipeContext': + def from_json(cls, json_data: str | None = None, filepath: str | None = None) -> "PipeContext": """ - Load from JSON - + Load from JSON. + Args: json_data: JSON string (optional) filepath: Path to JSON file (optional) - + Returns: Restored PipeContext """ if filepath: - json_data = Path(filepath).read_text(encoding='utf-8') + json_data = Path(filepath).read_text(encoding="utf-8") logger.info(f"Context loaded from {filepath}") - + if not json_data: raise ValueError("Either json_data or filepath must be provided") - + data = json.loads(json_data) return cls.from_dict(data) @@ -322,112 +309,113 @@ def from_json(cls, json_data: str = None, filepath: str = None) -> 'PipeContext' # FIELD RESOLVER # ============================================================================ + class FieldResolver: - """Resolve fields from data using path expressions""" - + """Resolve fields from data using path expressions.""" + @staticmethod def resolve(context: PipeContext, config: InputConfig) -> Any: - """Main resolve method""" + """Main resolve method.""" data = context.get(config.source) if data is None: return config.default - + if config.fields: data = FieldResolver._extract_fields(data, config.fields) - + if config.template: data = FieldResolver._format_template(data, config.template) - + if isinstance(data, list) and config.reduce: data = FieldResolver._apply_reduce(data, config.reduce) - + if config.filter: data = FieldResolver._apply_filter(data, config.filter) - + return data - + @staticmethod def _extract_fields(data: Any, path: str) -> Any: """ - Extract field from data using path - + Extract field from data using path. + Examples: 'entities' -> data.entities 'entities[*]' -> [item for item in data.entities] 'entities[*].text' -> [item.text for item in data.entities] 'entities[*][*].text' -> [[e.text for e in group] for group in data.entities] """ - parts = path.split('.') + parts = path.split(".") current = data - + for part in parts: - if '[' in part: + if "[" in part: current = FieldResolver._handle_brackets(current, part) else: current = FieldResolver._get_attr(current, part) - + if current is None: return None - + return current - + @staticmethod def _handle_brackets(data: Any, part: str) -> Any: - """Handle parts with brackets like 'entities[*]' or 'items[0]' or '[*][*]'""" - if part.startswith('['): + """Handle parts with brackets like 'entities[*]' or 'items[0]' or '[*][*]'.""" + if part.startswith("["): field_name = None brackets = part else: - bracket_idx = part.index('[') + bracket_idx = part.index("[") field_name = part[:bracket_idx] brackets = part[bracket_idx:] - + current = data if field_name: current = FieldResolver._get_attr(current, field_name) - + if current is None: return None - - while '[' in brackets: - start = brackets.index('[') - end = brackets.index(']') - content = brackets[start+1:end] - brackets = brackets[end+1:] - - if content == '*': + + while "[" in brackets: + start = brackets.index("[") + end = brackets.index("]") + content = brackets[start + 1 : end] + brackets = brackets[end + 1 :] + + if content == "*": if not isinstance(current, list): current = [current] - elif ':' in content: - parts = content.split(':') + elif ":" in content: + parts = content.split(":") s = int(parts[0]) if parts[0] else None e = int(parts[1]) if parts[1] else None current = current[s:e] else: idx = int(content) current = current[idx] - + return current - + @staticmethod def _get_attr(data: Any, field: str) -> Any: - """Get attribute from data (works with dict, object, list)""" + """Get attribute from data (works with dict, object, list).""" if isinstance(data, list): return [FieldResolver._get_attr(item, field) for item in data] - + if isinstance(data, dict): return data.get(field) - + return getattr(data, field, None) - + @staticmethod - def _format_template(data: Union[List[Any], Any], template: str) -> Union[List[str], str]: - """Format data using template""" + def _format_template(data: List[Any] | Any, template: str) -> List[str] | str: + """Format data using template.""" if isinstance(data, list): results = [] for item in data: try: - if hasattr(item, 'dict'): + if hasattr(item, "dict"): results.append(template.format(**item.dict())) elif isinstance(item, dict): results.append(template.format(**item)) @@ -438,7 +426,7 @@ def _format_template(data: Union[List[Any], Any], template: str) -> Union[List[s return results else: try: - if hasattr(data, 'dict'): + if hasattr(data, "dict"): return template.format(**data.dict()) elif isinstance(data, dict): return template.format(**data) @@ -446,17 +434,18 @@ def _format_template(data: Union[List[Any], Any], template: str) -> Union[List[s return str(data) except: return str(data) - + @staticmethod def _apply_reduce(data: List[Any], mode: str) -> Any: - """Reduce list based on mode""" + """Reduce list based on mode.""" if mode == "first": return data[0] if data else None - + elif mode == "last": return data[-1] if data else None - + elif mode == "flatten": + def flatten(lst): result = [] for item in lst: @@ -465,35 +454,35 @@ def flatten(lst): else: result.append(item) return result - + return flatten(data) - + return data - + @staticmethod def _apply_filter(data: List[Any], filter_expr: str) -> List[Any]: - """Filter list based on expression""" + """Filter list based on expression.""" if not isinstance(data, list): return data - - pattern = r'(\w+)\s*(>=|<=|>|<|==|!=)\s*(.+)' + + pattern = r"(\w+)\s*(>=|<=|>|<|==|!=)\s*(.+)" match = re.match(pattern, filter_expr) - + if not match: return data - + field, operator, value = match.groups() - + try: if value.startswith("'") or value.startswith('"'): value = value.strip("'\"") - elif '.' in value: + elif "." in value: value = float(value) else: value = int(value) except: pass - + result = [] for item in data: try: @@ -501,29 +490,29 @@ def _apply_filter(data: List[Any], filter_expr: str) -> List[Any]: item_value = item.get(field) else: item_value = getattr(item, field, None) - + if item_value is None: continue - + passes = False - if operator == '>': + if operator == ">": passes = item_value > value - elif operator == '>=': + elif operator == ">=": passes = item_value >= value - elif operator == '<': + elif operator == "<": passes = item_value < value - elif operator == '<=': + elif operator == "<=": passes = item_value <= value - elif operator == '==': + elif operator == "==": passes = item_value == value - elif operator == '!=': + elif operator == "!=": passes = item_value != value - + if passes: result.append(item) except: continue - + return result @@ -531,128 +520,124 @@ def _apply_filter(data: List[Any], filter_expr: str) -> List[Any]: # DAG EXECUTOR # ============================================================================ + class DAGPipeline(BaseModel): - """DAG pipeline configuration""" + """DAG pipeline configuration.""" + name: str = Field(...) nodes: List[PipeNode] = Field(...) - description: Optional[str] = None + description: str | None = None class DAGExecutor: - """Executes DAG pipeline with topological sort""" - + """Executes DAG pipeline with topological sort.""" + def __init__(self, pipeline: DAGPipeline, verbose: bool = False): self.pipeline = pipeline self.verbose = verbose - + self.nodes_map: Dict[str, PipeNode] = {} self.dependency_graph: Dict[str, List[str]] = defaultdict(list) self.reverse_graph: Dict[str, Set[str]] = defaultdict(set) self.processors: Dict[str, Any] = {} - + self._build_dependency_graph() self._initialize_processors() - + def _build_dependency_graph(self): - """Build dependency graph from nodes""" + """Build dependency graph from nodes.""" for node in self.pipeline.nodes: self.nodes_map[node.id] = node - + for node in self.pipeline.nodes: # Explicit dependencies for dep_id in node.requires: self.dependency_graph[dep_id].append(node.id) self.reverse_graph[node.id].add(dep_id) - + # Implicit dependencies from inputs for input_config in node.inputs.values(): source = input_config.source - + if source == "$input" or source.startswith("outputs["): continue - + if source in self.nodes_map: self.dependency_graph[source].append(node.id) self.reverse_graph[node.id].add(source) - + if input_config.reshape and input_config.reshape.by: - reshape_source = input_config.reshape.by.split('.')[0] + reshape_source = input_config.reshape.by.split(".")[0] if reshape_source in self.nodes_map: if reshape_source not in self.reverse_graph[node.id]: self.dependency_graph[reshape_source].append(node.id) self.reverse_graph[node.id].add(reshape_source) - + def _initialize_processors(self): - """Initialize all processors once""" + """Initialize all processors once.""" from .registry import processor_registry - + if self.verbose: logger.info(f"Initializing {len(self.nodes_map)} processors...") - + for node_id, node in self.nodes_map.items(): try: processor_factory = processor_registry.get(node.processor) processor = processor_factory(config_dict=node.config, pipeline=None) self.processors[node_id] = processor - + if self.verbose: logger.info(f" Created processor for '{node_id}' ({node.processor})") except Exception as e: raise RuntimeError( f"Failed to create processor '{node.processor}' for node '{node_id}': {e}" ) - + if self.verbose: - logger.info(f"All processors initialized and cached") - + logger.info("All processors initialized and cached") + def _topological_sort(self) -> List[List[str]]: - """Topological sort with level grouping""" + """Topological sort with level grouping.""" in_degree = {} for node_id in self.nodes_map: in_degree[node_id] = len(self.reverse_graph.get(node_id, set())) - - queue = deque([ - node_id for node_id, degree in in_degree.items() - if degree == 0 - ]) - + + queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0]) + levels = [] visited = set() - + while queue: current_level = list(queue) levels.append(current_level) - + next_queue = deque() for node_id in current_level: visited.add(node_id) - + for dependent_id in self.dependency_graph.get(node_id, []): in_degree[dependent_id] -= 1 - + if in_degree[dependent_id] == 0: next_queue.append(dependent_id) - + queue = next_queue - + if len(visited) != len(self.nodes_map): unvisited = set(self.nodes_map.keys()) - visited - raise ValueError( - f"Cycle detected in pipeline DAG! " - f"Unvisited nodes: {unvisited}" - ) - + raise ValueError(f"Cycle detected in pipeline DAG! Unvisited nodes: {unvisited}") + return levels - + def load_entities( self, source, - target_layers: List[str] = None, + target_layers: List[str] | None = None, batch_size: int = 1000, - overwrite: bool = False + overwrite: bool = False, ) -> Dict[str, Dict[str, int]]: """ - Load entities into database layers + Load entities into database layers. Finds all L2 processors and loads entities into their database layers. @@ -669,7 +654,7 @@ def load_entities( results = {} for node_id, processor in self.processors.items(): - if hasattr(processor, 'component') and hasattr(processor.component, 'load_entities'): + if hasattr(processor, "component") and hasattr(processor.component, "load_entities"): if self.verbose: logger.info(f"\nLoading entities for node '{node_id}'") @@ -677,38 +662,36 @@ def load_entities( source=source, target_layers=target_layers, batch_size=batch_size, - overwrite=overwrite + overwrite=overwrite, ) results[node_id] = result return results - - def clear_databases(self, layer_names: List[str] = None) -> Dict[str, bool]: - """Clear database layers in all L2 processors""" + + def clear_databases(self, layer_names: List[str] | None = None) -> Dict[str, bool]: + """Clear database layers in all L2 processors.""" results = {} - + for node_id, processor in self.processors.items(): - if hasattr(processor, 'component') and hasattr(processor.component, 'clear_layers'): + if hasattr(processor, "component") and hasattr(processor.component, "clear_layers"): processor.component.clear_layers(layer_names) results[node_id] = True - + return results - + def count_entities(self) -> Dict[str, Dict[str, int]]: - """Count entities in all database layers""" + """Count entities in all database layers.""" results = {} for node_id, processor in self.processors.items(): - if hasattr(processor, 'component') and hasattr(processor.component, 'count_entities'): + if hasattr(processor, "component") and hasattr(processor.component, "count_entities"): counts = processor.component.count_entities() results[node_id] = counts return results def precompute_embeddings( - self, - target_layers: List[str] = None, - batch_size: int = 32 + self, target_layers: List[str] | None = None, batch_size: int = 32 ) -> Dict[str, int]: """ Precompute embeddings for all entities using L3 model and L2 schema. @@ -734,12 +717,14 @@ def precompute_embeddings( node = self.nodes_map[node_id] # Check for L2 (has component with precompute_embeddings) - if hasattr(processor, 'component') and hasattr(processor.component, 'precompute_embeddings'): + if hasattr(processor, "component") and hasattr( + processor.component, "precompute_embeddings" + ): l2_processor = processor l2_node = node # Check for L3 (has component with encode_labels) - if hasattr(processor, 'component') and hasattr(processor.component, 'encode_labels'): + if hasattr(processor, "component") and hasattr(processor.component, "encode_labels"): l3_processor = processor l3_node = node @@ -757,11 +742,11 @@ def precompute_embeddings( ) # Get schema from L2 node (or L3 as fallback) - template = '{label}' + template = "{label}" if l2_node and l2_node.field_schema: - template = l2_node.field_schema.get('template', '{label}') + template = l2_node.field_schema.get("template", "{label}") elif l3_node and l3_node.field_schema: - template = l3_node.field_schema.get('template', '{label}') + template = l3_node.field_schema.get("template", "{label}") # Apply schema to L2 processor if l2_node and l2_node.field_schema: @@ -770,7 +755,7 @@ def precompute_embeddings( model_id = l3_processor.config.model_name if self.verbose: - logger.info(f"\nPrecomputing embeddings:") + logger.info("\nPrecomputing embeddings:") logger.info(f" Model: {model_id}") logger.info(f" Template: {template}") logger.info(f" Target layers: {target_layers or 'all'}") @@ -785,114 +770,115 @@ def encoder_fn(labels: List[str]): template=template, model_id=model_id, target_layers=target_layers, - batch_size=batch_size + batch_size=batch_size, ) if self.verbose: - logger.info(f"\nPrecompute completed:") + logger.info("\nPrecompute completed:") for layer, count in results.items(): logger.info(f" {layer}: {count} entities") return results def setup_l3_cache_writeback(self): - """Setup L3 processor to write back embeddings to L2""" + """Setup L3 processor to write back embeddings to L2.""" l2_processor = None l3_processor = None - for node_id, processor in self.processors.items(): - if hasattr(processor, 'component') and hasattr(processor.component, 'precompute_embeddings'): + for _node_id, processor in self.processors.items(): + if hasattr(processor, "component") and hasattr( + processor.component, "precompute_embeddings" + ): l2_processor = processor - if hasattr(processor, '_l2_processor'): + if hasattr(processor, "_l2_processor"): l3_processor = processor if l2_processor and l3_processor: l3_processor._l2_processor = l2_processor if self.verbose: logger.info("L3 cache write-back enabled") - + def execute(self, pipeline_input: Any) -> PipeContext: - """Execute full pipeline""" + """Execute full pipeline.""" context = PipeContext(pipeline_input) execution_levels = self._topological_sort() - + if self.verbose: logger.info(f"Executing pipeline: {self.pipeline.name}") logger.info(f"Total nodes: {len(self.nodes_map)}") logger.info(f"Execution levels: {len(execution_levels)}") - + for level_idx, level_nodes in enumerate(execution_levels): if self.verbose: - logger.info(f"\n{'='*60}") + logger.info(f"\n{'=' * 60}") logger.info( - f"Level {level_idx + 1}/{len(execution_levels)} " - f"({len(level_nodes)} nodes)" + f"Level {level_idx + 1}/{len(execution_levels)} ({len(level_nodes)} nodes)" ) - logger.info(f"{'='*60}") - + logger.info(f"{'=' * 60}") + for node_id in level_nodes: self._run_node(node_id, context) - + if self.verbose: - logger.info(f"\nPipeline completed successfully!") - + logger.info("\nPipeline completed successfully!") + return context - + def _run_node(self, node_id: str, context: PipeContext): - """Execute single node""" + """Execute single node.""" node = self.nodes_map[node_id] - + if self.verbose: logger.info(f"\nExecuting: {node.id} (processor: {node.processor})") - + if node.condition and not self._evaluate_condition(node.condition, context): if self.verbose: - logger.info(f" Skipped (condition not met)") + logger.info(" Skipped (condition not met)") return - + # Resolve inputs kwargs = {} for param_name, input_config in node.inputs.items(): try: value = FieldResolver.resolve(context, input_config) kwargs[param_name] = value - + if self.verbose: logger.info(f" Input '{param_name}': {input_config.source}") except Exception as e: raise ValueError( f"Failed to resolve input '{param_name}' for node '{node_id}': {e}" ) - + # Get cached processor processor = self.processors[node_id] # Apply schema if needed - if node.field_schema and hasattr(processor, 'schema'): + if node.field_schema and hasattr(processor, "schema"): processor.schema = node.field_schema - + # Execute processor try: result = processor(**kwargs) - + if self.verbose: - logger.info(f" Processing...") + logger.info(" Processing...") except Exception as e: if self.verbose: logger.error(f" Failed: {e}") raise RuntimeError(f"Node '{node_id}' failed: {e}") - + # Extract output fields if specified if node.output.fields: result = FieldResolver._extract_fields(result, node.output.fields) - + # Store output context.set(node.output.key, result) - + if self.verbose: logger.info(f" Output: '{node.output.key}'") - logger.info(f" Success") - + logger.info(" Success") + def _evaluate_condition(self, condition: str, context: PipeContext) -> bool: - """Evaluate conditional expression""" - return True \ No newline at end of file + """Evaluate conditional expression.""" + return True diff --git a/src/glinker/core/factory.py b/src/glinker/core/factory.py index ff295c0..420ca57 100644 --- a/src/glinker/core/factory.py +++ b/src/glinker/core/factory.py @@ -1,44 +1,44 @@ import warnings +from typing import Any, Dict, List from pathlib import Path -from typing import Any, Dict, List, Optional, Union + import yaml + +from .dag import PipeNode, DAGExecutor, DAGPipeline, InputConfig, OutputConfig from .registry import processor_registry -from .dag import DAGPipeline, DAGExecutor, PipeNode, InputConfig, OutputConfig def load_yaml(path: str | Path) -> dict: - """Load YAML configuration file""" - with open(path, 'r') as f: + """Load YAML configuration file.""" + with open(path) as f: return yaml.safe_load(f) class ProcessorFactory: - """Factory for creating pipelines from configs""" - + """Factory for creating pipelines from configs.""" + @staticmethod def create_from_registry( - processor_name: str, - config_dict: dict, - pipeline: list[tuple[str, dict]] = None + processor_name: str, config_dict: dict, pipeline: list[tuple[str, dict]] | None = None ): """ - Create single processor from registry - + Create single processor from registry. + For internal use by DAGExecutor """ factory = processor_registry.get(processor_name) return factory(config_dict, pipeline) - + @staticmethod def create_pipeline(config_path: str | Path, verbose: bool = False) -> DAGExecutor: """ - Create DAG pipeline from YAML config - + Create DAG pipeline from YAML config. + Supports: - Single node (just L2) - Multiple nodes (L1 → L2 → L3) - Complex DAGs with dependencies - + Example config: name: "my_pipeline" nodes: @@ -50,60 +50,56 @@ def create_pipeline(config_path: str | Path, verbose: bool = False) -> DAGExecut config: {...} """ config = load_yaml(config_path) - + nodes = [] - for node_cfg in config['nodes']: + for node_cfg in config["nodes"]: inputs = {} - for name, data in node_cfg['inputs'].items(): + for name, data in node_cfg["inputs"].items(): inputs[name] = InputConfig(**data) - + node = PipeNode( - id=node_cfg['id'], - processor=node_cfg['processor'], + id=node_cfg["id"], + processor=node_cfg["processor"], inputs=inputs, - output=OutputConfig(**node_cfg['output']), - requires=node_cfg.get('requires', []), - config=node_cfg['config'], - schema=node_cfg.get('schema') + output=OutputConfig(**node_cfg["output"]), + requires=node_cfg.get("requires", []), + config=node_cfg["config"], + schema=node_cfg.get("schema"), ) nodes.append(node) - + pipeline = DAGPipeline( - name=config['name'], - description=config.get('description'), - nodes=nodes + name=config["name"], description=config.get("description"), nodes=nodes ) - + return DAGExecutor(pipeline, verbose=verbose) - + @staticmethod def create_from_dict(config_dict: dict, verbose: bool = False) -> DAGExecutor: """ - Create pipeline from dict (for programmatic use) - + Create pipeline from dict (for programmatic use). + Same as create_pipeline but accepts dict instead of file path """ nodes = [] - for node_cfg in config_dict['nodes']: + for node_cfg in config_dict["nodes"]: inputs = {} - for name, data in node_cfg['inputs'].items(): + for name, data in node_cfg["inputs"].items(): inputs[name] = InputConfig(**data) - + node = PipeNode( - id=node_cfg['id'], - processor=node_cfg['processor'], + id=node_cfg["id"], + processor=node_cfg["processor"], inputs=inputs, - output=OutputConfig(**node_cfg['output']), - requires=node_cfg.get('requires', []), - config=node_cfg['config'], - schema=node_cfg.get('schema') + output=OutputConfig(**node_cfg["output"]), + requires=node_cfg.get("requires", []), + config=node_cfg["config"], + schema=node_cfg.get("schema"), ) nodes.append(node) - + pipeline = DAGPipeline( - name=config_dict['name'], - description=config_dict.get('description'), - nodes=nodes + name=config_dict["name"], description=config_dict.get("description"), nodes=nodes ) return DAGExecutor(pipeline, verbose=verbose) @@ -114,14 +110,14 @@ def create_simple( device: str = "cpu", threshold: float = 0.5, template: str = "{label}", - max_length: Optional[int] = 512, - token: Optional[str] = None, - entities: Optional[Union[str, Path, List[Dict[str, Any]], Dict[str, Dict[str, Any]]]] = None, + max_length: int | None = 512, + token: str | None = None, + entities: str | Path | List[Dict[str, Any]] | Dict[str, Dict[str, Any]] | None = None, precompute_embeddings: bool = False, verbose: bool = False, - reranker_model: Optional[str] = None, + reranker_model: str | None = None, reranker_max_labels: int = 20, - reranker_threshold: Optional[float] = None, + reranker_threshold: float | None = None, external_entities: bool = False, ) -> DAGExecutor: """ @@ -233,27 +229,31 @@ def create_simple( l0_requires = ["l2", "l3"] if reranker_model: - nodes.append({ - "id": "l4", - "processor": "l4_reranker", - "requires": ["l2", "l3"], - "inputs": { - "texts": {"source": "$input", "fields": "texts"}, - "candidates": {"source": "l2_result", "fields": "candidates"}, - }, - "output": {"key": "l4_result"}, - "schema": {"template": template}, - "config": { - "model_name": reranker_model, - "device": device, - "threshold": reranker_threshold if reranker_threshold is not None else threshold, - "flat_ner": True, - "multi_label": False, - "max_labels": reranker_max_labels, - "max_length": max_length, - "token": token, - }, - }) + nodes.append( + { + "id": "l4", + "processor": "l4_reranker", + "requires": ["l2", "l3"], + "inputs": { + "texts": {"source": "$input", "fields": "texts"}, + "candidates": {"source": "l2_result", "fields": "candidates"}, + }, + "output": {"key": "l4_result"}, + "schema": {"template": template}, + "config": { + "model_name": reranker_model, + "device": device, + "threshold": reranker_threshold + if reranker_threshold is not None + else threshold, + "flat_ner": True, + "multi_label": False, + "max_labels": reranker_max_labels, + "max_length": max_length, + "token": token, + }, + } + ) l0_entity_source = "l4_result" l0_requires.append("l4") @@ -265,20 +265,22 @@ def create_simple( if external_entities: l0_inputs["l1_entities"] = {"source": "$input", "fields": "entities"} - nodes.append({ - "id": "l0", - "processor": "l0_aggregator", - "requires": l0_requires, - "inputs": l0_inputs, - "output": {"key": "l0_result"}, - "schema": {"template": template}, - "config": { - "strict_matching": external_entities, - "min_confidence": 0.0, - "include_unlinked": True, - "position_tolerance": 2, - }, - }) + nodes.append( + { + "id": "l0", + "processor": "l0_aggregator", + "requires": l0_requires, + "inputs": l0_inputs, + "output": {"key": "l0_result"}, + "schema": {"template": template}, + "config": { + "strict_matching": external_entities, + "min_confidence": 0.0, + "include_unlinked": True, + "position_tolerance": 2, + }, + } + ) config = { "name": "simple", @@ -292,4 +294,4 @@ def create_simple( if precompute_embeddings: executor.precompute_embeddings() - return executor \ No newline at end of file + return executor diff --git a/src/glinker/core/registry.py b/src/glinker/core/registry.py index 773b515..0878132 100644 --- a/src/glinker/core/registry.py +++ b/src/glinker/core/registry.py @@ -2,30 +2,31 @@ class ProcessorRegistry: - """Registry for processor factory functions""" - + """Registry for processor factory functions.""" + def __init__(self): self._registry: Dict[str, Callable] = {} - + def register(self, name: str): - """Decorator to register processor factory""" + """Decorator to register processor factory.""" + def decorator(factory: Callable): self._registry[name] = factory return factory + return decorator - + def get(self, name: str) -> Callable: - """Get processor factory by name""" + """Get processor factory by name.""" if name not in self._registry: raise KeyError( - f"Processor '{name}' not found. " - f"Available: {list(self._registry.keys())}" + f"Processor '{name}' not found. Available: {list(self._registry.keys())}" ) return self._registry[name] - + def list_available(self) -> list[str]: - """List all registered processor names""" + """List all registered processor names.""" return list(self._registry.keys()) -processor_registry = ProcessorRegistry() \ No newline at end of file +processor_registry = ProcessorRegistry() diff --git a/src/glinker/l0/__init__.py b/src/glinker/l0/__init__.py index 9d0b906..97719ca 100644 --- a/src/glinker/l0/__init__.py +++ b/src/glinker/l0/__init__.py @@ -1,21 +1,21 @@ """ -L0 - Aggregation Layer +L0 - Aggregation Layer. Combines outputs from L1 (mention extraction), L2 (candidate retrieval), and L3 (entity linking) into unified L0Entity structures with full pipeline context. """ -from .models import L0Config, L0Input, L0Output, L0Entity, LinkedEntity +from .models import L0Input, L0Config, L0Entity, L0Output, LinkedEntity from .component import L0Component from .processor import L0Processor, create_l0_processor __all__ = [ + "L0Component", "L0Config", + "L0Entity", "L0Input", "L0Output", - "L0Entity", - "LinkedEntity", - "L0Component", "L0Processor", - "create_l0_processor" + "LinkedEntity", + "create_l0_processor", ] diff --git a/src/glinker/l0/component.py b/src/glinker/l0/component.py index 8ad31f6..9e1cf60 100644 --- a/src/glinker/l0/component.py +++ b/src/glinker/l0/component.py @@ -1,17 +1,17 @@ import re -from typing import List, Optional, Dict, Tuple +from typing import Dict, List, Tuple + from glinker.core.base import BaseComponent -from .models import ( - L0Config, L0Entity, LinkedEntity -) from glinker.l1.models import L1Entity from glinker.l2.models import DatabaseRecord from glinker.l3.models import L3Entity +from .models import L0Config, L0Entity, LinkedEntity + class L0Component(BaseComponent[L0Config]): """ - L0 aggregation component - combines outputs from L1, L2, L3 + L0 aggregation component - combines outputs from L1, L2, L3. Workflow: 1. For each L1 mention → find its L2 candidates @@ -20,12 +20,7 @@ class L0Component(BaseComponent[L0Config]): """ def get_available_methods(self) -> List[str]: - return [ - "aggregate", - "filter_by_confidence", - "sort_by_confidence", - "calculate_stats" - ] + return ["aggregate", "filter_by_confidence", "sort_by_confidence", "calculate_stats"] @staticmethod def _normalize_entity(e) -> L1Entity: @@ -48,10 +43,10 @@ def aggregate( l1_entities: List[List[L1Entity]], l2_candidates: List[List[DatabaseRecord]], l3_entities: List[List[L3Entity]], - template: str = "{label}" + template: str = "{label}", ) -> List[List[L0Entity]]: """ - Main aggregation method - combines all layers + Main aggregation method - combines all layers. Args: l1_entities: [[L1Entity, ...], ...] - one list per text @@ -63,10 +58,7 @@ def aggregate( [[L0Entity, ...], ...] - aggregated entities per text """ # Normalize dict entities to L1Entity objects - l1_entities = [ - [self._normalize_entity(e) for e in text_ents] - for text_ents in l1_entities - ] + l1_entities = [[self._normalize_entity(e) for e in text_ents] for text_ents in l1_entities] all_results = [] @@ -86,10 +78,10 @@ def _aggregate_single_text( l1_mentions: List[L1Entity], l2_candidates: List[DatabaseRecord], l3_links: List[L3Entity], - template: str = "{label}" + template: str = "{label}", ) -> List[L0Entity]: """ - Aggregate data for a single text + Aggregate data for a single text. Strategy: 1. Build index of L3 linked entities by position @@ -114,8 +106,11 @@ def _aggregate_single_text( # Check if this mention was linked in L3 linked_entity, l3_pos = self._find_linked_entity_with_position( - l1_mention, l3_by_position, mention_candidates, template, - tolerance=self.config.position_tolerance + l1_mention, + l3_by_position, + mention_candidates, + template, + tolerance=self.config.position_tolerance, ) if l3_pos: @@ -135,7 +130,7 @@ def _aggregate_single_text( # Create L0Entity l0_entity = L0Entity( mention_text=l1_mention.text, - label=getattr(l1_mention, 'label', None), # Safe access in case L1 is skipped + label=getattr(l1_mention, "label", None), # Safe access in case L1 is skipped mention_start=l1_mention.start, mention_end=l1_mention.end, left_context=l1_mention.left_context, @@ -145,7 +140,7 @@ def _aggregate_single_text( linked_entity=linked_entity, is_linked=linked_entity is not None, candidate_scores=candidate_scores, - pipeline_stage=pipeline_stage + pipeline_stage=pipeline_stage, ) results.append(l0_entity) @@ -173,7 +168,7 @@ def _aggregate_single_text( confidence=l3_entity.score, start=l3_entity.start, end=l3_entity.end, - matched_text=l3_entity.text + matched_text=l3_entity.text, ) l0_entity = L0Entity( @@ -187,14 +182,14 @@ def _aggregate_single_text( linked_entity=linked, is_linked=True, candidate_scores=candidate_scores, - pipeline_stage="l3_only" # Indicates L3 found it without L1 + pipeline_stage="l3_only", # Indicates L3 found it without L1 ) results.append(l0_entity) return results def _build_l3_index(self, l3_links: List[L3Entity]) -> Dict[Tuple[int, int], L3Entity]: - """Build index of L3 entities by (start, end) position""" + """Build index of L3 entities by (start, end) position.""" index = {} for entity in l3_links: key = (entity.start, entity.end) @@ -202,10 +197,7 @@ def _build_l3_index(self, l3_links: List[L3Entity]) -> Dict[Tuple[int, int], L3E return index def _get_candidates_for_mention( - self, - mention_idx: int, - l1_mention: L1Entity, - all_candidates: List[DatabaseRecord] + self, mention_idx: int, l1_mention: L1Entity, all_candidates: List[DatabaseRecord] ) -> List[DatabaseRecord]: """ Get candidates for specific mention. @@ -236,10 +228,10 @@ def _find_linked_entity( l1_mention: L1Entity, l3_by_position: Dict[Tuple[int, int], L3Entity], candidates: List[DatabaseRecord], - template: str = "{label}" - ) -> Optional[LinkedEntity]: + template: str = "{label}", + ) -> LinkedEntity | None: """ - Find if this L1 mention was linked in L3 + Find if this L1 mention was linked in L3. Strategy: 1. Look up L3 entity by position (start, end) @@ -257,10 +249,10 @@ def _find_linked_entity_with_position( l3_by_position: Dict[Tuple[int, int], L3Entity], candidates: List[DatabaseRecord], template: str = "{label}", - tolerance: int = 2 - ) -> Tuple[Optional[LinkedEntity], Optional[Tuple[int, int]]]: + tolerance: int = 2, + ) -> Tuple[LinkedEntity | None, Tuple[int, int] | None]: """ - Find if this L1 mention was linked in L3, and return the matched position + Find if this L1 mention was linked in L3, and return the matched position. Returns: Tuple of (LinkedEntity or None, matched position tuple or None) @@ -290,7 +282,7 @@ def _find_linked_entity_with_position( confidence=l3_entity.score, start=l3_entity.start, end=l3_entity.end, - matched_text=l3_entity.text + matched_text=l3_entity.text, ), matched_key return LinkedEntity( @@ -299,7 +291,7 @@ def _find_linked_entity_with_position( confidence=l3_entity.score, start=l3_entity.start, end=l3_entity.end, - matched_text=l3_entity.text + matched_text=l3_entity.text, ), matched_key def _fuzzy_position_match( @@ -307,9 +299,9 @@ def _fuzzy_position_match( start: int, end: int, l3_by_position: Dict[Tuple[int, int], L3Entity], - tolerance: int = 2 - ) -> Optional[L3Entity]: - """Find L3 entity with position close to given range""" + tolerance: int = 2, + ) -> L3Entity | None: + """Find L3 entity with position close to given range.""" entity, _ = self._fuzzy_position_match_with_key(start, end, l3_by_position, tolerance) return entity @@ -318,9 +310,9 @@ def _fuzzy_position_match_with_key( start: int, end: int, l3_by_position: Dict[Tuple[int, int], L3Entity], - tolerance: int = 2 - ) -> Tuple[Optional[L3Entity], Optional[Tuple[int, int]]]: - """Find L3 entity with position close to given range, return with its key""" + tolerance: int = 2, + ) -> Tuple[L3Entity | None, Tuple[int, int] | None]: + """Find L3 entity with position close to given range, return with its key.""" for (l3_start, l3_end), entity in l3_by_position.items(): if abs(l3_start - start) <= tolerance and abs(l3_end - end) <= tolerance: return entity, (l3_start, l3_end) @@ -330,7 +322,7 @@ def _build_candidate_scores( self, class_probs: Dict[str, float], candidates: List[DatabaseRecord], - template: str = "{label}" + template: str = "{label}", ) -> Dict[str, float]: """ Map L3 class_probs (label -> probability) to candidate entity_ids. @@ -351,13 +343,10 @@ def _build_candidate_scores( return scores def _match_candidate_by_label( - self, - l3_label: str, - candidates: List[DatabaseRecord], - template: str = "{label}" - ) -> Optional[DatabaseRecord]: + self, l3_label: str, candidates: List[DatabaseRecord], template: str = "{label}" + ) -> DatabaseRecord | None: """ - Match L3 label with L2 candidate using the same template + Match L3 label with L2 candidate using the same template. Uses the schema template to format candidate labels the same way L3 did, enabling exact matching. @@ -381,16 +370,16 @@ def _match_candidate_by_label( for candidate in candidates: try: # Format candidate using same template as L3 - if hasattr(candidate, 'dict'): + if hasattr(candidate, "dict"): cand_dict = candidate.dict() else: cand_dict = { - 'label': candidate.label, - 'description': getattr(candidate, 'description', ''), - 'entity_id': getattr(candidate, 'entity_id', ''), - 'entity_type': getattr(candidate, 'entity_type', ''), - 'popularity': getattr(candidate, 'popularity', 0), - 'aliases': getattr(candidate, 'aliases', []) + "label": candidate.label, + "description": getattr(candidate, "description", ""), + "entity_id": getattr(candidate, "entity_id", ""), + "entity_type": getattr(candidate, "entity_type", ""), + "popularity": getattr(candidate, "popularity", 0), + "aliases": getattr(candidate, "aliases", []), } formatted_label = template.format(**cand_dict) @@ -405,7 +394,7 @@ def _match_candidate_by_label( # Check aliases for exact match for candidate in candidates: - for alias in getattr(candidate, 'aliases', []): + for alias in getattr(candidate, "aliases", []): if alias.lower().strip() == l3_label_lower: return candidate @@ -422,11 +411,9 @@ def _match_candidate_by_label( return best_match def _determine_stage( - self, - candidates: List[DatabaseRecord], - linked_entity: Optional[LinkedEntity] + self, candidates: List[DatabaseRecord], linked_entity: LinkedEntity | None ) -> str: - """Determine which pipeline stage was last successful""" + """Determine which pipeline stage was last successful.""" if linked_entity: return "l3_linked" elif candidates: @@ -435,16 +422,14 @@ def _determine_stage( return "l1_only" def filter_by_confidence( - self, - entities: List[List[L0Entity]], - min_confidence: float = None + self, entities: List[List[L0Entity]], min_confidence: float | None = None ) -> List[List[L0Entity]]: - """Filter entities by linking confidence""" + """Filter entities by linking confidence.""" threshold = min_confidence if min_confidence is not None else self.config.min_confidence filtered = [] for text_entities in entities: - # Keep linked entities above threshold OR keep unlinked if configured + # Keep linked entities above threshold OR keep unlinked if configured filtered_text = [] for e in text_entities: if e.linked_entity: @@ -452,25 +437,25 @@ def filter_by_confidence( filtered_text.append(e) elif self.config.include_unlinked: filtered_text.append(e) - + filtered.append(filtered_text) return filtered def sort_by_confidence(self, entities: List[List[L0Entity]]) -> List[List[L0Entity]]: - """Sort entities by linking confidence (descending)""" + """Sort entities by linking confidence (descending).""" sorted_results = [] for text_entities in entities: sorted_text = sorted( text_entities, key=lambda e: e.linked_entity.confidence if e.linked_entity else 0.0, - reverse=True + reverse=True, ) sorted_results.append(sorted_text) return sorted_results def calculate_stats(self, entities: List[List[L0Entity]]) -> dict: - """Calculate pipeline statistics""" + """Calculate pipeline statistics.""" total = 0 linked = 0 unlinked = 0 @@ -506,6 +491,6 @@ def calculate_stats(self, entities: List[List[L0Entity]]) -> dict: "l1_only": l1_only, "l2_found": l2_found, "l3_linked": l3_linked, - "l3_only": l3_only - } + "l3_only": l3_only, + }, } diff --git a/src/glinker/l0/models.py b/src/glinker/l0/models.py index bd336cc..905a441 100644 --- a/src/glinker/l0/models.py +++ b/src/glinker/l0/models.py @@ -1,36 +1,47 @@ +from typing import Dict, List + from pydantic import Field -from typing import Dict, List, Optional -from glinker.core.base import BaseConfig, BaseInput, BaseOutput + +from glinker.core.base import BaseInput, BaseConfig, BaseOutput from glinker.l1.models import L1Entity from glinker.l2.models import DatabaseRecord from glinker.l3.models import L3Entity class L0Config(BaseConfig): - """L0 aggregation configuration""" - min_confidence: float = Field(0.0, description="Minimum confidence threshold for linked entities") + """L0 aggregation configuration.""" + + min_confidence: float = Field( + 0.0, description="Minimum confidence threshold for linked entities" + ) include_unlinked: bool = Field(True, description="Include mentions without linked entities") - return_all_candidates: bool = Field(False, description="Return all candidates or only top match") + return_all_candidates: bool = Field( + False, description="Return all candidates or only top match" + ) strict_matching: bool = Field( True, description="If True, only include entities that match L1 mentions. " - "If False, also include L3 entities found outside L1 mentions." + "If False, also include L3 entities found outside L1 mentions.", ) position_tolerance: int = Field( 2, - description="Maximum character difference for fuzzy position matching between L1 and L3 entities" + description="Maximum character difference for fuzzy position matching between L1 and L3 entities", ) class L0Input(BaseInput): - """L0 processor input - outputs from L1, L2, L3""" + """L0 processor input - outputs from L1, L2, L3.""" + l1_entities: List[List[L1Entity]] = Field(..., description="Entities from L1 (per text)") - l2_candidates: List[List[DatabaseRecord]] = Field(..., description="Candidates from L2 (per mention)") + l2_candidates: List[List[DatabaseRecord]] = Field( + ..., description="Candidates from L2 (per mention)" + ) l3_entities: List[List[L3Entity]] = Field(..., description="Linked entities from L3 (per text)") class LinkedEntity(BaseOutput): - """Linked entity information from L3""" + """Linked entity information from L3.""" + entity_id: str = Field(..., description="Entity ID from matched candidate") label: str = Field(..., description="Entity label") confidence: float = Field(..., description="Linking confidence score from L3") @@ -44,11 +55,12 @@ class L0Entity(BaseOutput): Aggregated entity combining information from all layers: - L1: mention detection (text, position, context) - L2: candidates (entity database records) - - L3: disambiguation (linked entity with confidence) + - L3: disambiguation (linked entity with confidence). """ + # From L1 - mention detection mention_text: str = Field(..., description="Extracted mention text from L1") - label: Optional[str] = Field(None, description="Entity label from L1") + label: str | None = Field(None, description="Entity label from L1") mention_start: int = Field(..., description="Start position in original text") mention_end: int = Field(..., description="End position in original text") left_context: str = Field(..., description="Left context from L1") @@ -56,36 +68,31 @@ class L0Entity(BaseOutput): # From L2 - candidate retrieval candidates: List[DatabaseRecord] = Field( - default_factory=list, - description="All candidates found in L2 for this mention" + default_factory=list, description="All candidates found in L2 for this mention" ) num_candidates: int = Field(0, description="Number of candidates found") # From L3 - entity linking - linked_entity: Optional[LinkedEntity] = Field( - None, - description="Linked entity if disambiguation was successful" + linked_entity: LinkedEntity | None = Field( + None, description="Linked entity if disambiguation was successful" ) is_linked: bool = Field(False, description="Whether entity was successfully linked") candidate_scores: Dict[str, float] = Field( - default_factory=dict, - description="L3 class probability per candidate entity_id" + default_factory=dict, description="L3 class probability per candidate entity_id" ) # Aggregated metadata pipeline_stage: str = Field( - "", - description="Last successful stage: 'l1_only', 'l2_found', 'l3_linked'" + "", description="Last successful stage: 'l1_only', 'l2_found', 'l3_linked'" ) class L0Output(BaseOutput): - """L0 processor output""" + """L0 processor output.""" + entities: List[List[L0Entity]] = Field( - ..., - description="Aggregated entities per text with full pipeline information" + ..., description="Aggregated entities per text with full pipeline information" ) stats: dict = Field( - default_factory=dict, - description="Pipeline statistics (total, linked, unlinked, etc.)" + default_factory=dict, description="Pipeline statistics (total, linked, unlinked, etc.)" ) diff --git a/src/glinker/l0/processor.py b/src/glinker/l0/processor.py index 52a3b15..b73a376 100644 --- a/src/glinker/l0/processor.py +++ b/src/glinker/l0/processor.py @@ -1,16 +1,18 @@ from typing import Any, List + from glinker.core.base import BaseProcessor -from glinker.core.registry import processor_registry -from .models import L0Config, L0Input, L0Output, L0Entity -from .component import L0Component from glinker.l1.models import L1Entity from glinker.l2.models import DatabaseRecord from glinker.l3.models import L3Entity +from glinker.core.registry import processor_registry + +from .models import L0Input, L0Config, L0Output +from .component import L0Component class L0Processor(BaseProcessor[L0Config, L0Input, L0Output]): """ - L0 aggregation processor - combines outputs from all pipeline layers + L0 aggregation processor - combines outputs from all pipeline layers. This processor aggregates information from: - L1: Entity mentions (text, position, context) @@ -24,7 +26,7 @@ def __init__( self, config: L0Config, component: L0Component, - pipeline: list[tuple[str, dict[str, Any]]] = None + pipeline: list[tuple[str, dict[str, Any]]] | None = None, ): super().__init__(config, component, pipeline) self._validate_pipeline() @@ -35,18 +37,18 @@ def _default_pipeline(self) -> list[tuple[str, dict[str, Any]]]: ("aggregate", {}), ("filter_by_confidence", {}), ("sort_by_confidence", {}), - ("calculate_stats", {}) + ("calculate_stats", {}), ] def __call__( self, - l1_entities: List[List[L1Entity]] = None, - l2_candidates: List[List[DatabaseRecord]] = None, - l3_entities: List[List[L3Entity]] = None, - input_data: L0Input = None + l1_entities: List[List[L1Entity]] | None = None, + l2_candidates: List[List[DatabaseRecord]] | None = None, + l3_entities: List[List[L3Entity]] | None = None, + input_data: L0Input = None, ) -> L0Output: """ - Process and aggregate outputs from L1, L2, L3 + Process and aggregate outputs from L1, L2, L3. Args: l1_entities: Entities from L1 (mention extraction) @@ -57,7 +59,6 @@ def __call__( Returns: L0Output with aggregated entities and statistics """ - # Support both direct params and L0Input if input_data is not None: l1_entities = input_data.l1_entities @@ -75,7 +76,7 @@ def __call__( l1_entities = [[] for _ in l3_entities] # Pass schema template to component for matching - template = self.schema.get('template', '{label}') if self.schema else '{label}' + template = self.schema.get("template", "{label}") if self.schema else "{label}" # Execute aggregation pipeline aggregated_entities = self.component.aggregate( @@ -101,8 +102,8 @@ def __call__( @processor_registry.register("l0_aggregator") -def create_l0_processor(config_dict: dict, pipeline: list = None) -> L0Processor: - """Factory: creates component + processor""" +def create_l0_processor(config_dict: dict, pipeline: list | None = None) -> L0Processor: + """Factory: creates component + processor.""" config = L0Config(**config_dict) component = L0Component(config) return L0Processor(config, component, pipeline) diff --git a/src/glinker/l1/__init__.py b/src/glinker/l1/__init__.py index 90723a3..bdaa0c6 100644 --- a/src/glinker/l1/__init__.py +++ b/src/glinker/l1/__init__.py @@ -1,15 +1,15 @@ -from .models import L1Config, L1GlinerConfig, L1Input, L1Output, L1Entity +from .models import L1Input, L1Config, L1Entity, L1Output, L1GlinerConfig from .component import L1SpacyComponent, L1GlinerComponent from .processor import L1SpacyProcessor, L1GlinerProcessor __all__ = [ "L1Config", + "L1Entity", + "L1GlinerComponent", "L1GlinerConfig", + "L1GlinerProcessor", "L1Input", "L1Output", - "L1Entity", "L1SpacyComponent", "L1SpacyProcessor", - "L1GlinerComponent", - "L1GlinerProcessor", -] \ No newline at end of file +] diff --git a/src/glinker/l1/component.py b/src/glinker/l1/component.py index 2b27552..eec2247 100644 --- a/src/glinker/l1/component.py +++ b/src/glinker/l1/component.py @@ -1,20 +1,23 @@ -import spacy -from spacy.language import Language from typing import List + +import spacy import torch +from spacy.language import Language + from glinker.core.base import BaseComponent -from .models import L1Config, L1GlinerConfig, L1Entity + +from .models import L1Config, L1Entity, L1GlinerConfig class L1SpacyComponent(BaseComponent[L1Config]): - """spaCy-based entity extraction component""" + """spaCy-based entity extraction component.""" def _setup(self): - """Initialize spaCy model""" + """Initialize spaCy model.""" self.nlp = self._load_model() - + def _load_model(self) -> Language: - """Load or download spaCy model""" + """Load or download spaCy model.""" try: nlp = spacy.load(self.config.model) if self.config.device != "cpu": @@ -22,112 +25,109 @@ def _load_model(self) -> Language: return nlp except OSError: from spacy.cli import download + download(self.config.model) return spacy.load(self.config.model) - + def get_available_methods(self) -> list[str]: - """Return list of available pipeline methods""" + """Return list of available pipeline methods.""" return [ "extract_entities", "filter_by_length", "deduplicate", "sort_by_position", - "add_noun_chunks" + "add_noun_chunks", ] - + def extract_entities(self, text: str) -> list[L1Entity]: - """Extract named entities from text""" + """Extract named entities from text.""" doc = self.nlp(text) entities = [] seen_spans = set() - + for ent in doc.ents: span = (ent.start_char, ent.end_char) if span in seen_spans: continue - - left_context, right_context = self._get_context( - text, ent.start_char, ent.end_char + + left_context, right_context = self._get_context(text, ent.start_char, ent.end_char) + + entities.append( + L1Entity( + text=ent.text, + label=getattr(ent, "label_", None), # Safe access in case label_ is missing + start=ent.start_char, + end=ent.end_char, + left_context=left_context, + right_context=right_context, + ) ) - - entities.append(L1Entity( - text=ent.text, - label=getattr(ent, 'label_', None), # Safe access in case label_ is missing - start=ent.start_char, - end=ent.end_char, - left_context=left_context, - right_context=right_context - )) seen_spans.add(span) - + return entities - + def filter_by_length( - self, - entities: list[L1Entity], - min_length: int = None + self, entities: list[L1Entity], min_length: int | None = None ) -> list[L1Entity]: - """Filter entities by minimum text length""" + """Filter entities by minimum text length.""" min_len = min_length if min_length is not None else self.config.min_entity_length return [e for e in entities if len(e.text) >= min_len] - + def deduplicate(self, entities: list[L1Entity]) -> list[L1Entity]: - """Remove duplicate entities by span""" + """Remove duplicate entities by span.""" seen_spans = set() unique = [] - + for entity in entities: span = (entity.start, entity.end) if span not in seen_spans: unique.append(entity) seen_spans.add(span) - + return unique - + def sort_by_position(self, entities: list[L1Entity]) -> list[L1Entity]: - """Sort entities by start position""" + """Sort entities by start position.""" return sorted(entities, key=lambda x: x.start) - - def add_noun_chunks( - self, - text: str, - entities: list[L1Entity] = None - ) -> list[L1Entity]: - """Add noun chunks to entities list""" + + def add_noun_chunks(self, text: str, entities: list[L1Entity] | None = None) -> list[L1Entity]: + """Add noun chunks to entities list.""" if entities is None: entities = [] - + doc = self.nlp(text) seen_spans = {(e.start, e.end) for e in entities} - + for chunk in doc.noun_chunks: span = (chunk.start_char, chunk.end_char) - + overlap = False - for (s, e) in seen_spans: + for s, e in seen_spans: if not (chunk.end_char <= s or chunk.start_char >= e): overlap = True break - + if not overlap and len(chunk.text) >= self.config.min_entity_length: left_context, right_context = self._get_context( text, chunk.start_char, chunk.end_char ) - - entities.append(L1Entity( - text=chunk.text, - label="NOUN_CHUNK", - start=chunk.start_char, - end=chunk.end_char, - left_context=left_context, - right_context=right_context - )) + + entities.append( + L1Entity( + text=chunk.text, + label="NOUN_CHUNK", + start=chunk.start_char, + end=chunk.end_char, + left_context=left_context, + right_context=right_context, + ) + ) seen_spans.add(span) - + return entities - + def _get_context(self, text: str, start: int, end: int) -> tuple[str, str]: - """Extract left and right context for entity""" + """Extract left and right context for entity.""" left_start = max(0, start - self.config.max_left_context) left_context = text[left_start:start].strip() @@ -138,23 +138,23 @@ def _get_context(self, text: str, start: int, end: int) -> tuple[str, str]: class L1GlinerComponent(BaseComponent[L1GlinerConfig]): - """GLiNER-based entity extraction component for L1""" + """GLiNER-based entity extraction component for L1.""" def _setup(self): - """Initialize GLiNER model""" + """Initialize GLiNER model.""" from gliner import GLiNER self.model = GLiNER.from_pretrained( - self.config.model, - token=self.config.token, - max_length=self.config.max_length + self.config.model, token=self.config.token, max_length=self.config.max_length ) self.model.to(self.config.device) # Fix labels tokenizer max_length for BiEncoder models - if (self.config.max_length is not None and - hasattr(self.model, 'data_processor') and - hasattr(self.model.data_processor, 'labels_tokenizer')): + if ( + self.config.max_length is not None + and hasattr(self.model, "data_processor") + and hasattr(self.model.data_processor, "labels_tokenizer") + ): tok = self.model.data_processor.labels_tokenizer if tok.model_max_length > 100000: tok.model_max_length = self.config.max_length @@ -166,20 +166,20 @@ def _setup(self): @property def supports_precomputed_embeddings(self) -> bool: - """Check if model supports precomputed embeddings (BiEncoder)""" - return hasattr(self.model, 'encode_labels') and self.model.config.labels_encoder is not None + """Check if model supports precomputed embeddings (BiEncoder).""" + return hasattr(self.model, "encode_labels") and self.model.config.labels_encoder is not None def get_available_methods(self) -> List[str]: - """Return list of available pipeline methods""" + """Return list of available pipeline methods.""" return [ "extract_entities", "filter_by_length", "deduplicate", "sort_by_position", - "encode_labels" + "encode_labels", ] - def encode_labels(self, labels: List[str], batch_size: int = None) -> torch.Tensor: + def encode_labels(self, labels: List[str], batch_size: int | None = None) -> torch.Tensor: """ Encode labels using GLiNER's native label encoder. @@ -203,7 +203,7 @@ def encode_labels(self, labels: List[str], batch_size: int = None) -> torch.Tens return self.model.encode_labels(labels, batch_size=batch_size) def extract_entities(self, text: str) -> List[L1Entity]: - """Extract named entities from text using GLiNER""" + """Extract named entities from text using GLiNER.""" if not self.config.labels: return [] @@ -215,7 +215,7 @@ def extract_entities(self, text: str) -> List[L1Entity]: self.config.labels, threshold=self.config.threshold, flat_ner=self.config.flat_ner, - multi_label=self.config.multi_label + multi_label=self.config.multi_label, ) else: raw_entities = self.model.predict_entities( @@ -223,7 +223,7 @@ def extract_entities(self, text: str) -> List[L1Entity]: self.config.labels, threshold=self.config.threshold, flat_ner=self.config.flat_ner, - multi_label=self.config.multi_label + multi_label=self.config.multi_label, ) entities = [] @@ -234,33 +234,31 @@ def extract_entities(self, text: str) -> List[L1Entity]: if span in seen_spans: continue - left_context, right_context = self._get_context( - text, ent["start"], ent["end"] - ) + left_context, right_context = self._get_context(text, ent["start"], ent["end"]) - entities.append(L1Entity( - text=ent["text"], - label=ent.get("label", None), # Safe access in case label is missing - start=ent["start"], - end=ent["end"], - left_context=left_context, - right_context=right_context - )) + entities.append( + L1Entity( + text=ent["text"], + label=ent.get("label", None), # Safe access in case label is missing + start=ent["start"], + end=ent["end"], + left_context=left_context, + right_context=right_context, + ) + ) seen_spans.add(span) return entities def filter_by_length( - self, - entities: List[L1Entity], - min_length: int = None + self, entities: List[L1Entity], min_length: int | None = None ) -> List[L1Entity]: - """Filter entities by minimum text length""" + """Filter entities by minimum text length.""" min_len = min_length if min_length is not None else self.config.min_entity_length return [e for e in entities if len(e.text) >= min_len] def deduplicate(self, entities: List[L1Entity]) -> List[L1Entity]: - """Remove duplicate entities by span""" + """Remove duplicate entities by span.""" seen_spans = set() unique = [] @@ -273,15 +271,15 @@ def deduplicate(self, entities: List[L1Entity]) -> List[L1Entity]: return unique def sort_by_position(self, entities: List[L1Entity]) -> List[L1Entity]: - """Sort entities by start position""" + """Sort entities by start position.""" return sorted(entities, key=lambda x: x.start) def _get_context(self, text: str, start: int, end: int) -> tuple[str, str]: - """Extract left and right context for entity""" + """Extract left and right context for entity.""" left_start = max(0, start - self.config.max_left_context) left_context = text[left_start:start].strip() right_end = min(len(text), end + self.config.max_right_context) right_context = text[end:right_end].strip() - return left_context, right_context \ No newline at end of file + return left_context, right_context diff --git a/src/glinker/l1/models.py b/src/glinker/l1/models.py index 7ab9b8c..b9c1c21 100644 --- a/src/glinker/l1/models.py +++ b/src/glinker/l1/models.py @@ -1,6 +1,8 @@ +from typing import List + from pydantic import Field -from typing import List, Optional -from glinker.core.base import BaseConfig, BaseInput, BaseOutput + +from glinker.core.base import BaseInput, BaseConfig, BaseOutput class L1Config(BaseConfig): @@ -14,21 +16,18 @@ class L1Config(BaseConfig): class L1GlinerConfig(L1Config): - """Configuration for GLiNER-based L1 entity extraction""" + """Configuration for GLiNER-based L1 entity extraction.""" + model: str = Field(..., description="GLiNER model identifier (overrides spaCy model)") labels: List[str] = Field(..., description="Fixed list of labels for entity extraction") - token: Optional[str] = Field(None, description="HuggingFace token") + token: str | None = Field(None, description="HuggingFace token") threshold: float = Field(0.3, description="Confidence threshold for entity extraction") flat_ner: bool = Field(True, description="Use flat NER (no nested entities)") multi_label: bool = Field(False, description="Allow multiple labels per entity") use_precomputed_embeddings: bool = Field( - False, - description="Use precomputed label embeddings (BiEncoder only)" - ) - max_length: Optional[int] = Field( - None, - description="Maximum sequence length for tokenization" + False, description="Use precomputed label embeddings (BiEncoder only)" ) + max_length: int | None = Field(None, description="Maximum sequence length for tokenization") class L1Input(BaseInput): @@ -37,7 +36,7 @@ class L1Input(BaseInput): class L1Entity(BaseOutput): text: str = Field(..., description="Extracted mention text") - label: Optional[str] = Field(None, description="Entity label/type") + label: str | None = Field(None, description="Entity label/type") start: int = Field(..., description="Start position") end: int = Field(..., description="End position") left_context: str = Field(..., description="Left context") @@ -45,4 +44,4 @@ class L1Entity(BaseOutput): class L1Output(BaseOutput): - entities: list[list[L1Entity]] = Field(..., description="Extracted entities per text") \ No newline at end of file + entities: list[list[L1Entity]] = Field(..., description="Extracted entities per text") diff --git a/src/glinker/l1/processor.py b/src/glinker/l1/processor.py index 98427f2..e9861ee 100644 --- a/src/glinker/l1/processor.py +++ b/src/glinker/l1/processor.py @@ -1,36 +1,29 @@ -from typing import Any, List, Union +from typing import Any, List + from glinker.core.base import BaseProcessor from glinker.core.registry import processor_registry -from .models import L1Config, L1GlinerConfig, L1Input, L1Output + +from .models import L1Input, L1Config, L1Output, L1GlinerConfig from .component import L1SpacyComponent, L1GlinerComponent class L1SpacyProcessor(BaseProcessor[L1Config, L1Input, L1Output]): - """Optimized batch processor using spaCy pipe""" + """Optimized batch processor using spaCy pipe.""" def __init__( self, config: L1Config, component: L1SpacyComponent, - pipeline: list[tuple[str, dict[str, Any]]] = None + pipeline: list[tuple[str, dict[str, Any]]] | None = None, ): super().__init__(config, component, pipeline) self._validate_pipeline() - + def _default_pipeline(self) -> list[tuple[str, dict[str, Any]]]: - return [ - ("extract_entities", {}), - ("deduplicate", {}), - ("sort_by_position", {}) - ] - - def __call__( - self, - texts: List[str] = None, - input_data: L1Input = None - ) -> L1Output: - """Process batch using spaCy's efficient pipe""" - + return [("extract_entities", {}), ("deduplicate", {}), ("sort_by_position", {})] + + def __call__(self, texts: List[str] | None = None, input_data: L1Input = None) -> L1Output: + """Process batch using spaCy's efficient pipe.""" # Support both direct texts and L1Input if texts is not None: texts_to_process = texts @@ -38,84 +31,72 @@ def __call__( texts_to_process = input_data.texts else: raise ValueError("Either 'texts' or 'input_data' must be provided") - + results = [] - + for doc, original_text in zip( - self.component.nlp.pipe( - texts_to_process, - batch_size=self.config.batch_size - ), - texts_to_process + self.component.nlp.pipe(texts_to_process, batch_size=self.config.batch_size), + texts_to_process, ): entities = self._extract_from_doc(doc, original_text) - + pipeline_rest = [ - (method, kwargs) - for method, kwargs in self.pipeline - if method != "extract_entities" + (method, kwargs) for method, kwargs in self.pipeline if method != "extract_entities" ] - + entities = self._execute_pipeline(entities, pipeline_rest) results.append(entities) - + return L1Output(entities=results) - + def _extract_from_doc(self, doc, text: str) -> list: - """Extract entities from already processed doc""" + """Extract entities from already processed doc.""" from .models import L1Entity - + entities = [] for ent in doc.ents: left_context, right_context = self.component._get_context( text, ent.start_char, ent.end_char ) - - entities.append(L1Entity( - text=ent.text, - start=ent.start_char, - end=ent.end_char, - left_context=left_context, - right_context=right_context - )) - + + entities.append( + L1Entity( + text=ent.text, + start=ent.start_char, + end=ent.end_char, + left_context=left_context, + right_context=right_context, + ) + ) + return entities @processor_registry.register("l1_spacy") -def create_l1_spacy_processor(config_dict: dict, pipeline: list = None) -> L1SpacyProcessor: - """Factory: creates component + batch processor""" +def create_l1_spacy_processor(config_dict: dict, pipeline: list | None = None) -> L1SpacyProcessor: + """Factory: creates component + batch processor.""" config = L1Config(**config_dict) component = L1SpacyComponent(config) return L1SpacyProcessor(config, component, pipeline) class L1GlinerProcessor(BaseProcessor[L1GlinerConfig, L1Input, L1Output]): - """GLiNER-based batch processor for L1 entity extraction""" + """GLiNER-based batch processor for L1 entity extraction.""" def __init__( self, config: L1GlinerConfig, component: L1GlinerComponent, - pipeline: list[tuple[str, dict[str, Any]]] = None + pipeline: list[tuple[str, dict[str, Any]]] | None = None, ): super().__init__(config, component, pipeline) self._validate_pipeline() def _default_pipeline(self) -> list[tuple[str, dict[str, Any]]]: - return [ - ("extract_entities", {}), - ("deduplicate", {}), - ("sort_by_position", {}) - ] - - def __call__( - self, - texts: List[str] = None, - input_data: L1Input = None - ) -> L1Output: - """Process batch of texts using GLiNER""" + return [("extract_entities", {}), ("deduplicate", {}), ("sort_by_position", {})] + def __call__(self, texts: List[str] | None = None, input_data: L1Input = None) -> L1Output: + """Process batch of texts using GLiNER.""" # Support both direct texts and L1Input if texts is not None: texts_to_process = texts @@ -133,9 +114,7 @@ def __call__( # Apply rest of pipeline (skip extract_entities as already done) pipeline_rest = [ - (method, kwargs) - for method, kwargs in self.pipeline - if method != "extract_entities" + (method, kwargs) for method, kwargs in self.pipeline if method != "extract_entities" ] entities = self._execute_pipeline(entities, pipeline_rest) @@ -145,8 +124,10 @@ def __call__( @processor_registry.register("l1_gliner") -def create_l1_gliner_processor(config_dict: dict, pipeline: list = None) -> L1GlinerProcessor: - """Factory: creates component + GLiNER processor""" +def create_l1_gliner_processor( + config_dict: dict, pipeline: list | None = None +) -> L1GlinerProcessor: + """Factory: creates component + GLiNER processor.""" config = L1GlinerConfig(**config_dict) component = L1GlinerComponent(config) - return L1GlinerProcessor(config, component, pipeline) \ No newline at end of file + return L1GlinerProcessor(config, component, pipeline) diff --git a/src/glinker/l2/__init__.py b/src/glinker/l2/__init__.py index 8ea7da4..7d046f4 100644 --- a/src/glinker/l2/__init__.py +++ b/src/glinker/l2/__init__.py @@ -1,19 +1,26 @@ -from .models import L2Config, L2Input, L2Output, LayerConfig, FuzzyConfig, DatabaseRecord -from .component import DatabaseChainComponent, DatabaseLayer, DictLayer, RedisLayer, ElasticsearchLayer, PostgresLayer +from .models import L2Input, L2Config, L2Output, FuzzyConfig, LayerConfig, DatabaseRecord +from .component import ( + DictLayer, + RedisLayer, + DatabaseLayer, + PostgresLayer, + ElasticsearchLayer, + DatabaseChainComponent, +) from .processor import L2Processor __all__ = [ - "L2Config", - "L2Input", - "L2Output", - "LayerConfig", - "FuzzyConfig", - "DatabaseRecord", "DatabaseChainComponent", "DatabaseLayer", + "DatabaseRecord", "DictLayer", - "RedisLayer", "ElasticsearchLayer", + "FuzzyConfig", + "L2Config", + "L2Input", + "L2Output", + "L2Processor", + "LayerConfig", "PostgresLayer", - "L2Processor" -] \ No newline at end of file + "RedisLayer", +] diff --git a/src/glinker/l2/component.py b/src/glinker/l2/component.py index 23e1cd5..e1cf05e 100644 --- a/src/glinker/l2/component.py +++ b/src/glinker/l2/component.py @@ -1,20 +1,22 @@ +import json from abc import ABC, abstractmethod -from typing import List, Dict, Any, Set, Union +from typing import Any, Set, Dict, List from pathlib import Path + import redis -import json -from elasticsearch import Elasticsearch -from elasticsearch.helpers import bulk as es_bulk import psycopg2 +from elasticsearch import Elasticsearch from psycopg2.extras import RealDictCursor, execute_batch +from elasticsearch.helpers import bulk as es_bulk from glinker.core.base import BaseComponent -from .models import L2Config, LayerConfig, FuzzyConfig, DatabaseRecord + +from .models import L2Config, FuzzyConfig, LayerConfig, DatabaseRecord class DatabaseLayer(ABC): - """Base class for all database layers""" - + """Base class for all database layers.""" + def __init__(self, config: LayerConfig): self.config = config self.priority = config.priority @@ -24,207 +26,205 @@ def __init__(self, config: LayerConfig): self.field_mapping = config.field_mapping self.fuzzy_config = config.fuzzy or FuzzyConfig() self._setup() - + @abstractmethod def _setup(self): - """Initialize layer resources""" + """Initialize layer resources.""" pass - + def normalize_query(self, query: str) -> str: - """Normalize query for search""" + """Normalize query for search.""" return query.lower().strip() - + @abstractmethod def search(self, query: str) -> List[DatabaseRecord]: - """Exact search""" + """Exact search.""" pass - + @abstractmethod def search_fuzzy(self, query: str) -> List[DatabaseRecord]: - """Fuzzy search""" + """Fuzzy search.""" pass - + def supports_fuzzy(self) -> bool: - """Check if layer supports fuzzy search""" + """Check if layer supports fuzzy search.""" return self.fuzzy_config is not None - + @abstractmethod def write_cache(self, key: str, records: List[DatabaseRecord], ttl: int): - """Write records to cache""" + """Write records to cache.""" pass - + @abstractmethod def is_available(self) -> bool: - """Check if layer is available""" + """Check if layer is available.""" pass - + @abstractmethod - def load_bulk(self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000) -> int: - """Bulk load entities""" + def load_bulk( + self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000 + ) -> int: + """Bulk load entities.""" pass - + def clear(self): - """Clear all data in layer""" + """Clear all data in layer.""" pass def count(self) -> int: - """Count entities in layer""" + """Count entities in layer.""" return 0 def get_all_entities(self) -> List[DatabaseRecord]: - """Get all entities from layer (for precompute)""" + """Get all entities from layer (for precompute).""" return [] def update_embeddings( - self, - entity_ids: List[str], - embeddings: List[List[float]], - model_id: str + self, entity_ids: List[str], embeddings: List[List[float]], model_id: str ) -> int: - """Update embeddings for entities""" + """Update embeddings for entities.""" return 0 def map_to_record(self, raw_data: Dict[str, Any]) -> DatabaseRecord: - """Map raw data to DatabaseRecord using field_mapping""" + """Map raw data to DatabaseRecord using field_mapping.""" mapped = {} for standard_field, db_field in self.field_mapping.items(): if db_field in raw_data: mapped[standard_field] = raw_data[db_field] # Handle embedding fields directly (not in field_mapping) - if 'embedding' in raw_data: - mapped['embedding'] = raw_data['embedding'] - if 'embedding_model_id' in raw_data: - mapped['embedding_model_id'] = raw_data['embedding_model_id'] + if "embedding" in raw_data: + mapped["embedding"] = raw_data["embedding"] + if "embedding_model_id" in raw_data: + mapped["embedding_model_id"] = raw_data["embedding_model_id"] - mapped['source'] = self.config.type + mapped["source"] = self.config.type return DatabaseRecord(**mapped) class DictLayer(DatabaseLayer): - """Simple dict-based storage for small entity sets (<5000)""" - + """Simple dict-based storage for small entity sets (<5000).""" + def _setup(self): self._storage: Dict[str, DatabaseRecord] = {} self._label_index: Dict[str, str] = {} self._alias_index: Dict[str, Set[str]] = {} - + def search(self, query: str) -> List[DatabaseRecord]: - """Fast O(1) exact search using indexes""" + """Fast O(1) exact search using indexes.""" query_key = self.normalize_query(query) results = [] seen = set() - + # Label lookup if query_key in self._label_index: eid = self._label_index[query_key] results.append(self._storage[eid]) seen.add(eid) - + # Alias lookup if query_key in self._alias_index: for eid in self._alias_index[query_key]: if eid not in seen: results.append(self._storage[eid]) seen.add(eid) - + return results - + def search_fuzzy(self, query: str) -> List[DatabaseRecord]: - """Simple fuzzy search for small datasets (O(n) is fine for <5000 entities)""" + """Simple fuzzy search for small datasets (O(n) is fine for <5000 entities).""" try: from rapidfuzz import fuzz except ImportError: print("[WARN DictLayer] rapidfuzz not installed, fuzzy search disabled") return [] - + query_key = self.normalize_query(query) results = [] - + # Check prefix requirement if self.fuzzy_config.prefix_length > 0: - prefix = query_key[:self.fuzzy_config.prefix_length] - + prefix = query_key[: self.fuzzy_config.prefix_length] + for entity in self._storage.values(): # Check label label_key = entity.label.lower() - + if self.fuzzy_config.prefix_length > 0: if not label_key.startswith(prefix): continue - + similarity = fuzz.ratio(query_key, label_key) / 100.0 if similarity >= self.fuzzy_config.min_similarity: results.append((entity, similarity)) continue - + # Check aliases for alias in entity.aliases: alias_key = alias.lower() if self.fuzzy_config.prefix_length > 0: if not alias_key.startswith(prefix): continue - + sim = fuzz.ratio(query_key, alias_key) / 100.0 if sim >= self.fuzzy_config.min_similarity: results.append((entity, sim)) break - + # Sort by similarity results.sort(key=lambda x: x[1], reverse=True) return [r[0] for r in results] - + def write_cache(self, key: str, records: List[DatabaseRecord], ttl: int): - """Write is same as load_bulk for dict layer""" + """Write is same as load_bulk for dict layer.""" self.load_bulk(records, overwrite=True) - - def load_bulk(self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000) -> int: - """Bulk load entities with indexing""" + + def load_bulk( + self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000 + ) -> int: + """Bulk load entities with indexing.""" count = 0 for entity in entities: entity_id = entity.entity_id - + if not overwrite and entity_id in self._storage: continue - + # Store entity self._storage[entity_id] = entity - + # Index by label label_key = entity.label.lower() self._label_index[label_key] = entity_id - + # Index by aliases for alias in entity.aliases: alias_key = alias.lower() if alias_key not in self._alias_index: self._alias_index[alias_key] = set() self._alias_index[alias_key].add(entity_id) - + count += 1 return count - + def clear(self): - """Clear all data""" + """Clear all data.""" self._storage.clear() self._label_index.clear() self._alias_index.clear() - + def count(self) -> int: - """Count entities""" + """Count entities.""" return len(self._storage) def get_all_entities(self) -> List[DatabaseRecord]: - """Get all entities from storage""" + """Get all entities from storage.""" return list(self._storage.values()) def update_embeddings( - self, - entity_ids: List[str], - embeddings: List[List[float]], - model_id: str + self, entity_ids: List[str], embeddings: List[List[float]], model_id: str ) -> int: - """Update embeddings for entities""" + """Update embeddings for entities.""" count = 0 for eid, emb in zip(entity_ids, embeddings): if eid in self._storage: @@ -234,113 +234,115 @@ def update_embeddings( return count def is_available(self) -> bool: - """Dict layer is always available""" + """Dict layer is always available.""" return True class RedisLayer(DatabaseLayer): - """Redis cache layer""" - + """Redis cache layer.""" + def _setup(self): self.client = redis.Redis( - host=self.config.config.get('host', 'localhost'), - port=self.config.config.get('port', 6379), - db=self.config.config.get('db', 0), - password=self.config.config.get('password'), - decode_responses=False + host=self.config.config.get("host", "localhost"), + port=self.config.config.get("port", 6379), + db=self.config.config.get("db", 0), + password=self.config.config.get("password"), + decode_responses=False, ) - + def supports_fuzzy(self) -> bool: return False - + def search(self, query: str) -> List[DatabaseRecord]: query = self.normalize_query(query) key = f"entity:{query}" - + try: data = self.client.get(key) if data: if isinstance(data, bytes): - data = data.decode('utf-8') - + data = data.decode("utf-8") + records_data = json.loads(data) - + if isinstance(records_data, list): results = [] for r in records_data: if isinstance(r, dict): - r['source'] = 'redis' + r["source"] = "redis" results.append(DatabaseRecord(**r)) else: results.append(r) return results - + elif isinstance(records_data, dict): - records_data['source'] = 'redis' + records_data["source"] = "redis" return [DatabaseRecord(**records_data)] - + except Exception as e: print(f"[ERROR Redis] Search error: {e}") - + return [] - + def search_fuzzy(self, query: str) -> List[DatabaseRecord]: return [] - + def write_cache(self, key: str, records: List[DatabaseRecord], ttl: int): key = self.normalize_query(key) cache_key = f"entity:{key}" - + try: data = json.dumps([r.dict() for r in records]) self.client.setex(cache_key, ttl, data) except Exception as e: print(f"[ERROR Redis] Write error: {e}") - - def load_bulk(self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000) -> int: - """Bulk load to Redis""" + + def load_bulk( + self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000 + ) -> int: + """Bulk load to Redis.""" count = 0 pipe = self.client.pipeline() - + for entity in entities: # Prepare data entity_data = entity.dict() data_json = json.dumps(entity_data) - + # Store by label label_key = f"entity:{entity.label.lower()}" if overwrite or not self.client.exists(label_key): pipe.setex(label_key, self.ttl, data_json) count += 1 - + # Store by aliases for alias in entity.aliases: alias_key = f"entity:{alias.lower()}" if overwrite or not self.client.exists(alias_key): pipe.setex(alias_key, self.ttl, data_json) - + # Execute in batches if len(pipe) >= batch_size: pipe.execute() pipe = self.client.pipeline() - + # Execute remaining if len(pipe) > 0: pipe.execute() - + return count - + def clear(self): - """Clear all entity keys""" + """Clear all entity keys.""" for key in self.client.scan_iter(match="entity:*"): self.client.delete(key) - + def count(self) -> int: - """Count entity keys""" + """Count entity keys.""" return sum(1 for _ in self.client.scan_iter(match="entity:*")) def get_all_entities(self) -> List[DatabaseRecord]: - """Get all entities from Redis (scans all entity:* keys)""" + """Get all entities from Redis (scans all entity:* keys).""" entities = [] seen_ids = set() @@ -349,32 +351,29 @@ def get_all_entities(self) -> List[DatabaseRecord]: data = self.client.get(key) if data: if isinstance(data, bytes): - data = data.decode('utf-8') + data = data.decode("utf-8") record_data = json.loads(data) if isinstance(record_data, dict): - if record_data.get('entity_id') not in seen_ids: - record_data['source'] = 'redis' + if record_data.get("entity_id") not in seen_ids: + record_data["source"] = "redis" entities.append(DatabaseRecord(**record_data)) - seen_ids.add(record_data.get('entity_id')) + seen_ids.add(record_data.get("entity_id")) elif isinstance(record_data, list): for r in record_data: - if r.get('entity_id') not in seen_ids: - r['source'] = 'redis' + if r.get("entity_id") not in seen_ids: + r["source"] = "redis" entities.append(DatabaseRecord(**r)) - seen_ids.add(r.get('entity_id')) - except Exception as e: + seen_ids.add(r.get("entity_id")) + except Exception: continue return entities def update_embeddings( - self, - entity_ids: List[str], - embeddings: List[List[float]], - model_id: str + self, entity_ids: List[str], embeddings: List[List[float]], model_id: str ) -> int: - """Update embeddings in Redis entities""" + """Update embeddings in Redis entities.""" count = 0 id_to_embedding = dict(zip(entity_ids, embeddings)) @@ -385,28 +384,28 @@ def update_embeddings( continue if isinstance(data, bytes): - data = data.decode('utf-8') + data = data.decode("utf-8") record_data = json.loads(data) updated = False if isinstance(record_data, dict): - if record_data.get('entity_id') in id_to_embedding: - record_data['embedding'] = id_to_embedding[record_data['entity_id']] - record_data['embedding_model_id'] = model_id + if record_data.get("entity_id") in id_to_embedding: + record_data["embedding"] = id_to_embedding[record_data["entity_id"]] + record_data["embedding_model_id"] = model_id updated = True elif isinstance(record_data, list): for r in record_data: - if r.get('entity_id') in id_to_embedding: - r['embedding'] = id_to_embedding[r['entity_id']] - r['embedding_model_id'] = model_id + if r.get("entity_id") in id_to_embedding: + r["embedding"] = id_to_embedding[r["entity_id"]] + r["embedding_model_id"] = model_id updated = True if updated: self.client.setex(key, self.ttl, json.dumps(record_data)) count += 1 - except Exception as e: + except Exception: continue return count @@ -420,15 +419,14 @@ def is_available(self) -> bool: class ElasticsearchLayer(DatabaseLayer): - """Elasticsearch full-text search layer""" + """Elasticsearch full-text search layer.""" def _setup(self): self.client = Elasticsearch( - self.config.config['hosts'], - api_key=self.config.config.get('api_key') + self.config.config["hosts"], api_key=self.config.config.get("api_key") ) - self.index_name = self.config.config['index_name'] - self.popularity_boost = self.config.config.get('popularity_boost', False) + self.index_name = self.config.config["index_name"] + self.popularity_boost = self.config.config.get("popularity_boost", False) def _build_query(self, match_query: dict) -> dict: """Wrap a match query with optional popularity boosting. @@ -447,15 +445,11 @@ def _build_query(self, match_query: dict) -> dict: "query": { "function_score": { "query": match_query, - "field_value_factor": { - "field": "popularity", - "modifier": "ln2p", - "missing": 1 - }, - "boost_mode": "multiply" + "field_value_factor": {"field": "popularity", "modifier": "ln2p", "missing": 1}, + "boost_mode": "multiply", } }, - "size": 50 + "size": 50, } def search(self, query: str) -> List[DatabaseRecord]: @@ -466,12 +460,12 @@ def search(self, query: str) -> List[DatabaseRecord]: "multi_match": { "query": query, "fields": ["label^2", "aliases^1.5", "description"], - "type": "best_fields" + "type": "best_fields", } } body = self._build_query(match_query) response = self.client.search(index=self.index_name, body=body) - return self._process_hits(response['hits']['hits']) + return self._process_hits(response["hits"]["hits"]) except Exception as e: print(f"[ERROR ES] Search error: {e}") return [] @@ -487,129 +481,115 @@ def search_fuzzy(self, query: str) -> List[DatabaseRecord]: "fields": ["label^2", "aliases^1.5", "description"], "fuzziness": fuzzy_distance, "prefix_length": self.fuzzy_config.prefix_length, - "max_expansions": 50 + "max_expansions": 50, } } body = self._build_query(match_query) response = self.client.search(index=self.index_name, body=body) - return self._process_hits(response['hits']['hits']) + return self._process_hits(response["hits"]["hits"]) except Exception as e: print(f"[ERROR ES] Fuzzy error: {e}") return [] - + def _process_hits(self, hits: List[Dict]) -> List[DatabaseRecord]: records = [] for hit in hits: - source = hit['_source'] - source['_id'] = hit['_id'] - source['source'] = 'elasticsearch' + source = hit["_source"] + source["_id"] = hit["_id"] + source["source"] = "elasticsearch" record = self.map_to_record(source) records.append(record) return records - + def write_cache(self, key: str, records: List[DatabaseRecord], ttl: int): if not records: return - + try: actions = [] for record in records: doc = self._map_from_record(record) - actions.append({ - "_index": self.index_name, - "_id": record.entity_id, - "_source": doc - }) - + actions.append({"_index": self.index_name, "_id": record.entity_id, "_source": doc}) + if actions: es_bulk(self.client, actions) self.client.indices.refresh(index=self.index_name) except Exception as e: print(f"[ERROR ES] Write error: {e}") - - def load_bulk(self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000) -> int: - """Bulk load to Elasticsearch""" + + def load_bulk( + self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000 + ) -> int: + """Bulk load to Elasticsearch.""" actions = [] for entity in entities: doc = self._map_from_record(entity) - - action = { - '_index': self.index_name, - '_id': entity.entity_id, - '_source': doc - } - + + action = {"_index": self.index_name, "_id": entity.entity_id, "_source": doc} + if overwrite: - action['_op_type'] = 'index' + action["_op_type"] = "index" else: - action['_op_type'] = 'create' - + action["_op_type"] = "create" + actions.append(action) - - success, failed = es_bulk( - self.client, - actions, - raise_on_error=False, - chunk_size=batch_size + + success, _failed = es_bulk( + self.client, actions, raise_on_error=False, chunk_size=batch_size ) - + self.client.indices.refresh(index=self.index_name) return success - + def _map_from_record(self, record: DatabaseRecord) -> dict: - """Map DatabaseRecord -> ES document using field_mapping""" + """Map DatabaseRecord -> ES document using field_mapping.""" reverse_mapping = {v: k for k, v in self.field_mapping.items()} - + doc = {} for standard_field, value in record.dict().items(): - if standard_field == 'source': + if standard_field == "source": continue - + es_field = reverse_mapping.get(standard_field, standard_field) doc[es_field] = value - + return doc - + def clear(self): - """Delete all documents in index""" + """Delete all documents in index.""" try: - self.client.delete_by_query( - index=self.index_name, - body={"query": {"match_all": {}}} - ) + self.client.delete_by_query(index=self.index_name, body={"query": {"match_all": {}}}) self.client.indices.refresh(index=self.index_name) except Exception as e: print(f"[ERROR ES] Clear error: {e}") - + def count(self) -> int: - """Count documents in index""" + """Count documents in index.""" try: result = self.client.count(index=self.index_name) - return result['count'] + return result["count"] except: return 0 def get_all_entities(self) -> List[DatabaseRecord]: - """Get all entities from Elasticsearch using scroll""" + """Get all entities from Elasticsearch using scroll.""" entities = [] try: # Use scroll API for large datasets response = self.client.search( - index=self.index_name, - body={"query": {"match_all": {}}, "size": 1000}, - scroll='2m' + index=self.index_name, body={"query": {"match_all": {}}, "size": 1000}, scroll="2m" ) - scroll_id = response['_scroll_id'] - hits = response['hits']['hits'] + scroll_id = response["_scroll_id"] + hits = response["hits"]["hits"] while hits: entities.extend(self._process_hits(hits)) - response = self.client.scroll(scroll_id=scroll_id, scroll='2m') - scroll_id = response['_scroll_id'] - hits = response['hits']['hits'] + response = self.client.scroll(scroll_id=scroll_id, scroll="2m") + scroll_id = response["_scroll_id"] + hits = response["hits"]["hits"] # Clear scroll self.client.clear_scroll(scroll_id=scroll_id) @@ -620,31 +600,22 @@ def get_all_entities(self) -> List[DatabaseRecord]: return entities def update_embeddings( - self, - entity_ids: List[str], - embeddings: List[List[float]], - model_id: str + self, entity_ids: List[str], embeddings: List[List[float]], model_id: str ) -> int: - """Update embeddings in Elasticsearch""" + """Update embeddings in Elasticsearch.""" try: actions = [] for eid, emb in zip(entity_ids, embeddings): - actions.append({ - "_op_type": "update", - "_index": self.index_name, - "_id": eid, - "doc": { - "embedding": emb, - "embedding_model_id": model_id + actions.append( + { + "_op_type": "update", + "_index": self.index_name, + "_id": eid, + "doc": {"embedding": emb, "embedding_model_id": model_id}, } - }) + ) - success, failed = es_bulk( - self.client, - actions, - raise_on_error=False, - chunk_size=500 - ) + success, _failed = es_bulk(self.client, actions, raise_on_error=False, chunk_size=500) self.client.indices.refresh(index=self.index_name) return success @@ -661,17 +632,17 @@ def is_available(self) -> bool: class PostgresLayer(DatabaseLayer): - """PostgreSQL database layer""" - + """PostgreSQL database layer.""" + def _setup(self): self.conn = psycopg2.connect( - host=self.config.config['host'], - port=self.config.config.get('port', 5432), - database=self.config.config['database'], - user=self.config.config['user'], - password=self.config.config['password'] + host=self.config.config["host"], + port=self.config.config.get("port", 5432), + database=self.config.config["database"], + user=self.config.config["user"], + password=self.config.config["password"], ) - + cursor = self.conn.cursor() try: cursor.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;") @@ -680,14 +651,14 @@ def _setup(self): print(f"[WARN Postgres] pg_trgm: {e}") finally: cursor.close() - + def search(self, query: str) -> List[DatabaseRecord]: query = self.normalize_query(query) - + try: cursor = self.conn.cursor(cursor_factory=RealDictCursor) sql = """ - SELECT + SELECT e.entity_id, e.label, e.description, @@ -698,8 +669,8 @@ def search(self, query: str) -> List[DatabaseRecord]: LEFT JOIN aliases a ON e.entity_id = a.entity_id WHERE LOWER(e.label) LIKE %s OR EXISTS ( - SELECT 1 FROM aliases a2 - WHERE a2.entity_id = e.entity_id + SELECT 1 FROM aliases a2 + WHERE a2.entity_id = e.entity_id AND LOWER(a2.alias) LIKE %s ) GROUP BY e.entity_id, e.label, e.description, e.entity_type, e.popularity @@ -713,15 +684,15 @@ def search(self, query: str) -> List[DatabaseRecord]: except Exception as e: print(f"[ERROR Postgres] Search error: {e}") return [] - + def search_fuzzy(self, query: str) -> List[DatabaseRecord]: query = self.normalize_query(query) threshold = self.fuzzy_config.min_similarity - + try: cursor = self.conn.cursor(cursor_factory=RealDictCursor) sql = """ - SELECT + SELECT e.entity_id, e.label, e.description, @@ -743,30 +714,31 @@ def search_fuzzy(self, query: str) -> List[DatabaseRecord]: except Exception as e: print(f"[ERROR Postgres] Fuzzy error: {e}") return self.search(query) - + def _process_rows(self, rows: List[Dict]) -> List[DatabaseRecord]: records = [] for row in rows: row_dict = dict(row) - row_dict['source'] = 'postgres' + row_dict["source"] = "postgres" record = self.map_to_record(row_dict) records.append(record) return records - + def write_cache(self, key: str, records: List[DatabaseRecord], ttl: int): pass - - def load_bulk(self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000) -> int: - """Bulk load to Postgres""" + + def load_bulk( + self, entities: List[DatabaseRecord], overwrite: bool = False, batch_size: int = 1000 + ) -> int: + """Bulk load to Postgres.""" cursor = self.conn.cursor() - + try: # Prepare entity data entity_values = [ - (e.entity_id, e.label, e.description, e.entity_type, e.popularity) - for e in entities + (e.entity_id, e.label, e.description, e.entity_type, e.popularity) for e in entities ] - + # Insert entities if overwrite: entity_query = """ @@ -784,44 +756,41 @@ def load_bulk(self, entities: List[DatabaseRecord], overwrite: bool = False, bat VALUES (%s, %s, %s, %s, %s) ON CONFLICT (entity_id) DO NOTHING """ - + execute_batch(cursor, entity_query, entity_values, page_size=batch_size) - + # Prepare alias data alias_values = [] for entity in entities: for alias in entity.aliases: alias_values.append((entity.entity_id, alias)) - + # Delete old aliases if overwrite if overwrite and alias_values: entity_ids = [e.entity_id for e in entities] - cursor.execute( - "DELETE FROM aliases WHERE entity_id = ANY(%s)", - (entity_ids,) - ) - + cursor.execute("DELETE FROM aliases WHERE entity_id = ANY(%s)", (entity_ids,)) + # Insert aliases if alias_values: execute_batch( cursor, "INSERT INTO aliases (entity_id, alias) VALUES (%s, %s) ON CONFLICT DO NOTHING", alias_values, - page_size=batch_size + page_size=batch_size, ) - + self.conn.commit() return len(entities) - + except Exception as e: self.conn.rollback() print(f"[ERROR Postgres] Load bulk failed: {e}") raise finally: cursor.close() - + def clear(self): - """Clear all data""" + """Clear all data.""" cursor = self.conn.cursor() try: cursor.execute("TRUNCATE entities, aliases CASCADE") @@ -831,9 +800,9 @@ def clear(self): print(f"[ERROR Postgres] Clear error: {e}") finally: cursor.close() - + def count(self) -> int: - """Count entities""" + """Count entities.""" cursor = self.conn.cursor() try: cursor.execute("SELECT COUNT(*) FROM entities") @@ -844,7 +813,7 @@ def count(self) -> int: cursor.close() def get_all_entities(self) -> List[DatabaseRecord]: - """Get all entities from PostgreSQL""" + """Get all entities from PostgreSQL.""" entities = [] try: @@ -867,13 +836,14 @@ def get_all_entities(self) -> List[DatabaseRecord]: for row in cursor.fetchall(): row_dict = dict(row) - row_dict['source'] = 'postgres' + row_dict["source"] = "postgres" # Deserialize embedding from bytes if needed - if row_dict.get('embedding'): + if row_dict.get("embedding"): import pickle - if isinstance(row_dict['embedding'], (bytes, memoryview)): - row_dict['embedding'] = pickle.loads(bytes(row_dict['embedding'])) + + if isinstance(row_dict["embedding"], (bytes, memoryview)): + row_dict["embedding"] = pickle.loads(bytes(row_dict["embedding"])) record = self.map_to_record(row_dict) entities.append(record) @@ -886,12 +856,9 @@ def get_all_entities(self) -> List[DatabaseRecord]: return entities def update_embeddings( - self, - entity_ids: List[str], - embeddings: List[List[float]], - model_id: str + self, entity_ids: List[str], embeddings: List[List[float]], model_id: str ) -> int: - """Update embeddings in PostgreSQL""" + """Update embeddings in PostgreSQL.""" cursor = self.conn.cursor() try: @@ -908,7 +875,7 @@ def update_embeddings( cursor, "UPDATE entities SET embedding = %s, embedding_model_id = %s WHERE entity_id = %s", batch_data, - page_size=500 + page_size=500, ) self.conn.commit() @@ -932,15 +899,15 @@ def is_available(self) -> bool: class DatabaseChainComponent(BaseComponent[L2Config]): - """Multi-layer database chain component""" - + """Multi-layer database chain component.""" + def _setup(self): self.layers: List[DatabaseLayer] = [] - + for layer_config in self.config.layers: if isinstance(layer_config, dict): layer_config = LayerConfig(**layer_config) - + if layer_config.type == "dict": layer = DictLayer(layer_config) elif layer_config.type == "redis": @@ -951,58 +918,58 @@ def _setup(self): layer = PostgresLayer(layer_config) else: raise ValueError(f"Unknown layer type: {layer_config.type}") - + self.layers.append(layer) - + self.layers.sort(key=lambda x: x.priority, reverse=True) # Higher priority checked first - + def get_available_methods(self) -> List[str]: return [ "search", "filter_by_popularity", "deduplicate_candidates", "limit_candidates", - "sort_by_popularity" + "sort_by_popularity", ] - + def search(self, mention: str) -> List[DatabaseRecord]: - """Search through layers with fallback""" + """Search through layers with fallback.""" found_in_layer = None results = [] - + for layer in self.layers: if not layer.is_available(): continue - + layer_results = [] - + for mode in layer.config.search_mode: if mode == "exact": layer_results.extend(layer.search(mention)) elif mode == "fuzzy": if layer.supports_fuzzy(): layer_results.extend(layer.search_fuzzy(mention)) - + if layer_results: layer_results = self.deduplicate_candidates(layer_results) results = layer_results found_in_layer = layer break - + if results and found_in_layer: self._cache_write(mention, results, found_in_layer) - + return results - + def _cache_write(self, query: str, results: List[DatabaseRecord], source_layer: DatabaseLayer): - """Write results to upper layers (higher priority = checked earlier)""" + """Write results to upper layers (higher priority = checked earlier).""" for layer in self.layers: # Skip source layer and all layers with lower priority if layer.priority <= source_layer.priority: continue if not layer.write: continue - + if layer.cache_policy == "always": layer.write_cache(query, results, layer.ttl) elif layer.cache_policy == "miss": @@ -1013,11 +980,13 @@ def _cache_write(self, query: str, results: List[DatabaseRecord], source_layer: existing = layer.search(query) if existing: layer.write_cache(query, results, layer.ttl) - - def filter_by_popularity(self, records: List[DatabaseRecord], min_popularity: int = None) -> List[DatabaseRecord]: + + def filter_by_popularity( + self, records: List[DatabaseRecord], min_popularity: int | None = None + ) -> List[DatabaseRecord]: threshold = min_popularity if min_popularity is not None else self.config.min_popularity return [r for r in records if r.popularity >= threshold] - + def deduplicate_candidates(self, records: List[DatabaseRecord]) -> List[DatabaseRecord]: seen = set() unique = [] @@ -1026,20 +995,22 @@ def deduplicate_candidates(self, records: List[DatabaseRecord]) -> List[Database unique.append(record) seen.add(record.entity_id) return unique - - def limit_candidates(self, records: List[DatabaseRecord], limit: int = None) -> List[DatabaseRecord]: + + def limit_candidates( + self, records: List[DatabaseRecord], limit: int | None = None + ) -> List[DatabaseRecord]: max_cands = limit if limit is not None else self.config.max_candidates return records[:max_cands] - + def sort_by_popularity(self, records: List[DatabaseRecord]) -> List[DatabaseRecord]: return sorted(records, key=lambda x: x.popularity, reverse=True) - + def load_entities( self, - source: Union[str, Path, List[Dict[str, Any]], Dict[str, Dict[str, Any]]], - target_layers: List[str] = None, + source: str | Path | List[Dict[str, Any]] | Dict[str, Dict[str, Any]], + target_layers: List[str] | None = None, batch_size: int = 1000, - overwrite: bool = False + overwrite: bool = False, ) -> Dict[str, int]: """ Load entities from JSONL file, list of dicts, or dict. @@ -1068,10 +1039,7 @@ def load_entities( for eid, data in source.items() ] elif isinstance(source, list): - entities = [ - DatabaseRecord(**e) if isinstance(e, dict) else e - for e in source - ] + entities = [DatabaseRecord(**e) if isinstance(e, dict) else e for e in source] else: raise TypeError(f"Expected file path, list, or dict; got {type(source)}") @@ -1085,7 +1053,7 @@ def load_entities( def load_records( self, entities: List[DatabaseRecord], - target_layers: List[str] = None, + target_layers: List[str] | None = None, batch_size: int = 1000, overwrite: bool = False, ) -> Dict[str, int]: @@ -1120,10 +1088,10 @@ def load_records( return results @staticmethod - def _parse_jsonl(filepath: Union[str, Path]) -> List[DatabaseRecord]: + def _parse_jsonl(filepath: str | Path) -> List[DatabaseRecord]: """Parse JSONL file into DatabaseRecord list.""" entities = [] - with open(filepath, 'r', encoding='utf-8') as f: + with open(filepath, encoding="utf-8") as f: for line_num, line in enumerate(f, 1): line = line.strip() if not line: @@ -1135,19 +1103,19 @@ def _parse_jsonl(filepath: Union[str, Path]) -> List[DatabaseRecord]: print(f"[WARN] Line {line_num} parse error: {e}") continue return entities - - def clear_layers(self, layer_names: List[str] = None): - """Clear all entities in specified layers""" + + def clear_layers(self, layer_names: List[str] | None = None): + """Clear all entities in specified layers.""" for layer in self.layers: if layer_names and layer.config.type not in layer_names: continue - + print(f"Clearing {layer.config.type}...") layer.clear() - print(f"✓ Cleared") - + print("✓ Cleared") + def get_all_entities(self) -> List[DatabaseRecord]: - """Get all entities from all available layers (deduplicated)""" + """Get all entities from all available layers (deduplicated).""" all_entities = [] for layer in self.layers: if layer.is_available(): @@ -1155,7 +1123,7 @@ def get_all_entities(self) -> List[DatabaseRecord]: return self.deduplicate_candidates(all_entities) def count_entities(self) -> Dict[str, int]: - """Count entities in each layer""" + """Count entities in each layer.""" counts = {} for layer in self.layers: counts[layer.config.type] = layer.count() @@ -1166,8 +1134,8 @@ def precompute_embeddings( encoder_fn, template: str, model_id: str, - target_layers: List[str] = None, - batch_size: int = 32 + target_layers: List[str] | None = None, + batch_size: int = 32, ) -> Dict[str, int]: """ Precompute embeddings for all entities in specified layers. @@ -1219,13 +1187,13 @@ def precompute_embeddings( # Encode in batches all_embeddings = [] for i in tqdm(range(0, len(labels), batch_size), desc="Encoding"): - batch_labels = labels[i:i + batch_size] + batch_labels = labels[i : i + batch_size] batch_embeddings = encoder_fn(batch_labels) # Convert to list if tensor - if hasattr(batch_embeddings, 'tolist'): + if hasattr(batch_embeddings, "tolist"): batch_embeddings = batch_embeddings.tolist() - elif hasattr(batch_embeddings, 'cpu'): + elif hasattr(batch_embeddings, "cpu"): batch_embeddings = batch_embeddings.cpu().numpy().tolist() all_embeddings.extend(batch_embeddings) @@ -1238,8 +1206,8 @@ def precompute_embeddings( return results def get_layer(self, layer_type: str) -> DatabaseLayer: - """Get layer by type""" + """Get layer by type.""" for layer in self.layers: if layer.config.type == layer_type: return layer - return None \ No newline at end of file + return None diff --git a/src/glinker/l2/models.py b/src/glinker/l2/models.py index 0a5a499..44be503 100644 --- a/src/glinker/l2/models.py +++ b/src/glinker/l2/models.py @@ -1,39 +1,38 @@ +from typing import Any, Dict, List, Literal + from pydantic import Field, BaseModel -from typing import List, Dict, Any, Optional, Literal -from glinker.core.base import BaseConfig, BaseInput, BaseOutput + +from glinker.core.base import BaseInput, BaseConfig, BaseOutput class DatabaseRecord(BaseModel): """ - Unified format for all database layers + Unified format for all database layers. All layers (Dict, Redis, Elasticsearch, Postgres) use this format. """ + entity_id: str = Field(..., description="Unique entity identifier") label: str = Field(..., description="Primary label/name") aliases: List[str] = Field(default_factory=list, description="Alternative names") description: str = Field(default="", description="Entity description") entity_type: str = Field(default="", description="Entity type/category") popularity: int = Field(default=0, description="Popularity score") - metadata: Dict[str, Any] = Field( - default_factory=dict, - description="Database-specific metadata" - ) + metadata: Dict[str, Any] = Field(default_factory=dict, description="Database-specific metadata") source: str = Field(default="", description="Source layer: dict|redis|elasticsearch|postgres") # Embedding fields for precomputed label embeddings - embedding: Optional[List[float]] = Field( - default=None, - description="Precomputed label embedding vector" + embedding: List[float] | None = Field( + default=None, description="Precomputed label embedding vector" ) - embedding_model_id: Optional[str] = Field( - default=None, - description="Model ID used to compute the embedding" + embedding_model_id: str | None = Field( + default=None, description="Model ID used to compute the embedding" ) class FuzzyConfig(BaseConfig): - """Fuzzy search configuration""" + """Fuzzy search configuration.""" + max_distance: int = Field(2, description="Maximum Levenshtein distance") min_similarity: float = Field(0.3, description="Minimum similarity threshold") n_gram_size: int = Field(3, description="N-gram size for matching") @@ -41,16 +40,16 @@ class FuzzyConfig(BaseConfig): class LayerConfig(BaseConfig): - """Database layer configuration""" + """Database layer configuration.""" + type: str = Field(..., description="Layer type: dict|redis|elasticsearch|postgres") priority: int = Field(..., description="Search priority (0 = highest)") config: Dict[str, Any] = Field(default_factory=dict, description="Layer-specific config") - + search_mode: List[Literal["exact", "fuzzy"]] = Field( - ["exact"], - description="Search methods: ['exact'], ['fuzzy'], or ['exact', 'fuzzy']" + ["exact"], description="Search methods: ['exact'], ['fuzzy'], or ['exact', 'fuzzy']" ) - + write: bool = Field(True, description="Enable write operations") cache_policy: str = Field("always", description="Cache policy: always|miss|hit") ttl: int = Field(3600, description="TTL in seconds (0 = no expiry)") @@ -61,39 +60,44 @@ class LayerConfig(BaseConfig): "aliases": "aliases", "description": "description", "entity_type": "entity_type", - "popularity": "popularity" + "popularity": "popularity", }, - description="Field mapping: DatabaseRecord field -> storage field" + description="Field mapping: DatabaseRecord field -> storage field", + ) + fuzzy: FuzzyConfig | None = Field( + default_factory=FuzzyConfig, description="Fuzzy search config" ) - fuzzy: Optional[FuzzyConfig] = Field(default_factory=FuzzyConfig, description="Fuzzy search config") class EmbeddingConfig(BaseModel): - """Configuration for precomputed label embeddings""" + """Configuration for precomputed label embeddings.""" + enabled: bool = Field(False, description="Enable embedding support") - model_name: Optional[str] = Field(None, description="Model name for encoding labels") + model_name: str | None = Field(None, description="Model name for encoding labels") dim: int = Field(768, description="Embedding dimension") precompute_on_load: bool = Field(False, description="Compute embeddings during load_bulk") batch_size: int = Field(32, description="Batch size for encoding") class L2Config(BaseConfig): - """L2 processor configuration""" + """L2 processor configuration.""" + layers: List[LayerConfig] = Field(..., description="Database layers in priority order") max_candidates: int = Field(30, description="Maximum candidates per mention") min_popularity: int = Field(0, description="Minimum popularity threshold") - embeddings: Optional[EmbeddingConfig] = Field( - default=None, - description="Embedding configuration for precomputed labels" + embeddings: EmbeddingConfig | None = Field( + default=None, description="Embedding configuration for precomputed labels" ) class L2Input(BaseInput): - """L2 processor input""" + """L2 processor input.""" + mentions: List[str] = Field(..., description="List of mentions to search") structure: List[List[str]] = Field(None, description="Optional grouping structure") class L2Output(BaseOutput): - """L2 processor output""" - candidates: List[List[DatabaseRecord]] = Field(..., description="Candidates per mention/group") \ No newline at end of file + """L2 processor output.""" + + candidates: List[List[DatabaseRecord]] = Field(..., description="Candidates per mention/group") diff --git a/src/glinker/l2/processor.py b/src/glinker/l2/processor.py index 109b987..3b8f4da 100644 --- a/src/glinker/l2/processor.py +++ b/src/glinker/l2/processor.py @@ -1,35 +1,34 @@ -from typing import Any, List, Union +from typing import Any, List + from glinker.core.base import BaseProcessor from glinker.core.registry import processor_registry -from .models import L2Config, L2Input, L2Output, DatabaseRecord + +from .models import L2Input, L2Config, L2Output, DatabaseRecord from .component import DatabaseChainComponent class L2Processor(BaseProcessor[L2Config, L2Input, L2Output]): - """Multi-layer database search processor""" + """Multi-layer database search processor.""" def __init__( self, config: L2Config, component: DatabaseChainComponent, - pipeline: list[tuple[str, dict[str, Any]]] = None + pipeline: list[tuple[str, dict[str, Any]]] | None = None, ): super().__init__(config, component, pipeline) self.schema = {} # Will be set by DAG executor from node config def format_label(self, record: DatabaseRecord) -> str: - """Format label using schema template""" - template = self.schema.get('template', '{label}') + """Format label using schema template.""" + template = self.schema.get("template", "{label}") try: return template.format(**record.model_dump()) except KeyError: return record.label def precompute_embeddings( - self, - encoder_fn, - target_layers: List[str] = None, - batch_size: int = 32 + self, encoder_fn, target_layers: List[str] | None = None, batch_size: int = 32 ): """ Precompute embeddings for entities using schema template. @@ -39,17 +38,17 @@ def precompute_embeddings( target_layers: Layer types to update batch_size: Batch size for encoding """ - template = self.schema.get('template', '{label}') - model_id = self.config.embeddings.model_name if self.config.embeddings else 'unknown' + template = self.schema.get("template", "{label}") + model_id = self.config.embeddings.model_name if self.config.embeddings else "unknown" return self.component.precompute_embeddings( encoder_fn=encoder_fn, template=template, model_id=model_id, target_layers=target_layers, - batch_size=batch_size + batch_size=batch_size, ) - + def _default_pipeline(self) -> list[tuple[str, dict[str, Any]]]: return [ ("search", {}), @@ -58,16 +57,16 @@ def _default_pipeline(self) -> list[tuple[str, dict[str, Any]]]: ("sort_by_popularity", {}), ("limit_candidates", {}), ] - + def __call__( self, - mentions: Union[List[str], List[List[Any]], L2Input] = None, - texts: List[str] = None, - structure: List[List[str]] = None, - input_data: L2Input = None + mentions: List[str] | List[List[Any]] | L2Input = None, + texts: List[str] | None = None, + structure: List[List[str]] | None = None, + input_data: L2Input = None, ) -> L2Output: """ - Process mentions and return candidates + Process mentions and return candidates. Supports: - List[str]: flat list of mention strings @@ -75,7 +74,6 @@ def __call__( - L2Input: structured input with mentions and structure - mentions=None: return entire entity database (one copy per text) """ - if input_data is not None: mentions = input_data.mentions structure = input_data.structure @@ -93,56 +91,54 @@ def __call__( if mentions and isinstance(mentions[0], (list, tuple)): # Nested structure: [[entities_text1], [entities_text2], ...] all_candidates = [] - + for text_entities in mentions: text_candidates = [] - + for entity in text_entities: # Extract text from L1Entity or dict mention_text = self._extract_mention_text(entity) - + # Search candidates for this mention candidates = self._execute_pipeline(mention_text, self.pipeline) text_candidates.extend(candidates) - + all_candidates.append(text_candidates) - + return L2Output(candidates=all_candidates) - + # Flat structure: ["mention1", "mention2", ...] else: all_candidates = [] - + for mention in mentions: mention_text = self._extract_mention_text(mention) candidates = self._execute_pipeline(mention_text, self.pipeline) all_candidates.append(candidates) - + if structure: grouped = self._group_by_structure(all_candidates, structure) else: # Flatten all into one group grouped = [self._flatten(all_candidates)] - + return L2Output(candidates=grouped) - + def _extract_mention_text(self, mention: Any) -> str: - """Extract text string from mention (can be L1Entity, dict, or str)""" + """Extract text string from mention (can be L1Entity, dict, or str).""" if isinstance(mention, str): return mention - elif hasattr(mention, 'text'): + elif hasattr(mention, "text"): return mention.text elif isinstance(mention, dict): - return mention.get('text', str(mention)) + return mention.get("text", str(mention)) else: return str(mention) - + def _group_by_structure( - self, - all_candidates: List[List[DatabaseRecord]], - structure: List[List[str]] + self, all_candidates: List[List[DatabaseRecord]], structure: List[List[str]] ) -> List[List[DatabaseRecord]]: - """Group candidates according to structure""" + """Group candidates according to structure.""" grouped = [] idx = 0 for text_mentions in structure: @@ -153,9 +149,9 @@ def _group_by_structure( idx += 1 grouped.append(text_candidates) return grouped - + def _flatten(self, nested: List[List[Any]]) -> List[Any]: - """Flatten nested list""" + """Flatten nested list.""" flat = [] for sublist in nested: flat.extend(sublist) @@ -163,8 +159,8 @@ def _flatten(self, nested: List[List[Any]]) -> List[Any]: @processor_registry.register("l2_chain") -def create_l2_processor(config_dict: dict, pipeline: list = None) -> L2Processor: - """Factory: creates component + processor""" +def create_l2_processor(config_dict: dict, pipeline: list | None = None) -> L2Processor: + """Factory: creates component + processor.""" config = L2Config(**config_dict) component = DatabaseChainComponent(config) - return L2Processor(config, component, pipeline) \ No newline at end of file + return L2Processor(config, component, pipeline) diff --git a/src/glinker/l3/__init__.py b/src/glinker/l3/__init__.py index 15f8604..ecce289 100644 --- a/src/glinker/l3/__init__.py +++ b/src/glinker/l3/__init__.py @@ -1,12 +1,12 @@ -from .models import L3Config, L3Input, L3Output, L3Entity +from .models import L3Input, L3Config, L3Entity, L3Output from .component import L3Component from .processor import L3Processor __all__ = [ + "L3Component", "L3Config", + "L3Entity", "L3Input", "L3Output", - "L3Entity", - "L3Component", "L3Processor", -] \ No newline at end of file +] diff --git a/src/glinker/l3/component.py b/src/glinker/l3/component.py index f753552..00ed15a 100644 --- a/src/glinker/l3/component.py +++ b/src/glinker/l3/component.py @@ -1,27 +1,30 @@ -from typing import Dict, List, Optional +from typing import List + import torch from gliner import GLiNER + from glinker.core.base import BaseComponent + from .models import L3Config, L3Entity class L3Component(BaseComponent[L3Config]): - """GLiNER-based entity linking component""" + """GLiNER-based entity linking component.""" def _setup(self): - """Initialize GLiNER model""" + """Initialize GLiNER model.""" self.model = GLiNER.from_pretrained( - self.config.model_name, - token=self.config.token, - max_length=self.config.max_length + self.config.model_name, token=self.config.token, max_length=self.config.max_length ) self.model.to(self.config.device) # Fix labels tokenizer max_length for BiEncoder models # Some models have model_max_length not properly set (> 10^18) - if (self.config.max_length is not None and - hasattr(self.model, 'data_processor') and - hasattr(self.model.data_processor, 'labels_tokenizer')): + if ( + self.config.max_length is not None + and hasattr(self.model, "data_processor") + and hasattr(self.model.data_processor, "labels_tokenizer") + ): tok = self.model.data_processor.labels_tokenizer if tok.model_max_length > 100000: tok.model_max_length = self.config.max_length @@ -32,8 +35,8 @@ def device(self): @property def supports_precomputed_embeddings(self) -> bool: - """Check if model supports precomputed embeddings (BiEncoder)""" - return hasattr(self.model, 'encode_labels') and self.model.config.labels_encoder is not None + """Check if model supports precomputed embeddings (BiEncoder).""" + return hasattr(self.model, "encode_labels") and self.model.config.labels_encoder is not None def get_available_methods(self) -> List[str]: return [ @@ -42,7 +45,7 @@ def get_available_methods(self) -> List[str]: "encode_labels", "filter_by_score", "sort_by_position", - "deduplicate_entities" + "deduplicate_entities", ] def encode_labels(self, labels: List[str], batch_size: int = 32) -> torch.Tensor: @@ -72,7 +75,7 @@ def predict_with_embeddings( text: str, labels: List[str], embeddings: torch.Tensor, - input_spans: List[List[dict]] = None + input_spans: List[List[dict]] | None = None, ) -> List[L3Entity]: """ Predict entities using pre-computed label embeddings. @@ -91,21 +94,16 @@ def predict_with_embeddings( # Fallback to regular prediction return self.predict_entities(text, labels, input_spans=input_spans) - kwargs = dict( - threshold=self.config.threshold, - flat_ner=self.config.flat_ner, - multi_label=self.config.multi_label, - return_class_probs=True - ) + kwargs = { + "threshold": self.config.threshold, + "flat_ner": self.config.flat_ner, + "multi_label": self.config.multi_label, + "return_class_probs": True, + } if input_spans is not None: kwargs["input_spans"] = input_spans - entities = self.model.predict_with_embeds( - text, - embeddings, - labels, - **kwargs - ) + entities = self.model.predict_with_embeds(text, embeddings, labels, **kwargs) return [ L3Entity( @@ -114,18 +112,15 @@ def predict_with_embeddings( start=e["start"], end=e["end"], score=e["score"], - class_probs=e.get("class_probs") + class_probs=e.get("class_probs"), ) for e in entities ] def predict_entities( - self, - text: str, - labels: List[str], - input_spans: List[List[dict]] = None + self, text: str, labels: List[str], input_spans: List[List[dict]] | None = None ) -> List[L3Entity]: - """Predict entities using GLiNER + """Predict entities using GLiNER. Args: text: Input text @@ -136,20 +131,16 @@ def predict_entities( if not labels: return [] - kwargs = dict( - threshold=self.config.threshold, - flat_ner=self.config.flat_ner, - multi_label=self.config.multi_label, - return_class_probs=True - ) + kwargs = { + "threshold": self.config.threshold, + "flat_ner": self.config.flat_ner, + "multi_label": self.config.multi_label, + "return_class_probs": True, + } if input_spans is not None: kwargs["input_spans"] = input_spans - entities = self.model.predict_entities( - text, - labels, - **kwargs - ) + entities = self.model.predict_entities(text, labels, **kwargs) return [ L3Entity( @@ -158,22 +149,24 @@ def predict_entities( start=e["start"], end=e["end"], score=e["score"], - class_probs=e.get("class_probs") + class_probs=e.get("class_probs"), ) for e in entities ] - - def filter_by_score(self, entities: List[L3Entity], threshold: float = None) -> List[L3Entity]: - """Filter entities by confidence score""" + + def filter_by_score( + self, entities: List[L3Entity], threshold: float | None = None + ) -> List[L3Entity]: + """Filter entities by confidence score.""" threshold = threshold if threshold is not None else self.config.threshold return [e for e in entities if e.score >= threshold] - + def sort_by_position(self, entities: List[L3Entity]) -> List[L3Entity]: - """Sort entities by position in text""" + """Sort entities by position in text.""" return sorted(entities, key=lambda e: e.start) - + def deduplicate_entities(self, entities: List[L3Entity]) -> List[L3Entity]: - """Remove duplicate entities""" + """Remove duplicate entities.""" seen = set() unique = [] for entity in entities: @@ -181,4 +174,4 @@ def deduplicate_entities(self, entities: List[L3Entity]) -> List[L3Entity]: if key not in seen: unique.append(entity) seen.add(key) - return unique \ No newline at end of file + return unique diff --git a/src/glinker/l3/models.py b/src/glinker/l3/models.py index 1a6e9aa..eb53a1f 100644 --- a/src/glinker/l3/models.py +++ b/src/glinker/l3/models.py @@ -1,11 +1,13 @@ +from typing import Any, Dict, List + from pydantic import Field -from typing import Dict, List, Any, Optional -from glinker.core.base import BaseConfig, BaseInput, BaseOutput + +from glinker.core.base import BaseInput, BaseConfig, BaseOutput class L3Config(BaseConfig): model_name: str = Field(...) - token: Optional[str] = Field(None) + token: str | None = Field(None) device: str = Field("cpu") threshold: float = Field(0.5) flat_ner: bool = Field(True) @@ -14,16 +16,12 @@ class L3Config(BaseConfig): # Embedding settings use_precomputed_embeddings: bool = Field( - True, - description="Use precomputed embeddings from L2 candidates if available" - ) - cache_embeddings: bool = Field( - False, - description="Cache computed embeddings back to L2" + True, description="Use precomputed embeddings from L2 candidates if available" ) + cache_embeddings: bool = Field(False, description="Cache computed embeddings back to L2") max_length: int = Field( None, - description="Maximum sequence length for tokenization. Passed to GLiNER.from_pretrained." + description="Maximum sequence length for tokenization. Passed to GLiNER.from_pretrained.", ) @@ -39,10 +37,10 @@ class L3Entity(BaseOutput): start: int end: int score: float - class_probs: Optional[Dict[str, float]] = Field( + class_probs: Dict[str, float] | None = Field( None, description="Per-label class probabilities from GLiNER" ) class L3Output(BaseOutput): - entities: List[List[L3Entity]] = Field(...) \ No newline at end of file + entities: List[List[L3Entity]] = Field(...) diff --git a/src/glinker/l3/processor.py b/src/glinker/l3/processor.py index a935d8f..01e4b8d 100644 --- a/src/glinker/l3/processor.py +++ b/src/glinker/l3/processor.py @@ -1,32 +1,31 @@ -from typing import Any, List, Optional +from typing import Any, List + import torch + from glinker.core.base import BaseProcessor from glinker.core.registry import processor_registry -from .models import L3Config, L3Input, L3Output, L3Entity + +from .models import L3Input, L3Config, L3Entity, L3Output from .component import L3Component class L3Processor(BaseProcessor[L3Config, L3Input, L3Output]): - """GLiNER entity linking processor""" + """GLiNER entity linking processor.""" def __init__( self, config: L3Config, component: L3Component, - pipeline: list[tuple[str, dict[str, Any]]] = None + pipeline: list[tuple[str, dict[str, Any]]] | None = None, ): super().__init__(config, component, pipeline) self._validate_pipeline() self.schema = {} self._l2_processor = None # Will be set by DAG executor for cache write-back - + def _default_pipeline(self) -> list[tuple[str, dict[str, Any]]]: - return [ - ("predict_entities", {}), - ("filter_by_score", {}), - ("sort_by_position", {}) - ] - + return [("predict_entities", {}), ("filter_by_score", {}), ("sort_by_position", {})] + @staticmethod def _build_input_spans(l1_entities_for_text: List[Any]) -> List[List[dict]]: """Convert L1 entities to GLiNER input_spans format. @@ -39,20 +38,22 @@ def _build_input_spans(l1_entities_for_text: List[Any]) -> List[List[dict]]: wrapped in an outer list as expected by GLiNER input_spans. """ spans = [ - {"start": e["start"] if isinstance(e, dict) else e.start, - "end": e["end"] if isinstance(e, dict) else e.end} + { + "start": e["start"] if isinstance(e, dict) else e.start, + "end": e["end"] if isinstance(e, dict) else e.end, + } for e in l1_entities_for_text ] return [spans] def __call__( self, - texts: List[str] = None, - candidates: List[List[Any]] = None, - l1_entities: List[List[Any]] = None, - input_data: L3Input = None + texts: List[str] | None = None, + candidates: List[List[Any]] | None = None, + l1_entities: List[List[Any]] | None = None, + input_data: L3Input = None, ) -> L3Output: - """Process texts with candidate labels + """Process texts with candidate labels. Args: texts: List of input texts @@ -61,7 +62,6 @@ def __call__( so L3 predicts on the same spans extracted in L1 input_data: Alternative L3Input object """ - # Support both direct params and L3Input if texts is not None and candidates is not None: texts_to_process = texts @@ -75,9 +75,8 @@ def __call__( all_entities = [] # Detect shared candidates (all texts use the same list, e.g. simple pipeline) - shared = ( - len(candidates_to_process) > 1 - and all(c is candidates_to_process[0] for c in candidates_to_process[1:]) + shared = len(candidates_to_process) > 1 and all( + c is candidates_to_process[0] for c in candidates_to_process[1:] ) # Pre-compute labels & embeddings once when candidates are shared @@ -89,8 +88,8 @@ def __call__( if shared: ref_candidates = candidates_to_process[0] if self.schema: - shared_labels, shared_label_to_candidate = ( - self._create_gliner_labels_with_mapping(ref_candidates) + shared_labels, shared_label_to_candidate = self._create_gliner_labels_with_mapping( + ref_candidates ) else: shared_labels = [self._extract_label(c) for c in ref_candidates] @@ -122,7 +121,9 @@ def __call__( else: # Create labels from candidates (per-text) if self.schema: - labels, label_to_candidate = self._create_gliner_labels_with_mapping(text_candidates) + labels, label_to_candidate = self._create_gliner_labels_with_mapping( + text_candidates + ) else: labels = [self._extract_label(c) for c in text_candidates] label_to_candidate = {} @@ -134,7 +135,9 @@ def __call__( ) embeddings = None if use_precomputed: - embeddings = self._get_embeddings_tensor(text_candidates, labels, label_to_candidate) + embeddings = self._get_embeddings_tensor( + text_candidates, labels, label_to_candidate + ) if use_precomputed and embeddings is not None: entities = self.component.predict_with_embeddings( @@ -154,19 +157,15 @@ def __call__( entities = method(entities, **kwargs) # Apply ranking if configured - if self.schema.get('ranking'): + if self.schema.get("ranking"): entities = self._rank_entities(entities, text_candidates) all_entities.append(entities) return L3Output(entities=all_entities) - def _can_use_precomputed( - self, - candidates: List[Any], - label_to_candidate: dict - ) -> bool: - """Check if all candidates have compatible precomputed embeddings""" + def _can_use_precomputed(self, candidates: List[Any], label_to_candidate: dict) -> bool: + """Check if all candidates have compatible precomputed embeddings.""" if not candidates: return False @@ -174,29 +173,26 @@ def _can_use_precomputed( for candidate in candidates: # Check if candidate has embedding - embedding = getattr(candidate, 'embedding', None) + embedding = getattr(candidate, "embedding", None) if embedding is None: return False # Check if model matches - model_id = getattr(candidate, 'embedding_model_id', None) + model_id = getattr(candidate, "embedding_model_id", None) if model_id != expected_model: return False return True def _get_embeddings_tensor( - self, - candidates: List[Any], - labels: List[str], - label_to_candidate: dict + self, candidates: List[Any], labels: List[str], label_to_candidate: dict ) -> torch.Tensor: - """Build embeddings tensor from candidates in same order as labels""" + """Build embeddings tensor from candidates in same order as labels.""" embeddings = [] for label in labels: candidate = label_to_candidate.get(label) - if candidate and hasattr(candidate, 'embedding') and candidate.embedding: + if candidate and hasattr(candidate, "embedding") and candidate.embedding: embeddings.append(candidate.embedding) else: # Should not happen if _can_use_precomputed returned True @@ -204,13 +200,8 @@ def _get_embeddings_tensor( return torch.tensor(embeddings, device=self.component.device) - def _cache_embeddings( - self, - candidates: List[Any], - labels: List[str], - label_to_candidate: dict - ): - """Compute and cache embeddings for candidates without them""" + def _cache_embeddings(self, candidates: List[Any], labels: List[str], label_to_candidate: dict): + """Compute and cache embeddings for candidates without them.""" if not self._l2_processor: return @@ -219,7 +210,7 @@ def _cache_embeddings( to_compute_ids = [] for candidate in candidates: - if not getattr(candidate, 'embedding', None): + if not getattr(candidate, "embedding", None): to_compute.append(candidate) to_compute_ids.append(candidate.entity_id) @@ -227,13 +218,13 @@ def _cache_embeddings( return # Format labels for these candidates - template = self.schema.get('template', '{label}') + template = self.schema.get("template", "{label}") compute_labels = [] for candidate in to_compute: try: - if hasattr(candidate, 'model_dump'): + if hasattr(candidate, "model_dump"): formatted = template.format(**candidate.model_dump()) - elif hasattr(candidate, 'dict'): + elif hasattr(candidate, "dict"): formatted = template.format(**candidate.dict()) else: formatted = candidate.label @@ -245,19 +236,17 @@ def _cache_embeddings( embeddings = self.component.encode_labels(compute_labels) # Update L2 layer - if hasattr(self._l2_processor, 'component'): + if hasattr(self._l2_processor, "component"): for layer in self._l2_processor.component.layers: if layer.is_available(): layer.update_embeddings( - to_compute_ids, - embeddings.tolist(), - self.config.model_name + to_compute_ids, embeddings.tolist(), self.config.model_name ) break # Update first available layer - + def _extract_label(self, candidate: Any) -> str: - """Extract label from candidate""" - if hasattr(candidate, 'label'): + """Extract label from candidate.""" + if hasattr(candidate, "label"): return candidate.label return str(candidate) @@ -268,16 +257,16 @@ def _create_gliner_labels_with_mapping(self, candidates: List[Any]) -> tuple: Returns: tuple: (labels: List[str], label_to_candidate: dict) """ - template = self.schema.get('template', '{label}') + template = self.schema.get("template", "{label}") labels = [] label_to_candidate = {} seen = set() for candidate in candidates: try: - if hasattr(candidate, 'model_dump'): + if hasattr(candidate, "model_dump"): cand_dict = candidate.model_dump() - elif hasattr(candidate, 'dict'): + elif hasattr(candidate, "dict"): cand_dict = candidate.dict() elif isinstance(candidate, dict): cand_dict = candidate @@ -295,7 +284,7 @@ def _create_gliner_labels_with_mapping(self, candidates: List[Any]) -> tuple: label_to_candidate[label] = candidate seen.add(label_lower) except (KeyError, AttributeError): - if hasattr(candidate, 'label'): + if hasattr(candidate, "label"): if candidate.label.lower() not in seen: labels.append(candidate.label) label_to_candidate[candidate.label] = candidate @@ -304,33 +293,33 @@ def _create_gliner_labels_with_mapping(self, candidates: List[Any]) -> tuple: return labels, label_to_candidate def _create_gliner_labels(self, candidates: List[Any]) -> List[str]: - """Create GLiNER labels using schema template (legacy, for compatibility)""" + """Create GLiNER labels using schema template (legacy, for compatibility).""" labels, _ = self._create_gliner_labels_with_mapping(candidates) return labels - + def _rank_entities(self, entities: List[L3Entity], candidates: List[Any]) -> List[L3Entity]: - """Re-rank entities using multiple scoring factors""" + """Re-rank entities using multiple scoring factors.""" # Build label to candidate mapping label_to_candidate = {} for c in candidates: - if hasattr(c, 'label'): + if hasattr(c, "label"): label_to_candidate[c.label] = c - if hasattr(c, 'aliases'): + if hasattr(c, "aliases"): for alias in c.aliases: if alias not in label_to_candidate: label_to_candidate[alias] = c - + # Calculate weighted scores for entity in entities: total_score = 0.0 total_weight = 0.0 - - for rank_spec in self.schema['ranking']: - field = rank_spec['field'] - weight = rank_spec['weight'] + + for rank_spec in self.schema["ranking"]: + field = rank_spec["field"] + weight = rank_spec["weight"] total_weight += weight - - if field == 'gliner_score': + + if field == "gliner_score": total_score += entity.score * weight else: candidate = label_to_candidate.get(entity.label) @@ -339,16 +328,16 @@ def _rank_entities(self, entities: List[L3Entity], candidates: List[Any]) -> Lis if isinstance(value, (int, float)): normalized = min(value / 1000000.0, 1.0) total_score += normalized * weight - + if total_weight > 0: entity.score = total_score / total_weight - + return sorted(entities, key=lambda x: x.score, reverse=True) @processor_registry.register("l3_batch") -def create_l3_processor(config_dict: dict, pipeline: list = None) -> L3Processor: - """Factory: creates component + processor""" +def create_l3_processor(config_dict: dict, pipeline: list | None = None) -> L3Processor: + """Factory: creates component + processor.""" config = L3Config(**config_dict) component = L3Component(config) - return L3Processor(config, component, pipeline) \ No newline at end of file + return L3Processor(config, component, pipeline) diff --git a/src/glinker/l4/__init__.py b/src/glinker/l4/__init__.py index cb67571..d900050 100644 --- a/src/glinker/l4/__init__.py +++ b/src/glinker/l4/__init__.py @@ -3,7 +3,7 @@ from .processor import L4Processor __all__ = [ - "L4Config", "L4Component", + "L4Config", "L4Processor", -] \ No newline at end of file +] diff --git a/src/glinker/l4/component.py b/src/glinker/l4/component.py index 389d109..03b5768 100644 --- a/src/glinker/l4/component.py +++ b/src/glinker/l4/component.py @@ -1,19 +1,20 @@ -from typing import List, Optional +from typing import List + from gliner import GLiNER + from glinker.core.base import BaseComponent from glinker.l3.models import L3Entity + from .models import L4Config class L4Component(BaseComponent[L4Config]): - """GLiNER-based reranking component (uni-encoder only, no precomputed embeddings)""" + """GLiNER-based reranking component (uni-encoder only, no precomputed embeddings).""" def _setup(self): - """Initialize GLiNER model""" + """Initialize GLiNER model.""" self.model = GLiNER.from_pretrained( - self.config.model_name, - token=self.config.token, - max_length=self.config.max_length + self.config.model_name, token=self.config.token, max_length=self.config.max_length ) self.model.to(self.config.device) @@ -23,14 +24,11 @@ def get_available_methods(self) -> List[str]: "predict_entities_chunked", "filter_by_score", "sort_by_position", - "deduplicate_entities" + "deduplicate_entities", ] def predict_entities( - self, - text: str, - labels: List[str], - input_spans: List[List[dict]] = None + self, text: str, labels: List[str], input_spans: List[List[dict]] | None = None ) -> List[L3Entity]: """Predict entities using GLiNER for a single label set. @@ -42,12 +40,12 @@ def predict_entities( if not labels: return [] - kwargs = dict( - threshold=self.config.threshold, - flat_ner=self.config.flat_ner, - multi_label=self.config.multi_label, - return_class_probs=True - ) + kwargs = { + "threshold": self.config.threshold, + "flat_ner": self.config.flat_ner, + "multi_label": self.config.multi_label, + "return_class_probs": True, + } if input_spans is not None: kwargs["input_spans"] = input_spans @@ -60,7 +58,7 @@ def predict_entities( start=e["start"], end=e["end"], score=e["score"], - class_probs=e.get("class_probs") + class_probs=e.get("class_probs"), ) for e in entities ] @@ -70,7 +68,7 @@ def predict_entities_chunked( text: str, labels: List[str], max_labels: int, - input_spans: List[List[dict]] = None + input_spans: List[List[dict]] | None = None, ) -> List[L3Entity]: """Predict entities with candidate chunking. @@ -90,10 +88,7 @@ def predict_entities_chunked( return self.predict_entities(text, labels, input_spans=input_spans) # Split labels into chunks - chunks = [ - labels[i:i + max_labels] - for i in range(0, len(labels), max_labels) - ] + chunks = [labels[i : i + max_labels] for i in range(0, len(labels), max_labels)] all_entities = [] for chunk in chunks: @@ -102,17 +97,19 @@ def predict_entities_chunked( return all_entities - def filter_by_score(self, entities: List[L3Entity], threshold: float = None) -> List[L3Entity]: - """Filter entities by confidence score""" + def filter_by_score( + self, entities: List[L3Entity], threshold: float | None = None + ) -> List[L3Entity]: + """Filter entities by confidence score.""" threshold = threshold if threshold is not None else self.config.threshold return [e for e in entities if e.score >= threshold] def sort_by_position(self, entities: List[L3Entity]) -> List[L3Entity]: - """Sort entities by position in text""" + """Sort entities by position in text.""" return sorted(entities, key=lambda e: e.start) def deduplicate_entities(self, entities: List[L3Entity]) -> List[L3Entity]: - """Remove duplicate entities, keeping the highest-scoring one per span""" + """Remove duplicate entities, keeping the highest-scoring one per span.""" best = {} for entity in entities: key = (entity.text, entity.start, entity.end) diff --git a/src/glinker/l4/models.py b/src/glinker/l4/models.py index 109e65d..3e6670f 100644 --- a/src/glinker/l4/models.py +++ b/src/glinker/l4/models.py @@ -1,11 +1,11 @@ from pydantic import Field -from typing import List, Any, Optional -from glinker.core.base import BaseConfig, BaseInput, BaseOutput + +from glinker.core.base import BaseConfig class L4Config(BaseConfig): model_name: str = Field(...) - token: Optional[str] = Field(None) + token: str | None = Field(None) device: str = Field("cpu") threshold: float = Field(0.5) flat_ner: bool = Field(True) @@ -13,9 +13,9 @@ class L4Config(BaseConfig): max_labels: int = Field( 20, description="Maximum number of candidate labels per inference call. " - "When candidates exceed this, they are split into chunks." + "When candidates exceed this, they are split into chunks.", ) max_length: int = Field( None, - description="Maximum sequence length for tokenization. Passed to GLiNER.from_pretrained." + description="Maximum sequence length for tokenization. Passed to GLiNER.from_pretrained.", ) diff --git a/src/glinker/l4/processor.py b/src/glinker/l4/processor.py index 722213e..714faf3 100644 --- a/src/glinker/l4/processor.py +++ b/src/glinker/l4/processor.py @@ -1,19 +1,21 @@ -from typing import Any, List, Optional +from typing import Any, List + from glinker.core.base import BaseProcessor +from glinker.l3.models import L3Input, L3Output from glinker.core.registry import processor_registry -from glinker.l3.models import L3Input, L3Output, L3Entity + from .models import L4Config from .component import L4Component class L4Processor(BaseProcessor[L4Config, L3Input, L3Output]): - """GLiNER reranking processor with candidate chunking""" + """GLiNER reranking processor with candidate chunking.""" def __init__( self, config: L4Config, component: L4Component, - pipeline: list[tuple[str, dict[str, Any]]] = None + pipeline: list[tuple[str, dict[str, Any]]] | None = None, ): super().__init__(config, component, pipeline) self._validate_pipeline() @@ -24,7 +26,7 @@ def _default_pipeline(self) -> list[tuple[str, dict[str, Any]]]: ("predict_entities_chunked", {}), ("deduplicate_entities", {}), ("filter_by_score", {}), - ("sort_by_position", {}) + ("sort_by_position", {}), ] @staticmethod @@ -35,10 +37,10 @@ def _build_input_spans(l1_entities_for_text: List[Any]) -> List[List[dict]]: def __call__( self, - texts: List[str] = None, - candidates: List[List[Any]] = None, - l1_entities: List[List[Any]] = None, - input_data: L3Input = None + texts: List[str] | None = None, + candidates: List[List[Any]] | None = None, + l1_entities: List[List[Any]] | None = None, + input_data: L3Input = None, ) -> L3Output: """Process texts with candidate labels using chunked GLiNER inference. @@ -61,9 +63,8 @@ def __call__( max_labels = self.config.max_labels # Detect shared candidates (all texts use the same list) - shared = ( - len(candidates_to_process) > 1 - and all(c is candidates_to_process[0] for c in candidates_to_process[1:]) + shared = len(candidates_to_process) > 1 and all( + c is candidates_to_process[0] for c in candidates_to_process[1:] ) shared_labels = None @@ -84,11 +85,10 @@ def __call__( if shared: labels = shared_labels + elif self.schema: + labels, _ = self._create_gliner_labels_with_mapping(text_candidates) else: - if self.schema: - labels, _ = self._create_gliner_labels_with_mapping(text_candidates) - else: - labels = [self._extract_label(c) for c in text_candidates] + labels = [self._extract_label(c) for c in text_candidates] # Run chunked prediction entities = self.component.predict_entities_chunked( @@ -105,23 +105,23 @@ def __call__( return L3Output(entities=all_entities) def _extract_label(self, candidate: Any) -> str: - """Extract label from candidate""" - if hasattr(candidate, 'label'): + """Extract label from candidate.""" + if hasattr(candidate, "label"): return candidate.label return str(candidate) def _create_gliner_labels_with_mapping(self, candidates: List[Any]) -> tuple: """Create GLiNER labels using schema template and return label->candidate mapping.""" - template = self.schema.get('template', '{label}') + template = self.schema.get("template", "{label}") labels = [] label_to_candidate = {} seen = set() for candidate in candidates: try: - if hasattr(candidate, 'model_dump'): + if hasattr(candidate, "model_dump"): cand_dict = candidate.model_dump() - elif hasattr(candidate, 'dict'): + elif hasattr(candidate, "dict"): cand_dict = candidate.dict() elif isinstance(candidate, dict): cand_dict = candidate @@ -139,7 +139,7 @@ def _create_gliner_labels_with_mapping(self, candidates: List[Any]) -> tuple: label_to_candidate[label] = candidate seen.add(label_lower) except (KeyError, AttributeError): - if hasattr(candidate, 'label'): + if hasattr(candidate, "label"): if candidate.label.lower() not in seen: labels.append(candidate.label) label_to_candidate[candidate.label] = candidate @@ -149,8 +149,8 @@ def _create_gliner_labels_with_mapping(self, candidates: List[Any]) -> tuple: @processor_registry.register("l4_reranker") -def create_l4_processor(config_dict: dict, pipeline: list = None) -> L4Processor: - """Factory: creates component + processor""" +def create_l4_processor(config_dict: dict, pipeline: list | None = None) -> L4Processor: + """Factory: creates component + processor.""" config = L4Config(**config_dict) component = L4Component(config) - return L4Processor(config, component, pipeline) \ No newline at end of file + return L4Processor(config, component, pipeline) diff --git a/tests/conftest.py b/tests/conftest.py index 2559c87..f50c345 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,109 @@ import os from pathlib import Path from typing import List, Dict, Any +from unittest.mock import MagicMock, patch +import torch + + +# ============================================================ +# GLOBAL MOCKS FOR CI/CD +# ============================================================ + +@pytest.fixture(scope="session", autouse=True) +def mock_spacy_models(): + """Mock spacy.load globally to avoid downloading models in CI.""" + with patch("spacy.load") as mock_load: + mock_nlp = MagicMock() + + def mock_call(text): + mock_doc = MagicMock() + entities = [] + # Simple mock: find common keywords + keywords = { + "TP53": ("GENE", 0, 4), + "BRCA1": ("GENE", 0, 5), + "cancer": ("DISEASE", 0, 6), + "breast cancer": ("DISEASE", 0, 13), + } + for keyword, (label, _, _) in keywords.items(): + if keyword in text: + start = text.find(keyword) + mock_ent = MagicMock() + mock_ent.text = keyword + mock_ent.label_ = label + mock_ent.start_char = start + mock_ent.end_char = start + len(keyword) + entities.append(mock_ent) + + mock_doc.ents = entities + mock_doc.noun_chunks = [] + mock_doc.text = text + return mock_doc + + # Mock both direct call and pipe + mock_nlp.side_effect = mock_call + + # Mock pipe() for batch processing + def mock_pipe(texts, batch_size=None): + for text in texts: + yield mock_call(text) + + mock_nlp.pipe = mock_pipe + mock_load.return_value = mock_nlp + yield mock_load + + +@pytest.fixture(scope="session", autouse=True) +def mock_gliner_models(): + """Mock GLiNER.from_pretrained globally to avoid downloading models in CI.""" + with patch("gliner.GLiNER.from_pretrained") as mock_from_pretrained: + mock_model = MagicMock() + + # Mock tokenizer + mock_tokenizer = MagicMock() + mock_tokenizer.model_max_length = 512 + mock_data_processor = MagicMock() + mock_data_processor.labels_tokenizer = mock_tokenizer + mock_model.data_processor = mock_data_processor + + # Mock config for BiEncoder + mock_config = MagicMock() + mock_config.labels_encoder = MagicMock() + mock_model.config = mock_config + + # Mock predict_entities + def mock_predict(text, labels, **kwargs): + results = [] + keywords = { + "gene": ["TP53", "BRCA1", "EGFR"], + "disease": ["cancer", "carcinoma"], + "protein": ["protein", "p53"], + } + for label in labels: + for keyword in keywords.get(label, []): + if keyword.lower() in text.lower(): + start = text.lower().find(keyword.lower()) + results.append({ + "text": text[start : start + len(keyword)], + "label": label, + "start": start, + "end": start + len(keyword), + "score": 0.9, + "class_probs": {label: 0.9}, + }) + return results + + mock_model.predict_entities.side_effect = mock_predict + mock_model.to.return_value = mock_model + + # Mock encode_labels + def mock_encode_labels(labels, batch_size=32): + return torch.randn(len(labels), 768) + + mock_model.encode_labels.side_effect = mock_encode_labels + + mock_from_pretrained.return_value = mock_model + yield mock_from_pretrained # ============================================================ @@ -130,8 +233,9 @@ def l1_config(l1_config_dict): @pytest.fixture def l1_component(l1_config): - """L1SpacyComponent instance.""" + """L1SpacyComponent instance (uses global spacy mock).""" from glinker.l1.component import L1SpacyComponent + return L1SpacyComponent(l1_config) @@ -251,22 +355,55 @@ def l3_config(l3_config_dict): return L3Config(**l3_config_dict) -# L3 component is expensive - session scoped -@pytest.fixture(scope="session") -def l3_component(): - """L3Component instance (session-scoped for efficiency).""" +# L3 component (uses global GLiNER mock) +@pytest.fixture +def l3_component(l3_config): + """L3Component instance (uses global GLiNER mock).""" from glinker.l3.component import L3Component - from glinker.l3.models import L3Config - config = L3Config( - model_name="knowledgator/gliner-linker-large-v1.0", - token="hf_", - device="cpu", - threshold=0.3, - flat_ner=True, - multi_label=False, - max_length=512 - ) - return L3Component(config) + + return L3Component(l3_config) + + +# ============================================================ +# L4 FIXTURES +# ============================================================ + +@pytest.fixture +def l4_config_dict() -> Dict[str, Any]: + """L4 processor configuration dictionary.""" + return { + "model_name": "knowledgator/gliner-linker-large-v1.0", + "token": "hf_", + "device": "cpu", + "threshold": 0.5, + "flat_ner": True, + "multi_label": False, + "max_labels": 20, + "max_length": 512 + } + + +@pytest.fixture +def l4_config(l4_config_dict): + """L4Config instance.""" + from glinker.l4.models import L4Config + return L4Config(**l4_config_dict) + + +# L4 component (uses global GLiNER mock) +@pytest.fixture +def l4_component(l4_config): + """L4Component instance (uses global GLiNER mock).""" + from glinker.l4.component import L4Component + + return L4Component(l4_config) + + +@pytest.fixture +def l4_processor(l4_component, l4_config): + """L4Processor instance.""" + from glinker.l4.processor import L4Processor + return L4Processor(l4_config, l4_component) # ============================================================ diff --git a/tests/core/test_config_builder.py b/tests/core/test_config_builder.py index 4109245..f54a3c1 100644 --- a/tests/core/test_config_builder.py +++ b/tests/core/test_config_builder.py @@ -126,7 +126,7 @@ def test_l2_add_dict_layer_defaults(self): assert layer["write"] is True assert layer["search_mode"] == ["exact", "fuzzy"] assert layer["ttl"] == 0 - assert layer["fuzzy"]["min_similarity"] == 0.6 + assert layer["fuzzy"]["min_similarity"] == 0.75 def test_l2_add_redis_layer_defaults(self): from glinker.core.builders import ConfigBuilder @@ -326,13 +326,15 @@ def test_l0_builder_returns_parent(self): class TestConfigBuilding: """Tests for config building and validation.""" - def test_build_requires_l1(self): + def test_build_without_l1_works(self): + """L1 is now optional - pipeline can work with explicit entities.""" from glinker.core.builders import ConfigBuilder builder = ConfigBuilder(name="test") builder.l3.configure() - with pytest.raises(ValueError, match="L1 configuration is required"): - builder.build() + # Should not raise - L1 is optional now + config = builder.build() + assert config is not None def test_build_requires_l3(self): from glinker.core.builders import ConfigBuilder diff --git a/tests/l4/__init__.py b/tests/l4/__init__.py new file mode 100644 index 0000000..28e50cd --- /dev/null +++ b/tests/l4/__init__.py @@ -0,0 +1,3 @@ +""" +Tests for L4 (GLiNER reranking) layer. +""" diff --git a/tests/l4/test_component.py b/tests/l4/test_component.py new file mode 100644 index 0000000..ae51b79 --- /dev/null +++ b/tests/l4/test_component.py @@ -0,0 +1,237 @@ +""" +Tests for src/l4/component.py - L4 GLiNER reranking component. +""" + +import pytest +from unittest.mock import MagicMock, patch + + +class TestL4ComponentCreation: + """Tests for L4Component initialization.""" + + def test_import(self): + from glinker.l4.component import L4Component + assert L4Component is not None + + def test_creation(self, l4_component): + assert l4_component is not None + + def test_has_model(self, l4_component): + assert l4_component.model is not None + + def test_has_config(self, l4_component): + assert l4_component.config is not None + + def test_device_property(self, l4_component): + # Should be on CPU for tests + assert hasattr(l4_component.config, 'device') + + def test_get_available_methods(self, l4_component): + methods = l4_component.get_available_methods() + assert isinstance(methods, list) + assert "predict_entities" in methods + assert "predict_entities_chunked" in methods + assert "filter_by_score" in methods + assert "sort_by_position" in methods + assert "deduplicate_entities" in methods + + +class TestL4ComponentPredictEntities: + """Tests for predict_entities method.""" + + def test_predict_simple(self, l4_component): + text = "TP53 is a gene." + labels = ["gene", "disease", "protein"] + entities = l4_component.predict_entities(text, labels) + assert isinstance(entities, list) + + def test_predict_returns_l3entity(self, l4_component): + from glinker.l3.models import L3Entity + text = "BRCA1 mutations cause breast cancer." + labels = ["gene", "disease"] + entities = l4_component.predict_entities(text, labels) + for entity in entities: + assert isinstance(entity, L3Entity) + + def test_predict_empty_labels(self, l4_component): + entities = l4_component.predict_entities("Some text", []) + assert entities == [] + + def test_predict_with_input_spans(self, l4_component): + text = "TP53 mutations cause cancer." + labels = ["gene", "disease"] + # Provide spans to constrain prediction + input_spans = [[{"start": 0, "end": 4}]] # Only "TP53" + entities = l4_component.predict_entities( + text, labels, input_spans=input_spans + ) + assert isinstance(entities, list) + # If entities found, they should be within input_spans + for entity in entities: + assert entity.start >= 0 + assert entity.end <= len(text) + + def test_entity_has_all_fields(self, l4_component): + text = "TP53 is important." + labels = ["gene"] + entities = l4_component.predict_entities(text, labels) + for entity in entities: + assert hasattr(entity, 'text') + assert hasattr(entity, 'label') + assert hasattr(entity, 'start') + assert hasattr(entity, 'end') + assert hasattr(entity, 'score') + + def test_entity_positions_valid(self, l4_component): + text = "BRCA1 causes breast cancer." + labels = ["gene", "disease"] + entities = l4_component.predict_entities(text, labels) + for entity in entities: + assert entity.start >= 0 + assert entity.end > entity.start + assert entity.end <= len(text) + + def test_predict_with_class_probs(self, l4_component): + """Verify that class_probs are returned when requested.""" + text = "TP53 mutations." + labels = ["gene", "protein"] + entities = l4_component.predict_entities(text, labels) + for entity in entities: + # class_probs should be included (return_class_probs=True) + assert hasattr(entity, 'class_probs') + + +class TestL4ComponentPredictEntitiesChunked: + """Tests for predict_entities_chunked method.""" + + def test_predict_chunked_small_labels(self, l4_component): + """When labels <= max_labels, should behave like predict_entities.""" + text = "TP53 and BRCA1 are genes." + labels = ["gene", "disease"] + max_labels = 10 + entities = l4_component.predict_entities_chunked( + text, labels, max_labels + ) + assert isinstance(entities, list) + + def test_predict_chunked_large_labels(self, l4_component): + """When labels > max_labels, should split into chunks.""" + text = "TP53 is a gene." + # Create many labels to force chunking + labels = [f"label_{i}" for i in range(50)] + max_labels = 10 + entities = l4_component.predict_entities_chunked( + text, labels, max_labels + ) + assert isinstance(entities, list) + + def test_predict_chunked_empty_labels(self, l4_component): + entities = l4_component.predict_entities_chunked( + "Some text", [], max_labels=10 + ) + assert entities == [] + + def test_predict_chunked_with_input_spans(self, l4_component): + text = "TP53 mutations cause cancer." + labels = ["gene", "disease", "protein"] + input_spans = [[{"start": 0, "end": 4}]] + entities = l4_component.predict_entities_chunked( + text, labels, max_labels=2, input_spans=input_spans + ) + assert isinstance(entities, list) + + def test_predict_chunked_exact_boundary(self, l4_component): + """Test when len(labels) == max_labels.""" + text = "Test text." + labels = ["A", "B", "C"] + entities = l4_component.predict_entities_chunked( + text, labels, max_labels=3 + ) + assert isinstance(entities, list) + + +class TestL4ComponentFilterByScore: + """Tests for filter_by_score method.""" + + def test_filter_by_score(self, l4_component): + from glinker.l3.models import L3Entity + entities = [ + L3Entity(text="A", label="X", start=0, end=1, score=0.9), + L3Entity(text="B", label="X", start=5, end=6, score=0.4), + L3Entity(text="C", label="X", start=10, end=11, score=0.6), + ] + filtered = l4_component.filter_by_score(entities, threshold=0.5) + assert len(filtered) == 2 + assert all(e.score >= 0.5 for e in filtered) + + def test_filter_by_score_default_threshold(self, l4_component): + from glinker.l3.models import L3Entity + entities = [ + L3Entity(text="A", label="X", start=0, end=1, score=0.9), + L3Entity(text="B", label="X", start=5, end=6, score=0.1), + ] + # Should use config threshold (0.5) + filtered = l4_component.filter_by_score(entities) + assert len(filtered) == 1 + + def test_filter_empty(self, l4_component): + filtered = l4_component.filter_by_score([]) + assert filtered == [] + + +class TestL4ComponentSortByPosition: + """Tests for sort_by_position method.""" + + def test_sort_by_position(self, l4_component): + from glinker.l3.models import L3Entity + entities = [ + L3Entity(text="C", label="X", start=20, end=21, score=0.9), + L3Entity(text="A", label="X", start=0, end=1, score=0.8), + L3Entity(text="B", label="X", start=10, end=11, score=0.7), + ] + sorted_ents = l4_component.sort_by_position(entities) + assert sorted_ents[0].text == "A" + assert sorted_ents[1].text == "B" + assert sorted_ents[2].text == "C" + + def test_sort_empty(self, l4_component): + sorted_ents = l4_component.sort_by_position([]) + assert sorted_ents == [] + + +class TestL4ComponentDeduplicate: + """Tests for deduplicate_entities method.""" + + def test_deduplicate(self, l4_component): + from glinker.l3.models import L3Entity + entities = [ + L3Entity(text="TP53", label="gene", start=0, end=4, score=0.9), + L3Entity(text="TP53", label="gene", start=0, end=4, score=0.85), + L3Entity(text="BRCA1", label="gene", start=10, end=15, score=0.8), + ] + deduped = l4_component.deduplicate_entities(entities) + assert len(deduped) == 2 + + def test_deduplicate_keeps_highest_score(self, l4_component): + from glinker.l3.models import L3Entity + entities = [ + L3Entity(text="TP53", label="gene", start=0, end=4, score=0.5), + L3Entity(text="TP53", label="gene", start=0, end=4, score=0.9), + ] + deduped = l4_component.deduplicate_entities(entities) + assert len(deduped) == 1 + assert deduped[0].score == 0.9 # Keeps highest score + + def test_deduplicate_empty(self, l4_component): + deduped = l4_component.deduplicate_entities([]) + assert deduped == [] + + def test_deduplicate_different_spans(self, l4_component): + """Different spans should not be deduplicated.""" + from glinker.l3.models import L3Entity + entities = [ + L3Entity(text="TP53", label="gene", start=0, end=4, score=0.9), + L3Entity(text="TP53", label="gene", start=10, end=14, score=0.8), + ] + deduped = l4_component.deduplicate_entities(entities) + assert len(deduped) == 2 # Different positions diff --git a/tests/l4/test_models.py b/tests/l4/test_models.py new file mode 100644 index 0000000..a07bd1c --- /dev/null +++ b/tests/l4/test_models.py @@ -0,0 +1,91 @@ +""" +Tests for src/l4/models.py - L4 data models. +""" + +import pytest + + +class TestL4Config: + """Tests for L4Config model.""" + + def test_import(self): + from glinker.l4.models import L4Config + assert L4Config is not None + + def test_creation_minimal(self): + from glinker.l4.models import L4Config + config = L4Config(model_name="test-model") + assert config.model_name == "test-model" + + def test_creation_full(self, l4_config_dict): + from glinker.l4.models import L4Config + config = L4Config(**l4_config_dict) + assert config.model_name == l4_config_dict["model_name"] + assert config.device == l4_config_dict["device"] + assert config.threshold == l4_config_dict["threshold"] + + def test_defaults(self): + from glinker.l4.models import L4Config + config = L4Config(model_name="test") + + assert config.device == "cpu" + assert config.threshold == 0.5 + assert config.flat_ner is True + assert config.multi_label is False + assert config.max_labels == 20 + assert config.token is None + assert config.max_length is None + + def test_max_labels_default(self): + from glinker.l4.models import L4Config + config = L4Config(model_name="test") + assert config.max_labels == 20 + + def test_max_labels_custom(self): + from glinker.l4.models import L4Config + config = L4Config(model_name="test", max_labels=50) + assert config.max_labels == 50 + + def test_max_length_optional(self): + from glinker.l4.models import L4Config + config = L4Config(model_name="test") + assert config.max_length is None + + config_with_length = L4Config(model_name="test", max_length=256) + assert config_with_length.max_length == 256 + + def test_token_optional(self): + from glinker.l4.models import L4Config + config = L4Config(model_name="test") + assert config.token is None + + config_with_token = L4Config(model_name="test", token="hf_test") + assert config_with_token.token == "hf_test" + + def test_config_is_base_config(self): + from glinker.l4.models import L4Config + from glinker.core.base import BaseConfig + config = L4Config(model_name="test") + assert isinstance(config, BaseConfig) + + def test_field_types(self): + from glinker.l4.models import L4Config + config = L4Config( + model_name="test", + token="token", + device="cuda", + threshold=0.8, + flat_ner=False, + multi_label=True, + max_labels=30, + max_length=512 + ) + + assert isinstance(config.model_name, str) + assert isinstance(config.token, str) + assert isinstance(config.device, str) + assert isinstance(config.threshold, float) + assert isinstance(config.flat_ner, bool) + assert isinstance(config.multi_label, bool) + assert isinstance(config.max_labels, int) + assert isinstance(config.max_length, int) diff --git a/tests/l4/test_processor.py b/tests/l4/test_processor.py new file mode 100644 index 0000000..a2fc38e --- /dev/null +++ b/tests/l4/test_processor.py @@ -0,0 +1,325 @@ +""" +Tests for src/l4/processor.py - L4 GLiNER reranking processor. +""" + +import pytest + + +class TestL4ProcessorCreation: + """Tests for L4Processor initialization.""" + + def test_import(self): + from glinker.l4.processor import L4Processor + assert L4Processor is not None + + def test_create_via_factory(self, l4_config_dict): + from glinker.l4.processor import create_l4_processor + processor = create_l4_processor(l4_config_dict) + assert processor is not None + + def test_processor_has_component(self, l4_processor): + assert l4_processor.component is not None + + def test_processor_has_config(self, l4_processor): + assert l4_processor.config is not None + + def test_processor_has_default_pipeline(self, l4_processor): + assert l4_processor.pipeline is not None + assert len(l4_processor.pipeline) > 0 + + +class TestL4ProcessorCall: + """Tests for L4Processor __call__ method.""" + + def test_call_single_text(self, l4_processor, single_text): + texts = [single_text] + candidates = [["gene", "disease", "protein"]] + result = l4_processor(texts=texts, candidates=candidates) + + from glinker.l3.models import L3Output + assert isinstance(result, L3Output) + assert len(result.entities) == 1 + + def test_call_multiple_texts(self, l4_processor, sample_texts): + candidates = [ + ["gene", "disease"], + ["gene", "disease"], + ["gene", "disease", "drug"] + ] + result = l4_processor(texts=sample_texts, candidates=candidates) + + assert len(result.entities) == len(sample_texts) + + def test_call_empty_input(self, l4_processor): + result = l4_processor(texts=[], candidates=[]) + assert len(result.entities) == 0 + + def test_call_empty_text(self, l4_processor): + result = l4_processor(texts=[""], candidates=[[]]) + assert len(result.entities) == 1 + assert result.entities[0] == [] + + def test_result_entities_are_lists(self, l4_processor, single_text): + result = l4_processor( + texts=[single_text], + candidates=[["gene", "disease"]] + ) + for entities in result.entities: + assert isinstance(entities, list) + + def test_call_with_input_data(self, l4_processor): + """Test calling with L3Input object.""" + from glinker.l3.models import L3Input + + input_data = L3Input( + texts=["TP53 is a gene."], + labels=[["gene", "disease"]] + ) + result = l4_processor(input_data=input_data) + + from glinker.l3.models import L3Output + assert isinstance(result, L3Output) + + def test_call_raises_without_params(self, l4_processor): + """Should raise ValueError if neither texts+candidates nor input_data provided.""" + with pytest.raises(ValueError): + l4_processor() + + +class TestL4ProcessorChunking: + """Tests for candidate chunking functionality.""" + + def test_chunking_with_many_candidates(self, l4_processor): + """Test that processor handles many candidates via chunking.""" + text = "TP53 is a gene." + # Create many candidates to force chunking + many_candidates = [f"label_{i}" for i in range(50)] + + result = l4_processor( + texts=[text], + candidates=[many_candidates] + ) + + assert isinstance(result.entities, list) + assert len(result.entities) == 1 + + def test_max_labels_config_used(self, l4_config_dict): + """Test that max_labels from config is used.""" + from glinker.l4.processor import create_l4_processor + + l4_config_dict["max_labels"] = 5 + processor = create_l4_processor(l4_config_dict) + + assert processor.config.max_labels == 5 + + +class TestL4ProcessorWithL1Entities: + """Tests for using L1 entities as input_spans.""" + + def test_call_with_l1_entities(self, l4_processor, single_text): + """Test providing L1 entities to constrain predictions.""" + from glinker.l1.models import L1Entity + + # Create mock L1 entities + l1_entities = [[ + L1Entity( + text="TP53", + label="", + start=0, + end=4, + left_context="", + right_context=" mutations cause breast cancer." + ) + ]] + + result = l4_processor( + texts=[single_text], + candidates=[["gene", "disease"]], + l1_entities=l1_entities + ) + + assert isinstance(result.entities, list) + + def test_build_input_spans(self, l4_processor): + """Test _build_input_spans static method.""" + from glinker.l1.models import L1Entity + + l1_entities = [ + L1Entity(text="TP53", label="", start=0, end=4, left_context="", right_context=""), + L1Entity(text="BRCA1", label="", start=10, end=15, left_context="", right_context="") + ] + + spans = l4_processor._build_input_spans(l1_entities) + + assert isinstance(spans, list) + assert len(spans) == 1 # Returns [spans] + assert len(spans[0]) == 2 + assert spans[0][0] == {"start": 0, "end": 4} + assert spans[0][1] == {"start": 10, "end": 15} + + +class TestL4ProcessorSharedCandidates: + """Tests for shared candidate optimization.""" + + def test_shared_candidates_detected(self, l4_processor, sample_texts): + """Test that shared candidates (same list) are optimized.""" + shared_candidates = ["gene", "disease", "protein"] + candidates = [shared_candidates] * len(sample_texts) + + result = l4_processor(texts=sample_texts, candidates=candidates) + + assert len(result.entities) == len(sample_texts) + + def test_different_candidates_per_text(self, l4_processor): + """Test with different candidates per text.""" + texts = ["TP53 is a gene.", "Aspirin is a drug."] + candidates = [ + ["gene", "disease"], + ["drug", "chemical"] + ] + + result = l4_processor(texts=texts, candidates=candidates) + + assert len(result.entities) == 2 + + +class TestL4ProcessorWithSchema: + """Tests for schema template functionality.""" + + def test_processor_with_schema(self, l4_config_dict): + """Test using schema template for candidates.""" + from glinker.l4.processor import create_l4_processor + + processor = create_l4_processor(l4_config_dict) + processor.schema = {"template": "{label} - {description}"} + + # Mock candidates with attributes + class MockCandidate: + def __init__(self, label, description): + self.label = label + self.description = description + + def model_dump(self): + return {"label": self.label, "description": self.description} + + candidates = [ + MockCandidate("gene", "genetic element"), + MockCandidate("disease", "medical condition") + ] + + labels, mapping = processor._create_gliner_labels_with_mapping(candidates) + + assert len(labels) == 2 + assert "gene - genetic element" in labels + assert "disease - medical condition" in labels + + def test_extract_label_from_object(self, l4_processor): + """Test _extract_label with object having label attribute.""" + class MockCandidate: + def __init__(self, label): + self.label = label + + candidate = MockCandidate("test_label") + label = l4_processor._extract_label(candidate) + + assert label == "test_label" + + def test_extract_label_from_string(self, l4_processor): + """Test _extract_label with string candidate.""" + label = l4_processor._extract_label("string_label") + assert label == "string_label" + + def test_create_labels_with_pydantic(self, l4_processor): + """Test label creation with Pydantic model.""" + from pydantic import BaseModel + + class Candidate(BaseModel): + label: str + description: str + + processor = l4_processor + processor.schema = {"template": "{label}"} + + candidates = [ + Candidate(label="gene", description="test"), + Candidate(label="disease", description="test2") + ] + + labels, mapping = processor._create_gliner_labels_with_mapping(candidates) + + assert len(labels) == 2 + assert "gene" in labels + + def test_create_labels_deduplication(self, l4_processor): + """Test that duplicate labels are removed.""" + processor = l4_processor + processor.schema = {"template": "{label}"} + + class Candidate: + def __init__(self, label): + self.label = label + + def model_dump(self): + return {"label": self.label} + + candidates = [ + Candidate("gene"), + Candidate("Gene"), # Different case + Candidate("disease") + ] + + labels, mapping = processor._create_gliner_labels_with_mapping(candidates) + + # Should deduplicate case-insensitive + assert len(labels) <= 2 + + +class TestL4ProcessorPipeline: + """Tests for pipeline execution.""" + + def test_custom_pipeline(self, l4_component, l4_config): + """Test processor with custom pipeline.""" + from glinker.l4.processor import L4Processor + + custom_pipeline = [ + ("predict_entities_chunked", {}), + ("filter_by_score", {"threshold": 0.7}), + ] + + processor = L4Processor(l4_config, l4_component, custom_pipeline) + assert processor.pipeline == custom_pipeline + + def test_default_pipeline(self, l4_processor): + """Test default pipeline is set correctly.""" + pipeline_methods = [step[0] for step in l4_processor.pipeline] + + assert "predict_entities_chunked" in pipeline_methods + assert "deduplicate_entities" in pipeline_methods + assert "filter_by_score" in pipeline_methods + assert "sort_by_position" in pipeline_methods + + +class TestL4ProcessorModels: + """Tests for L4 models.""" + + def test_l4_config_import(self): + from glinker.l4.models import L4Config + assert L4Config is not None + + def test_l4_config_creation(self, l4_config_dict): + from glinker.l4.models import L4Config + config = L4Config(**l4_config_dict) + + assert config.model_name == l4_config_dict["model_name"] + assert config.threshold == l4_config_dict["threshold"] + assert config.max_labels == l4_config_dict.get("max_labels", 20) + + def test_l4_config_defaults(self): + from glinker.l4.models import L4Config + config = L4Config(model_name="test-model") + + assert config.device == "cpu" + assert config.threshold == 0.5 + assert config.max_labels == 20 + assert config.flat_ner is True + assert config.multi_label is False