Skip to content

Commit 2e69561

Browse files
authored
CU-8699h2yv2: Avoid using pkg_resources (deprecated) (CogStack/MedCAT2#94)
* CU-8699h2yv2: Avoid using pkg_resources (deprecated) * CU-8699h2yv2: Update relevancy check during dependency calculations * CU-8699h2yv2: Simplify getting of transitive dependencies * CU-8699h2yv2: Unify metadata name access * CU-8699h2yv2: Imrpove getting of installed dependencies * CU-8699h2yv2: Add convenience method to figure out if a dependency is installed * CU-8699h2yv2: Remove unnecessary option for installation targets * CU-8699h2yv2: Remove unnecessary option for installed dependency targets * CU-8699h2yv2: Update tests to correctly identify installed dependencies
1 parent 7fc80eb commit 2e69561

File tree

2 files changed

+69
-35
lines changed

2 files changed

+69
-35
lines changed

medcat-v2/medcat/utils/envsnapshot.py

Lines changed: 68 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pkg_resources
21
import platform
32
import logging
43
import importlib.metadata
@@ -37,30 +36,6 @@ def get_direct_dependencies(include_extras: bool) -> list[str]:
3736
return reqs
3837

3938

40-
def _update_installed_dependencies_recursive(
41-
gathered: dict[str, str],
42-
package: pkg_resources.Distribution) -> dict[str, str]:
43-
if package.project_name.lower() in gathered:
44-
logger.debug("Trying to update already found transitive dependency "
45-
"'%'", package.egg_name)
46-
return gathered
47-
for req in package.requires():
48-
if req.project_name.lower() in gathered:
49-
logger.debug("Trying to look up already found transitive "
50-
"dependency '%'", req.project_name)
51-
continue # don't look for it again
52-
try:
53-
dep = pkg_resources.get_distribution(req.project_name)
54-
except pkg_resources.DistributionNotFound as e:
55-
logger.warning("Unable to locate requirement '%s':",
56-
req.project_name, exc_info=e)
57-
continue
58-
_update_installed_dependencies_recursive(gathered, dep)
59-
# do this after so its dependencies get explored
60-
gathered[dep.project_name.lower()] = dep.version
61-
return gathered
62-
63-
6439
def get_transitive_deps(direct_deps: list[str]) -> dict[str, str]:
6540
"""Get the transitive dependencies of the direct dependencies.
6641
@@ -70,12 +45,45 @@ def get_transitive_deps(direct_deps: list[str]) -> dict[str, str]:
7045
Returns:
7146
dict[str, str]: The dependency names and their corresponding versions.
7247
"""
73-
# map from name to version so as to avoid multiples of the same package
74-
all_transitive_deps: dict[str, str] = {}
75-
for dep in direct_deps:
76-
package = pkg_resources.get_distribution(dep)
77-
_update_installed_dependencies_recursive(all_transitive_deps, package)
78-
return all_transitive_deps
48+
all_deps: dict[str, str] = {}
49+
to_process = set(direct_deps)
50+
processed = set()
51+
# list installed packages for ease of use
52+
installed_packages = {
53+
dist.metadata['name'].lower()
54+
for dist in importlib.metadata.distributions()}
55+
56+
while to_process:
57+
package = to_process.pop()
58+
if package in processed:
59+
continue
60+
61+
processed.add(package)
62+
63+
try:
64+
dist = importlib.metadata.distribution(package)
65+
except importlib.metadata.PackageNotFoundError:
66+
# NOTE: if not installed, we won't bother
67+
# after all, if we can save the model, clearly
68+
# everything is working
69+
continue
70+
requires = dist.requires or []
71+
72+
for req in requires:
73+
match = DEP_NAME_PATTERN.match(req)
74+
if match is None:
75+
raise ValueError(f"Malformed dependency: {req}")
76+
dep_name = match.group(0).lower()
77+
if (dep_name and dep_name not in processed and
78+
dep_name in installed_packages):
79+
all_deps[dep_name] = importlib.metadata.distribution(
80+
dep_name).version
81+
to_process.add(dep_name)
82+
83+
for direct in direct_deps:
84+
# remove direct dependencies if they were added
85+
all_deps.pop(direct, None)
86+
return all_deps
7987

8088

8189
def get_installed_dependencies(include_extras: bool) -> dict[str, str]:
@@ -89,13 +97,39 @@ def get_installed_dependencies(include_extras: bool) -> dict[str, str]:
8997
"""
9098
direct_deps = get_direct_dependencies(include_extras)
9199
installed_packages: dict[str, str] = {}
92-
for package in pkg_resources.working_set:
93-
if package.project_name.lower() not in direct_deps:
100+
for package in importlib.metadata.distributions():
101+
req_name = package.metadata["name"].lower()
102+
# NOTE: we're checking against the '-' typed package name not
103+
# the import name (which will have _ instead)
104+
req_name_dashes = req_name.replace("_", "-")
105+
if all(cn not in direct_deps for cn in
106+
[req_name, req_name_dashes]):
94107
continue
95-
installed_packages[package.project_name.lower()] = package.version
108+
installed_packages[req_name] = package.version
96109
return installed_packages
97110

98111

112+
def is_dependency_installed(dependency: str) -> bool:
113+
"""Checks whether a dependency is installed.
114+
115+
This takes into account changes such as '-' vs '_'.
116+
For example, `typing-extensions` is a direct dependency,
117+
but its module path will be `typing_extension` and that's
118+
how we can find it as an installed dependency.
119+
120+
Args:
121+
dependency (str): The dependency in question.
122+
123+
Returns:
124+
bool: Whether the depedency has been installed.
125+
"""
126+
installed_deps = get_installed_dependencies(True)
127+
dep_name = dependency.lower()
128+
dep_name_underscores = dependency.replace("-", "_")
129+
options = [dep_name, dep_name_underscores]
130+
return any(option in installed_deps for option in options)
131+
132+
99133
class Environment(BaseModel, AbstractSerialisable):
100134
dependencies: dict[str, str]
101135
transitive_deps: dict[str, str]

medcat-v2/tests/utils/test_envsnapshot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_dir_deps_have_no_version(self):
3030
def test_all_dir_deps_have_been_installed(self):
3131
for dep in self.dir_deps:
3232
with self.subTest(dep):
33-
self.assertIn(dep, self.installed_deps)
33+
self.assertTrue(envsnapshot.is_dependency_installed(dep))
3434

3535
def test_all_deps_add_to_correct(self):
3636
# NOTE: didn't account for test/dev deps

0 commit comments

Comments
 (0)