-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathutils.py
More file actions
174 lines (144 loc) · 5.04 KB
/
utils.py
File metadata and controls
174 lines (144 loc) · 5.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#!/usr/bin/python3
"""Utility helpers shared by the csub CLI entrypoint."""
from __future__ import annotations
import base64
import re
import shlex
import subprocess
import sys
import tempfile
from contextlib import contextmanager
from datetime import timedelta
from pathlib import Path
from typing import Dict, Iterable, Iterator, List
DEFAULT_ENV_FILE = ".env"
DEFAULT_GITHUB_KEY_PATH = Path("~/.ssh/github").expanduser()
DEFAULT_TIME_SECONDS = 12 * 60 * 60
SECRET_KEYS = {
"WANDB_API_KEY",
"HF_TOKEN",
"SSH_PRIVATE_KEY_B64",
"SSH_PUBLIC_KEY",
"SSH_KNOWN_HOSTS",
"GIT_USER_NAME",
"GIT_USER_EMAIL",
"GITHUB_TOKEN",
}
def parse_env_file(path: Path) -> Dict[str, str]:
# this is just to avoid having to instsall dotenv for the users
env: Dict[str, str] = {}
if not path.exists():
sys.exit(
f"Environment file {path} does not exist. Copy templates/user.env.example first."
)
for raw_line in path.read_text().splitlines():
line = raw_line.strip()
if not line or line.startswith("#"):
continue
# Strip inline comments
if "#" in line:
line = line.split("#", 1)[0].strip()
if "=" not in line:
sys.exit(f"Invalid line in {path}: {raw_line}")
key, value = line.split("=", 1)
key = key.strip()
value = value.strip()
if value and value[0] == value[-1] and value.startswith(("'", '"')):
value = value[1:-1]
env[key] = value
return env
def _expand_path(raw: str | None, fallback: Path) -> Path:
return Path(raw).expanduser() if raw else fallback
def maybe_populate_github_ssh(env: Dict[str, str]) -> None:
"""Populate SSH_* secrets from a local GitHub key if they are empty."""
key_path = _expand_path(env.get("GITHUB_SSH_KEY_PATH"), DEFAULT_GITHUB_KEY_PATH)
pub_path = _expand_path(
env.get("GITHUB_SSH_PUBLIC_KEY_PATH"), Path(f"{key_path}.pub")
)
if not env.get("SSH_PRIVATE_KEY_B64") and key_path.exists():
encoded = base64.b64encode(key_path.read_bytes()).decode("ascii")
env["SSH_PRIVATE_KEY_B64"] = encoded
print(f"[csub] Loaded SSH_PRIVATE_KEY_B64 from {key_path}", file=sys.stderr)
if not env.get("SSH_PUBLIC_KEY") and pub_path.exists():
env["SSH_PUBLIC_KEY"] = pub_path.read_text().strip()
print(f"[csub] Loaded SSH_PUBLIC_KEY from {pub_path}", file=sys.stderr)
@contextmanager
def rendered_env_file(env: Dict[str, str]) -> Iterator[Path]:
"""Serialize the in-memory env dict to a temporary file for kubectl."""
tmp = tempfile.NamedTemporaryFile("w", delete=False)
tmp_path = Path(tmp.name)
try:
with tmp:
for key, value in env.items():
tmp.write(f"{key}={value}\n")
yield tmp_path
finally:
tmp_path.unlink(missing_ok=True)
def parse_duration(spec: str | None) -> int:
if not spec:
return DEFAULT_TIME_SECONDS
pattern = r"^((?P<days>\d+)d)?((?P<hours>\d+)h)?((?P<minutes>\d+)m)?((?P<seconds>\d+)s?)?$"
match = re.match(pattern, spec)
if not match:
sys.exit(f"Invalid duration '{spec}'. Use formats like 12h, 45m, 2d6h30m.")
parts = {k: int(v) for k, v in match.groupdict().items() if v}
return int(timedelta(**parts).total_seconds())
def shlex_join(cmd: Iterable[str]) -> str:
return " ".join(shlex.quote(str(token)) for token in cmd)
def ensure_secret(env_path: Path, namespace: str, secret_name: str) -> None:
create_cmd = [
"kubectl",
"-n",
namespace,
"create",
"secret",
"generic",
secret_name,
f"--from-env-file={env_path}",
"--dry-run=client",
"-o",
"yaml",
]
try:
rendered = subprocess.run(
create_cmd,
check=True,
capture_output=True,
text=True,
).stdout
except subprocess.CalledProcessError as exc:
sys.exit(f"kubectl failed to render the secret:\n{exc.stderr}")
try:
subprocess.run(
["kubectl", "-n", namespace, "apply", "-f", "-"],
input=rendered,
check=True,
text=True,
)
except subprocess.CalledProcessError as exc:
sys.exit(f"kubectl failed to apply the secret:\n{exc.stderr}")
def add_env_flags(cmd: List[str], values: Dict[str, str]) -> None:
for key, value in values.items():
if value == "":
continue
cmd.extend(["--environment", f"{key}={value}"])
def add_secret_env_flags(
cmd: List[str],
env: Dict[str, str],
secret_name: str,
extra_secret_keys: Iterable[str],
) -> None:
keys = set(SECRET_KEYS).union(k.strip() for k in extra_secret_keys if k.strip())
for key in sorted(keys):
if key not in env or env[key] == "":
continue
cmd.extend(["--environment", f"{key}=SECRET:{secret_name},{key}"])
__all__ = [
"DEFAULT_ENV_FILE",
"build_runai_command",
"ensure_secret",
"maybe_populate_github_ssh",
"parse_env_file",
"rendered_env_file",
"shlex_join",
]