|
9 | 9 | from typing import Any, Optional |
10 | 10 |
|
11 | 11 | import click |
12 | | -import pglast |
13 | | -import psutil |
14 | | -import psycopg |
15 | 12 | import sqlalchemy |
| 13 | +from gymlib.pg import create_sqlalchemy_conn, sql_file_execute |
16 | 14 | from gymlib.shell import subprocess_run |
17 | 15 | from gymlib.symlinks_paths import ( |
18 | 16 | get_dbdata_tgz_symlink_path, |
|
28 | 26 | is_fully_resolved, |
29 | 27 | is_ssd, |
30 | 28 | ) |
31 | | -from sqlalchemy import create_engine, text |
| 29 | +from sqlalchemy import text |
32 | 30 |
|
33 | 31 | from benchmark.constants import DEFAULT_SCALE_FACTOR |
34 | 32 | from benchmark.job.load_info import JobLoadInfo |
@@ -365,91 +363,3 @@ def sqlalchemy_conn_execute( |
365 | 363 | conn: sqlalchemy.Connection, sql: str |
366 | 364 | ) -> sqlalchemy.engine.CursorResult[Any]: |
367 | 365 | return conn.execute(text(sql)) |
368 | | - |
369 | | - |
370 | | -def sql_file_queries(dbgym_workspace: DBGymWorkspace, filepath: Path) -> list[str]: |
371 | | - with dbgym_workspace.open_and_save(filepath) as f: |
372 | | - lines: list[str] = [] |
373 | | - for line in f: |
374 | | - if line.startswith("--"): |
375 | | - continue |
376 | | - if len(line.strip()) == 0: |
377 | | - continue |
378 | | - lines.append(line) |
379 | | - queries_str = "".join(lines) |
380 | | - queries: list[str] = pglast.split(queries_str) |
381 | | - return queries |
382 | | - |
383 | | - |
384 | | -def sql_file_execute( |
385 | | - dbgym_workspace: DBGymWorkspace, conn: sqlalchemy.Connection, filepath: Path |
386 | | -) -> None: |
387 | | - for sql in sql_file_queries(dbgym_workspace, filepath): |
388 | | - sqlalchemy_conn_execute(conn, sql) |
389 | | - |
390 | | - |
391 | | -# The reason pgport is an argument is because when doing agnet HPO, we want to run multiple instances of Postgres |
392 | | -# at the same time. In this situation, they need to have different ports |
393 | | -def get_connstr(pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg: bool = True) -> str: |
394 | | - connstr_suffix = f"{DBGYM_POSTGRES_USER}:{DBGYM_POSTGRES_PASS}@localhost:{pgport}/{DBGYM_POSTGRES_DBNAME}" |
395 | | - # use_psycopg means whether or not we use the psycopg.connect() function |
396 | | - # counterintuively, you *don't* need psycopg in the connection string if you *are* |
397 | | - # using the psycopg.connect() function |
398 | | - connstr_prefix = "postgresql" if use_psycopg else "postgresql+psycopg" |
399 | | - return connstr_prefix + "://" + connstr_suffix |
400 | | - |
401 | | - |
402 | | -def get_kv_connstr(pgport: int = DEFAULT_POSTGRES_PORT) -> str: |
403 | | - return f"host=localhost port={pgport} user={DBGYM_POSTGRES_USER} password={DBGYM_POSTGRES_PASS} dbname={DBGYM_POSTGRES_DBNAME}" |
404 | | - |
405 | | - |
406 | | -def create_psycopg_conn(pgport: int = DEFAULT_POSTGRES_PORT) -> psycopg.Connection[Any]: |
407 | | - connstr = get_connstr(use_psycopg=True, pgport=pgport) |
408 | | - psycopg_conn = psycopg.connect(connstr, autocommit=True, prepare_threshold=None) |
409 | | - return psycopg_conn |
410 | | - |
411 | | - |
412 | | -def create_sqlalchemy_conn( |
413 | | - pgport: int = DEFAULT_POSTGRES_PORT, |
414 | | -) -> sqlalchemy.Connection: |
415 | | - connstr = get_connstr(use_psycopg=False, pgport=pgport) |
416 | | - engine: sqlalchemy.Engine = create_engine( |
417 | | - connstr, |
418 | | - execution_options={"isolation_level": "AUTOCOMMIT"}, |
419 | | - ) |
420 | | - return engine.connect() |
421 | | - |
422 | | - |
423 | | -def get_is_postgres_running() -> bool: |
424 | | - """ |
425 | | - This is often used in assertions to ensure that Postgres isn't running before we |
426 | | - execute some code. |
427 | | -
|
428 | | - I intentionally do not have a function that forcefully *stops* all Postgres instances. |
429 | | - This is risky because it could accidentally stop instances it wasn't supposed (e.g. |
430 | | - Postgres instances run by other users on the same machine). |
431 | | -
|
432 | | - Stopping Postgres instances is thus a responsibility of the human to take care of. |
433 | | - """ |
434 | | - return len(get_running_postgres_ports()) > 0 |
435 | | - |
436 | | - |
437 | | -def get_running_postgres_ports() -> list[int]: |
438 | | - """ |
439 | | - Returns a list of all ports on which Postgres is currently running. |
440 | | -
|
441 | | - There are ways to check with psycopg/sqlalchemy. However, I chose to check using |
442 | | - psutil to keep it as simple as possible and orthogonal to how connections work. |
443 | | - """ |
444 | | - running_ports = [] |
445 | | - |
446 | | - for conn in psutil.net_connections(kind="inet"): |
447 | | - if conn.status == "LISTEN": |
448 | | - try: |
449 | | - proc = psutil.Process(conn.pid) |
450 | | - if proc.name() == "postgres": |
451 | | - running_ports.append(conn.laddr.port) |
452 | | - except (psutil.NoSuchProcess, psutil.AccessDenied): |
453 | | - continue |
454 | | - |
455 | | - return running_ports |
0 commit comments