1- import pkg_resources
21import platform
32import logging
43import importlib .metadata
@@ -37,30 +36,6 @@ def get_direct_dependencies(include_extras: bool) -> list[str]:
3736 return reqs
3837
3938
40- def _update_installed_dependencies_recursive (
41- gathered : dict [str , str ],
42- package : pkg_resources .Distribution ) -> dict [str , str ]:
43- if package .project_name .lower () in gathered :
44- logger .debug ("Trying to update already found transitive dependency "
45- "'%'" , package .egg_name )
46- return gathered
47- for req in package .requires ():
48- if req .project_name .lower () in gathered :
49- logger .debug ("Trying to look up already found transitive "
50- "dependency '%'" , req .project_name )
51- continue # don't look for it again
52- try :
53- dep = pkg_resources .get_distribution (req .project_name )
54- except pkg_resources .DistributionNotFound as e :
55- logger .warning ("Unable to locate requirement '%s':" ,
56- req .project_name , exc_info = e )
57- continue
58- _update_installed_dependencies_recursive (gathered , dep )
59- # do this after so its dependencies get explored
60- gathered [dep .project_name .lower ()] = dep .version
61- return gathered
62-
63-
6439def get_transitive_deps (direct_deps : list [str ]) -> dict [str , str ]:
6540 """Get the transitive dependencies of the direct dependencies.
6641
@@ -70,12 +45,45 @@ def get_transitive_deps(direct_deps: list[str]) -> dict[str, str]:
7045 Returns:
7146 dict[str, str]: The dependency names and their corresponding versions.
7247 """
73- # map from name to version so as to avoid multiples of the same package
74- all_transitive_deps : dict [str , str ] = {}
75- for dep in direct_deps :
76- package = pkg_resources .get_distribution (dep )
77- _update_installed_dependencies_recursive (all_transitive_deps , package )
78- return all_transitive_deps
48+ all_deps : dict [str , str ] = {}
49+ to_process = set (direct_deps )
50+ processed = set ()
51+ # list installed packages for ease of use
52+ installed_packages = {
53+ dist .metadata ['name' ].lower ()
54+ for dist in importlib .metadata .distributions ()}
55+
56+ while to_process :
57+ package = to_process .pop ()
58+ if package in processed :
59+ continue
60+
61+ processed .add (package )
62+
63+ try :
64+ dist = importlib .metadata .distribution (package )
65+ except importlib .metadata .PackageNotFoundError :
66+ # NOTE: if not installed, we won't bother
67+ # after all, if we can save the model, clearly
68+ # everything is working
69+ continue
70+ requires = dist .requires or []
71+
72+ for req in requires :
73+ match = DEP_NAME_PATTERN .match (req )
74+ if match is None :
75+ raise ValueError (f"Malformed dependency: { req } " )
76+ dep_name = match .group (0 ).lower ()
77+ if (dep_name and dep_name not in processed and
78+ dep_name in installed_packages ):
79+ all_deps [dep_name ] = importlib .metadata .distribution (
80+ dep_name ).version
81+ to_process .add (dep_name )
82+
83+ for direct in direct_deps :
84+ # remove direct dependencies if they were added
85+ all_deps .pop (direct , None )
86+ return all_deps
7987
8088
8189def get_installed_dependencies (include_extras : bool ) -> dict [str , str ]:
@@ -89,13 +97,39 @@ def get_installed_dependencies(include_extras: bool) -> dict[str, str]:
8997 """
9098 direct_deps = get_direct_dependencies (include_extras )
9199 installed_packages : dict [str , str ] = {}
92- for package in pkg_resources .working_set :
93- if package .project_name .lower () not in direct_deps :
100+ for package in importlib .metadata .distributions ():
101+ req_name = package .metadata ["name" ].lower ()
102+ # NOTE: we're checking against the '-' typed package name not
103+ # the import name (which will have _ instead)
104+ req_name_dashes = req_name .replace ("_" , "-" )
105+ if all (cn not in direct_deps for cn in
106+ [req_name , req_name_dashes ]):
94107 continue
95- installed_packages [package . project_name . lower () ] = package .version
108+ installed_packages [req_name ] = package .version
96109 return installed_packages
97110
98111
112+ def is_dependency_installed (dependency : str ) -> bool :
113+ """Checks whether a dependency is installed.
114+
115+ This takes into account changes such as '-' vs '_'.
116+ For example, `typing-extensions` is a direct dependency,
117+ but its module path will be `typing_extension` and that's
118+ how we can find it as an installed dependency.
119+
120+ Args:
121+ dependency (str): The dependency in question.
122+
123+ Returns:
124+ bool: Whether the depedency has been installed.
125+ """
126+ installed_deps = get_installed_dependencies (True )
127+ dep_name = dependency .lower ()
128+ dep_name_underscores = dependency .replace ("-" , "_" )
129+ options = [dep_name , dep_name_underscores ]
130+ return any (option in installed_deps for option in options )
131+
132+
99133class Environment (BaseModel , AbstractSerialisable ):
100134 dependencies : dict [str , str ]
101135 transitive_deps : dict [str , str ]
0 commit comments