Skip to content

Commit 8c899d6

Browse files
committed
[run-tests] Substitute UDFs
1 parent 983addc commit 8c899d6

File tree

5 files changed

+88
-22
lines changed

5 files changed

+88
-22
lines changed

bigquery_etl/dryrun.py

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,59 @@ def wrap_in_view_for_dryrun(sql: str) -> str:
9595
if not statements or not isinstance(statements[-1], sqlglot.exp.Select):
9696
return sql
9797

98-
# Split original SQL by semicolons to preserve formatting;
99-
# stripping formatting causes some query dry runs to fail
100-
parts = [p for p in sql.split(";") if p.strip()]
98+
# Replace CREATE TEMP FUNCTION with CREATE FUNCTION using fully qualified names
99+
# CREATE VIEW doesn't support temp functions
100+
test_project = ConfigLoader.get(
101+
"default", "test_project", fallback="bigquery-etl-integration-test"
102+
)
101103

102-
if len(parts) != len(statements):
103-
return sql
104+
def replace_temp_function(match):
105+
func_name = match.group(1)
106+
# If function name is already qualified, keep it; otherwise add project.dataset prefix
107+
if "." not in func_name and "`" not in func_name:
108+
return f"CREATE FUNCTION `{test_project}.tmp.{func_name}`"
109+
else:
110+
return f"CREATE FUNCTION {func_name}"
104111

105-
prefix_sql = ";\n".join(parts[:-1]) + ";" if len(parts) > 1 else ""
106-
query_sql = parts[-1].strip()
112+
sql = re.sub(
113+
r"\bCREATE\s+TEMP(?:ORARY)?\s+FUNCTION\s+([^\s(]+)",
114+
replace_temp_function,
115+
sql,
116+
flags=re.IGNORECASE,
117+
)
118+
119+
# Single statement - just wrap it
120+
if len(statements) == 1:
121+
view_name = f"_dryrun_view_{random_str(8)}"
122+
test_project = ConfigLoader.get(
123+
"default", "test_project", fallback="bigquery-etl-integration-test"
124+
)
125+
query_sql = sql.strip().rstrip(";")
126+
return f"CREATE VIEW `{test_project}.tmp.{view_name}` AS\n{query_sql}"
127+
128+
# Multiple statements: use sqlglot tokenizer to find statement boundaries
129+
# This handles semicolons in strings and comments
130+
tokens = list(sqlglot.tokens.Tokenizer(dialect="bigquery").tokenize(sql))
131+
132+
# Find semicolon tokens that separate statements (not in strings/comments)
133+
semicolon_positions = []
134+
for token in tokens:
135+
if token.token_type == sqlglot.tokens.TokenType.SEMICOLON:
136+
semicolon_positions.append(token.end)
137+
138+
# We need (len(statements) - 1) semicolons to separate statements
139+
if len(semicolon_positions) >= len(statements) - 1:
140+
# The (n-1)th semicolon separates the prefix from the last statement
141+
split_pos = semicolon_positions[len(statements) - 2]
142+
prefix_sql = sql[:split_pos].strip()
143+
query_sql = sql[split_pos:].strip().lstrip(";").strip()
144+
else:
145+
# Fallback: regenerate prefix statements, use regenerated query
146+
prefix_statements = statements[:-1]
147+
prefix_sql = ";\n".join(
148+
stmt.sql(dialect="bigquery") for stmt in prefix_statements
149+
)
150+
query_sql = statements[-1].sql(dialect="bigquery")
107151

108152
# Wrap in view
109153
view_name = f"_dryrun_view_{random_str(8)}"
@@ -112,7 +156,7 @@ def wrap_in_view_for_dryrun(sql: str) -> str:
112156
)
113157
wrapped_query = f"CREATE VIEW `{test_project}.tmp.{view_name}` AS\n{query_sql}"
114158

115-
return f"{prefix_sql}\n\n{wrapped_query}" if prefix_sql else wrapped_query
159+
return f"{prefix_sql};\n\n{wrapped_query}"
116160

117161
except Exception as e:
118162
print(f"Warning: Failed to wrap SQL in view: {e}")
@@ -271,11 +315,6 @@ def dry_run_result(self):
271315
else:
272316
sql = self.get_sql()
273317

274-
# Wrap the query in a CREATE VIEW for faster dry runs
275-
# Skip wrapping when strip_dml=True as it's used for special analysis modes
276-
if not self.strip_dml:
277-
sql = wrap_in_view_for_dryrun(sql)
278-
279318
query_parameters = []
280319
scheduling_metadata = self.metadata.scheduling if self.metadata else {}
281320
if date_partition_parameter := scheduling_metadata.get(
@@ -299,6 +338,30 @@ def dry_run_result(self):
299338
)
300339
)
301340

341+
# Wrap the query in a CREATE VIEW for faster dry runs
342+
# Skip wrapping when strip_dml=True as it's used for special analysis modes
343+
if not self.strip_dml:
344+
# If query has parameters, replace them with literal values in the wrapped version
345+
# since CREATE VIEW cannot use parameterized queries
346+
sql_for_wrapping = sql
347+
if query_parameters:
348+
for param in query_parameters:
349+
param_name = f"@{param.name}"
350+
# Convert parameter value to SQL literal
351+
if param.type_ == "DATE":
352+
param_value = f"DATE '{param.value}'"
353+
elif param.type_ in ("STRING", "DATETIME", "TIMESTAMP"):
354+
param_value = f"'{param.value}'"
355+
elif param.type_ == "BOOL":
356+
param_value = str(param.value).upper()
357+
else:
358+
param_value = str(param.value)
359+
sql_for_wrapping = sql_for_wrapping.replace(param_name, param_value)
360+
361+
sql = wrap_in_view_for_dryrun(sql_for_wrapping)
362+
363+
# print(sql)
364+
302365
project = basename(dirname(dirname(dirname(self.sqlfile))))
303366
dataset = basename(dirname(dirname(self.sqlfile)))
304367
try:
@@ -633,6 +696,8 @@ def validate_schema(self):
633696
strip_dml=self.strip_dml,
634697
)
635698

699+
# print(table_schema)
700+
636701
# This check relies on the new schema being deployed to prod
637702
if not query_schema.compatible(table_schema):
638703
click.echo(

bigquery_etl/schema/stable_table_schema.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ def prod_schemas_uri():
6060
"""
6161
dryrun = DryRun(
6262
"moz-fx-data-shared-prod/telemetry_derived/foo/query.sql",
63-
content="SELECT 1",
64-
strip_dml=True,
63+
content="SELECT 1 AS field",
6564
)
6665
build_id = dryrun.get_dataset_labels()["schemas_build_id"]
6766
commit_hash = build_id.split("_")[-1]

sql/moz-fx-data-shared-prod/analysis/bqetl_default_task_v1/query.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
SELECT
33
*
44
FROM
5-
(SELECT 1)
5+
(SELECT 1 AS field)
66
WHERE
77
FALSE

sql_generators/derived_view_schemas/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _get_reference_partition_column(ref_path):
112112
logging.debug("No reference partition column, dry running without one.")
113113

114114
view = View.from_file(
115-
view_file, partition_column=reference_partition_column, id_token=id_token
115+
view_file, partition_column=reference_partition_column, id_token=id_token, strip_dml=True
116116
)
117117

118118
# `View.schema` prioritizes the configured schema over the dryrun schema, but here

tests/test_dryrun.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ def tmp_query_path(tmp_path):
1515
class TestDryRun:
1616
def test_dry_run_sql_file(self, tmp_query_path):
1717
query_file = tmp_query_path / "query.sql"
18-
query_file.write_text("SELECT 123")
18+
query_file.write_text("SELECT 123 AS field")
1919

2020
dryrun = DryRun(str(query_file))
21-
assert dryrun.is_valid()
21+
print(dryrun.dry_run_result)
22+
response = dryrun.dry_run_result
23+
assert response["valid"]
2224

2325
def test_dry_run_invalid_sql_file(self, tmp_query_path):
2426
query_file = tmp_query_path / "query.sql"
@@ -30,7 +32,7 @@ def test_dry_run_invalid_sql_file(self, tmp_query_path):
3032

3133
def test_sql_file_valid(self, tmp_query_path):
3234
query_file = tmp_query_path / "query.sql"
33-
query_file.write_text("SELECT 123")
35+
query_file.write_text("SELECT 123 AS field")
3436

3537
dryrun = DryRun(str(query_file))
3638
assert dryrun.is_valid()
@@ -60,7 +62,7 @@ def test_sql_file_invalid(self, tmp_query_path):
6062

6163
def test_get_referenced_tables_empty(self, tmp_query_path):
6264
query_file = tmp_query_path / "query.sql"
63-
query_file.write_text("SELECT 123")
65+
query_file.write_text("SELECT 123 AS field")
6466

6567
dryrun = DryRun(str(query_file))
6668
assert dryrun.get_referenced_tables() == []
@@ -69,7 +71,7 @@ def test_get_sql(self, tmp_path):
6971
os.makedirs(tmp_path / "telmetry_derived")
7072
query_file = tmp_path / "telmetry_derived" / "query.sql"
7173

72-
sql_content = "SELECT 123 "
74+
sql_content = "SELECT 123 AS field"
7375
query_file.write_text(sql_content)
7476

7577
assert DryRun(sqlfile=str(query_file)).get_sql() == sql_content

0 commit comments

Comments
 (0)