1- """Define utilities needed by the MP web server."""
1+ """Define flask-dependent utilities for the web server."""
2+
23from __future__ import annotations
34
5+ from typing import TYPE_CHECKING
6+
47try :
5- import flask
8+ from flask import has_request_context as _has_request_context
9+ from flask import request
610except ImportError :
7- from mp_api .client .core .exceptions import MPRestError
11+ _has_request_context = None # type: ignore[assignment]
12+ request = None # type: ignore[assignment]
13+
14+ from mp_api .client .core .utils import validate_api_key
815
9- raise MPRestError ("`flask` must be installed to use server utilities." )
16+ if TYPE_CHECKING :
17+ from collections .abc import Sequence
18+ from typing import Any
1019
11- import requests
1220
13- from mp_api .client import MPRester
14- from mp_api .client .core .utils import validate_api_key
21+ def has_request_context () -> bool :
22+ """Determine if the current context is a request.
23+
24+ Returns:
25+ --------
26+ bool : True if in a request context
27+ False if flask is not installed or not in a request context.
28+ """
29+ return _has_request_context is not None and _has_request_context ()
30+
31+
32+ def get_request_headers () -> dict [str , Any ]:
33+ """Get the headers if operating in a request context.
34+
35+ Returns:
36+ --------
37+ dict of str to Any
38+ Empty dict if flask is not installed, or not in a request context.
39+ Request headers otherwise.
40+ """
41+ return request .headers if has_request_context () else {}
1542
16- SESSION = requests .Session ()
1743
44+ def is_dev_env (
45+ localhosts : Sequence [str ] = ("localhost:" , "127.0.0.1:" , "0.0.0.0:" )
46+ ) -> bool :
47+ """Determine if current env is local/developmental or production.
1848
19- def is_localhost () -> bool :
20- """Determine if current env is local or production.
49+ Args:
50+ localhosts (Sequence of str) : A set of host prefixes for checking
51+ if the current environment is locally deployed.
2152
2253 Returns:
2354 bool: True if the environment is locally hosted.
2455 """
2556 return (
2657 True
27- if not flask .has_request_context ()
28- else flask .request .headers .get ("Host" , "" ).startswith (
29- ("localhost:" , "127.0.0.1:" , "0.0.0.0:" )
30- )
58+ if not has_request_context ()
59+ else get_request_headers ().get ("Host" , "" ).startswith (localhosts )
3160 )
3261
3362
@@ -37,7 +66,7 @@ def get_consumer() -> dict[str, str]:
3766 Returns:
3867 dict of str to str, the headers associated with the consumer
3968 """
40- if not flask . has_request_context ():
69+ if not has_request_context ():
4170 return {}
4271
4372 names = [
@@ -48,7 +77,7 @@ def get_consumer() -> dict[str, str]:
4877 "X-Authenticated-Groups" , # groups this user belongs to
4978 "X-Consumer-Groups" , # same as X-Authenticated-Groups
5079 ]
51- headers = flask . request . headers
80+ headers = get_request_headers ()
5281 return {name : headers [name ] for name in names if headers .get (name ) is not None }
5382
5483
@@ -65,39 +94,23 @@ def is_logged_in_user(consumer: dict[str, str] | None = None) -> bool:
6594 return bool (not c .get ("X-Anonymous-Consumer" ) and c .get ("X-Consumer-Id" ))
6695
6796
68- def get_user_api_key (consumer : dict [str , str ] | None = None ) -> str | None :
97+ def get_user_api_key (
98+ api_key : str | None = None , consumer : dict [str , str ] | None = None
99+ ) -> str | None :
69100 """Get the api key that belongs to the current user.
70101
71102 If running on localhost, api key is obtained from
72103 the environment variable MP_API_KEY.
73104
74105 Args:
106+ api_key (str or None) : User API key
75107 consumer (dict of str to str, or None): Headers associated with the consumer
76108
77109 Returns:
78110 str, the API key, or None if no API key could be identified.
79111 """
80- c = consumer or get_consumer ()
81-
82- if is_localhost ():
83- return validate_api_key ()
84- elif is_logged_in_user (c ):
112+ if is_dev_env ():
113+ return validate_api_key (api_key = api_key )
114+ elif is_logged_in_user (c := consumer or get_consumer ()):
85115 return c .get ("X-Consumer-Custom-Id" )
86116 return None
87-
88-
89- def get_rester (** kwargs ) -> MPRester :
90- """Create MPRester with headers set for localhost and production compatibility.
91-
92- Args:
93- **kwargs : kwargs to pass to MPRester
94-
95- Returns:
96- MPRester
97- """
98- if is_localhost ():
99- dev_api_key = get_user_api_key ()
100- SESSION .headers ["x-api-key" ] = dev_api_key or ""
101- return MPRester (api_key = dev_api_key , session = SESSION , ** kwargs )
102-
103- return MPRester (headers = get_consumer (), session = SESSION , ** kwargs )
0 commit comments