Skip to content

Commit bd4edbb

Browse files
[WIP] Web server utilities (#1067)
2 parents 5994c58 + cd1718e commit bd4edbb

7 files changed

Lines changed: 162 additions & 96 deletions

File tree

mp_api/client/_server_utils.py

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,62 @@
1-
"""Define utilities needed by the MP web server."""
1+
"""Define flask-dependent utilities for the web server."""
2+
23
from __future__ import annotations
34

5+
from typing import TYPE_CHECKING
6+
47
try:
5-
import flask
8+
from flask import has_request_context as _has_request_context
9+
from flask import request
610
except ImportError:
7-
from mp_api.client.core.exceptions import MPRestError
11+
_has_request_context = None # type: ignore[assignment]
12+
request = None # type: ignore[assignment]
13+
14+
from mp_api.client.core.utils import validate_api_key
815

9-
raise MPRestError("`flask` must be installed to use server utilities.")
16+
if TYPE_CHECKING:
17+
from collections.abc import Sequence
18+
from typing import Any
1019

11-
import requests
1220

13-
from mp_api.client import MPRester
14-
from mp_api.client.core.utils import validate_api_key
21+
def has_request_context() -> bool:
22+
"""Determine if the current context is a request.
23+
24+
Returns:
25+
--------
26+
bool : True if in a request context
27+
False if flask is not installed or not in a request context.
28+
"""
29+
return _has_request_context is not None and _has_request_context()
30+
31+
32+
def get_request_headers() -> dict[str, Any]:
33+
"""Get the headers if operating in a request context.
34+
35+
Returns:
36+
--------
37+
dict of str to Any
38+
Empty dict if flask is not installed, or not in a request context.
39+
Request headers otherwise.
40+
"""
41+
return request.headers if has_request_context() else {}
1542

16-
SESSION = requests.Session()
1743

44+
def is_dev_env(
45+
localhosts: Sequence[str] = ("localhost:", "127.0.0.1:", "0.0.0.0:")
46+
) -> bool:
47+
"""Determine if current env is local/developmental or production.
1848
19-
def is_localhost() -> bool:
20-
"""Determine if current env is local or production.
49+
Args:
50+
localhosts (Sequence of str) : A set of host prefixes for checking
51+
if the current environment is locally deployed.
2152
2253
Returns:
2354
bool: True if the environment is locally hosted.
2455
"""
2556
return (
2657
True
27-
if not flask.has_request_context()
28-
else flask.request.headers.get("Host", "").startswith(
29-
("localhost:", "127.0.0.1:", "0.0.0.0:")
30-
)
58+
if not has_request_context()
59+
else get_request_headers().get("Host", "").startswith(localhosts)
3160
)
3261

3362

@@ -37,7 +66,7 @@ def get_consumer() -> dict[str, str]:
3766
Returns:
3867
dict of str to str, the headers associated with the consumer
3968
"""
40-
if not flask.has_request_context():
69+
if not has_request_context():
4170
return {}
4271

4372
names = [
@@ -48,7 +77,7 @@ def get_consumer() -> dict[str, str]:
4877
"X-Authenticated-Groups", # groups this user belongs to
4978
"X-Consumer-Groups", # same as X-Authenticated-Groups
5079
]
51-
headers = flask.request.headers
80+
headers = get_request_headers()
5281
return {name: headers[name] for name in names if headers.get(name) is not None}
5382

5483

@@ -65,39 +94,23 @@ def is_logged_in_user(consumer: dict[str, str] | None = None) -> bool:
6594
return bool(not c.get("X-Anonymous-Consumer") and c.get("X-Consumer-Id"))
6695

6796

68-
def get_user_api_key(consumer: dict[str, str] | None = None) -> str | None:
97+
def get_user_api_key(
98+
api_key: str | None = None, consumer: dict[str, str] | None = None
99+
) -> str | None:
69100
"""Get the api key that belongs to the current user.
70101
71102
If running on localhost, api key is obtained from
72103
the environment variable MP_API_KEY.
73104
74105
Args:
106+
api_key (str or None) : User API key
75107
consumer (dict of str to str, or None): Headers associated with the consumer
76108
77109
Returns:
78110
str, the API key, or None if no API key could be identified.
79111
"""
80-
c = consumer or get_consumer()
81-
82-
if is_localhost():
83-
return validate_api_key()
84-
elif is_logged_in_user(c):
112+
if is_dev_env():
113+
return validate_api_key(api_key=api_key)
114+
elif is_logged_in_user(c := consumer or get_consumer()):
85115
return c.get("X-Consumer-Custom-Id")
86116
return None
87-
88-
89-
def get_rester(**kwargs) -> MPRester:
90-
"""Create MPRester with headers set for localhost and production compatibility.
91-
92-
Args:
93-
**kwargs : kwargs to pass to MPRester
94-
95-
Returns:
96-
MPRester
97-
"""
98-
if is_localhost():
99-
dev_api_key = get_user_api_key()
100-
SESSION.headers["x-api-key"] = dev_api_key or ""
101-
return MPRester(api_key=dev_api_key, session=SESSION, **kwargs)
102-
103-
return MPRester(headers=get_consumer(), session=SESSION, **kwargs)

mp_api/client/core/client.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,6 @@
5757
validate_ids,
5858
)
5959

60-
try:
61-
import flask
62-
63-
_flask_is_installed = True
64-
except ImportError:
65-
_flask_is_installed = False
66-
6760
if TYPE_CHECKING:
6861
from collections.abc import Callable, Iterable, Iterator
6962
from typing import Any
@@ -1177,17 +1170,13 @@ def _submit_request_and_process(
11771170
Returns:
11781171
Tuple with data and total number of docs in matching the query in the database.
11791172
"""
1180-
headers = None
1181-
if _flask_is_installed and flask.has_request_context():
1182-
headers = flask.request.headers
1183-
11841173
try:
11851174
response = self.session.get(
11861175
url=url,
11871176
verify=verify,
11881177
params=params,
11891178
timeout=timeout,
1190-
headers=headers if headers else self.headers,
1179+
headers=self.headers,
11911180
)
11921181
except requests.exceptions.ConnectTimeout:
11931182
raise MPRestError(

mp_api/client/core/settings.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ class MAPIClientSettings(BaseSettings):
7272
description="Angle tolerance for structure matching in degrees.",
7373
)
7474

75+
LOG_FILE: Path = Field(
76+
Path("~/.mprester.log.yaml").expanduser(),
77+
description="Path for storing last accessed database version.",
78+
)
79+
7580
LOCAL_DATASET_CACHE: Path = Field(
7681
Path("~/mp_datasets").expanduser(),
7782
description="Target directory for downloading full datasets",

mp_api/client/mprester.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
2222
from requests import Session, get
2323

24+
from mp_api.client._server_utils import get_consumer, get_user_api_key, is_dev_env
2425
from mp_api.client.core import BaseRester
2526
from mp_api.client.core._oxygen_evolution import OxygenEvolution
2627
from mp_api.client.core.exceptions import (
@@ -32,7 +33,6 @@
3233
from mp_api.client.core.utils import (
3334
LazyImport,
3435
load_json,
35-
validate_api_key,
3636
validate_endpoint,
3737
validate_ids,
3838
)
@@ -141,16 +141,18 @@ def __init__(
141141
force_renew: Option to overwrite existing local dataset
142142
**kwargs: access to legacy kwargs that may be in the process of being deprecated
143143
"""
144-
self.api_key = validate_api_key(api_key)
144+
self.api_key = get_user_api_key(api_key=api_key)
145145

146146
self.endpoint = validate_endpoint(endpoint)
147147

148-
self.headers = headers or {}
148+
self.headers = headers or get_consumer()
149149
self.session = session or BaseRester._create_session(
150150
api_key=self.api_key,
151151
include_user_agent=include_user_agent,
152152
headers=self.headers,
153153
)
154+
if is_dev_env():
155+
self.session.headers["x-api-key"] = self.api_key or ""
154156
self._include_user_agent = include_user_agent
155157
self.use_document_model = use_document_model
156158
self.mute_progress_bars = mute_progress_bars
@@ -209,7 +211,7 @@ def __init__(
209211
)
210212

211213
if notify_db_version:
212-
raise NotImplementedError("This has not yet been implemented.")
214+
self._db_version_check()
213215

214216
# Dynamically set rester attributes.
215217
# First, materials and molecules top level resters are set.
@@ -296,6 +298,10 @@ def __dir__(self):
296298
+ [r.split("/", 1)[0] for r in TOP_LEVEL_RESTERS if not r.startswith("_")]
297299
)
298300

301+
def __repr__(self) -> str:
302+
db_version = self.get_database_version()
303+
return f"MPRester({'v' + db_version if db_version else 'unknown version'})"
304+
299305
def get_task_ids_associated_with_material_id(
300306
self, material_id: str, calc_types: list[CalcType] | None = None
301307
) -> list[str]:
@@ -367,7 +373,7 @@ def get_database_version(self) -> str | None:
367373
where "_DD" may be optional. An additional numerical suffix
368374
might be added if multiple releases happen on the same day.
369375
370-
Returns: database version as a string
376+
Returns: database version as a string if accessible, None otherwise
371377
"""
372378
if (get_resp := get(url=self.endpoint + "heartbeat")).status_code == 403:
373379
_emit_status_warning()
@@ -1636,3 +1642,31 @@ def get_oxygen_evolution(
16361642
phase_diagram,
16371643
unique_composition,
16381644
)
1645+
1646+
def _db_version_check(self) -> None:
1647+
"""Check if the database version has drifted."""
1648+
import yaml # type: ignore[import-untyped]
1649+
1650+
db_version = self.get_database_version()
1651+
old_db_version = None
1652+
if MAPI_CLIENT_SETTINGS.LOG_FILE.exists():
1653+
old_db_version = (
1654+
yaml.safe_load(MAPI_CLIENT_SETTINGS.LOG_FILE.read_text()) or {}
1655+
).get("MAPI_DB_VERSION", None)
1656+
1657+
# Handle legacy pymatgen behavior
1658+
if not isinstance(old_db_version, str):
1659+
old_db_version = None
1660+
1661+
if old_db_version != db_version:
1662+
MAPI_CLIENT_SETTINGS.LOG_FILE.write_text(
1663+
yaml.safe_dump({"MAPI_DB_VERSION": db_version})
1664+
)
1665+
1666+
if old_db_version:
1667+
warnings.warn(
1668+
"Materials Project database version has changed "
1669+
f"from v{old_db_version} to v{db_version}.",
1670+
category=MPRestWarning,
1671+
stacklevel=2,
1672+
)

0 commit comments

Comments
 (0)