Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bigquery_etl/cli/dryrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def _sql_file_valid(
respect_skip=respect_skip,
id_token=id_token,
billing_project=billing_project,
strip_dml=True if validate_schemas else False,
)
if validate_schemas:
try:
Expand Down
116 changes: 115 additions & 1 deletion bigquery_etl/dryrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@

import click
import google.auth
import sqlglot
from google.auth.transport.requests import Request as GoogleAuthRequest
from google.cloud import bigquery
from google.oauth2.id_token import fetch_id_token

from .config import ConfigLoader
from .metadata.parse_metadata import Metadata
from .util.common import render
from .util.common import random_str, render

try:
from functools import cached_property # type: ignore
Expand Down Expand Up @@ -79,6 +80,89 @@ def get_id_token(dry_run_url=ConfigLoader.get("dry_run", "function"), credential
return id_token


def wrap_in_view_for_dryrun(sql: str) -> str:
"""
Wrap SELECT queries in CREATE VIEW statement for faster dry runs.

CREATE VIEW statements don't scan partition metadata which makes dry runs faster.
"""
try:
statements = [
stmt for stmt in sqlglot.parse(sql, dialect="bigquery") if stmt is not None
]

# Only wrap if the last statement is a SELECT statement
if not statements or not isinstance(statements[-1], sqlglot.exp.Select):
return sql

# Replace CREATE TEMP FUNCTION with CREATE FUNCTION using fully qualified names
# CREATE VIEW doesn't support temp functions
test_project = ConfigLoader.get(
"default", "test_project", fallback="bigquery-etl-integration-test"
)

def replace_temp_function(match):
func_name = match.group(1)
# If function name is already qualified, keep it; otherwise add project.dataset prefix
if "." not in func_name and "`" not in func_name:
return f"CREATE FUNCTION `{test_project}.tmp.{func_name}`"
else:
return f"CREATE FUNCTION {func_name}"

sql = re.sub(
r"\bCREATE\s+TEMP(?:ORARY)?\s+FUNCTION\s+([^\s(]+)",
replace_temp_function,
sql,
flags=re.IGNORECASE,
)

# Single statement - just wrap it
if len(statements) == 1:
view_name = f"_dryrun_view_{random_str(8)}"
test_project = ConfigLoader.get(
"default", "test_project", fallback="bigquery-etl-integration-test"
)
query_sql = sql.strip().rstrip(";")
return f"CREATE VIEW `{test_project}.tmp.{view_name}` AS\n{query_sql}"

# Multiple statements: use sqlglot tokenizer to find statement boundaries
# This handles semicolons in strings and comments
tokens = list(sqlglot.tokens.Tokenizer(dialect="bigquery").tokenize(sql))

# Find semicolon tokens that separate statements (not in strings/comments)
semicolon_positions = []
for token in tokens:
if token.token_type == sqlglot.tokens.TokenType.SEMICOLON:
semicolon_positions.append(token.end)

# We need (len(statements) - 1) semicolons to separate statements
if len(semicolon_positions) >= len(statements) - 1:
# The (n-1)th semicolon separates the prefix from the last statement
split_pos = semicolon_positions[len(statements) - 2]
prefix_sql = sql[:split_pos].strip()
query_sql = sql[split_pos:].strip().lstrip(";").strip()
else:
# Fallback: regenerate prefix statements, use regenerated query
prefix_statements = statements[:-1]
prefix_sql = ";\n".join(
stmt.sql(dialect="bigquery") for stmt in prefix_statements
)
query_sql = statements[-1].sql(dialect="bigquery")

# Wrap in view
view_name = f"_dryrun_view_{random_str(8)}"
test_project = ConfigLoader.get(
"default", "test_project", fallback="bigquery-etl-integration-test"
)
wrapped_query = f"CREATE VIEW `{test_project}.tmp.{view_name}` AS\n{query_sql}"

return f"{prefix_sql};\n\n{wrapped_query}"

except Exception as e:
print(f"Warning: Failed to wrap SQL in view: {e}")
return sql


class Errors(Enum):
"""DryRun errors that require special handling."""

Expand Down Expand Up @@ -254,6 +338,30 @@ def dry_run_result(self):
)
)

# Wrap the query in a CREATE VIEW for faster dry runs
# Skip wrapping when strip_dml=True as it's used for special analysis modes
if not self.strip_dml:
# If query has parameters, replace them with literal values in the wrapped version
# since CREATE VIEW cannot use parameterized queries
sql_for_wrapping = sql
if query_parameters:
for param in query_parameters:
param_name = f"@{param.name}"
# Convert parameter value to SQL literal
if param.type_ == "DATE":
param_value = f"DATE '{param.value}'"
elif param.type_ in ("STRING", "DATETIME", "TIMESTAMP"):
param_value = f"'{param.value}'"
elif param.type_ == "BOOL":
param_value = str(param.value).upper()
else:
param_value = str(param.value)
sql_for_wrapping = sql_for_wrapping.replace(param_name, param_value)

sql = wrap_in_view_for_dryrun(sql_for_wrapping)

# print(sql)

project = basename(dirname(dirname(dirname(self.sqlfile))))
dataset = basename(dirname(dirname(self.sqlfile)))
try:
Expand Down Expand Up @@ -387,6 +495,7 @@ def get_referenced_tables(self):
filtered_content,
client=self.client,
id_token=self.id_token,
strip_dml=self.strip_dml,
).get_error()
== Errors.DATE_FILTER_NEEDED_AND_SYNTAX
):
Expand All @@ -408,6 +517,7 @@ def get_referenced_tables(self):
content=filtered_content,
client=self.client,
id_token=self.id_token,
strip_dml=self.strip_dml,
).get_error()
== Errors.DATE_FILTER_NEEDED_AND_SYNTAX
):
Expand All @@ -420,6 +530,7 @@ def get_referenced_tables(self):
content=filtered_content,
client=self.client,
id_token=self.id_token,
strip_dml=self.strip_dml,
)
if (
stripped_dml_result.get_error() is None
Expand Down Expand Up @@ -582,8 +693,11 @@ def validate_schema(self):
client=self.client,
id_token=self.id_token,
partitioned_by=partitioned_by,
strip_dml=self.strip_dml,
)

# print(table_schema)

# This check relies on the new schema being deployed to prod
if not query_schema.compatible(table_schema):
click.echo(
Expand Down
3 changes: 2 additions & 1 deletion bigquery_etl/schema/stable_table_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def prod_schemas_uri():
with the most recent production schemas deploy.
"""
dryrun = DryRun(
"moz-fx-data-shared-prod/telemetry_derived/foo/query.sql", content="SELECT 1"
"moz-fx-data-shared-prod/telemetry_derived/foo/query.sql",
content="SELECT 1 AS field",
)
build_id = dryrun.get_dataset_labels()["schemas_build_id"]
commit_hash = build_id.split("_")[-1]
Expand Down
5 changes: 4 additions & 1 deletion bigquery_etl/view/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,10 @@ def dryrun_schema(self):
"""
)
return Schema.from_query_file(
Path(self.path), content=schema_query, id_token=self.id_token
Path(self.path),
content=schema_query,
id_token=self.id_token,
strip_dml=True,
)
except Exception as e:
print(f"Error dry-running view {self.view_identifier} to get schema: {e}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
SELECT
*
FROM
(SELECT 1)
(SELECT 1 AS field)
WHERE
FALSE
1 change: 1 addition & 0 deletions sql_generators/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ The directories in `sql_generators/` represent the generated queries and will co
Each `__init__.py` file needs to implement a `generate()` method that is configured as a [click command](https://click.palletsprojects.com/en/8.0.x/). The `bqetl` CLI will automatically add these commands to the `./bqetl query generate` command group.

After changes to a schema or adding new tables, the schema is automatically derived from the query and deployed the next day in DAG [bqetl_artifact_deployment](https://workflow.telemetry.mozilla.org/dags/bqetl_artifact_deployment/grid). Alternatively, it can be manually generated and deployed using `./bqetl generate all` and `./bqetl query schema deploy`.

2 changes: 1 addition & 1 deletion sql_generators/stable_views/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def write_view_if_not_exists(
content = VIEW_CREATE_REGEX.sub("", target_file.read_text())
content += " WHERE DATE(submission_timestamp) = '2020-01-01'"
view_schema = Schema.from_query_file(
target_file, content=content, sql_dir=sql_dir, id_token=id_token
target_file, content=content, sql_dir=sql_dir, id_token=id_token, strip_dml=True
)

stable_table_schema = Schema.from_json({"fields": schema.schema})
Expand Down
11 changes: 6 additions & 5 deletions tests/test_dryrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ def tmp_query_path(tmp_path):
class TestDryRun:
def test_dry_run_sql_file(self, tmp_query_path):
query_file = tmp_query_path / "query.sql"
query_file.write_text("SELECT 123")
query_file.write_text("SELECT 123 AS field")

dryrun = DryRun(str(query_file))
print(dryrun.dry_run_result)
response = dryrun.dry_run_result
assert response["valid"]

Expand All @@ -31,7 +32,7 @@ def test_dry_run_invalid_sql_file(self, tmp_query_path):

def test_sql_file_valid(self, tmp_query_path):
query_file = tmp_query_path / "query.sql"
query_file.write_text("SELECT 123")
query_file.write_text("SELECT 123 AS field")

dryrun = DryRun(str(query_file))
assert dryrun.is_valid()
Expand Down Expand Up @@ -61,7 +62,7 @@ def test_sql_file_invalid(self, tmp_query_path):

def test_get_referenced_tables_empty(self, tmp_query_path):
query_file = tmp_query_path / "query.sql"
query_file.write_text("SELECT 123")
query_file.write_text("SELECT 123 AS field")

dryrun = DryRun(str(query_file))
assert dryrun.get_referenced_tables() == []
Expand All @@ -70,7 +71,7 @@ def test_get_sql(self, tmp_path):
os.makedirs(tmp_path / "telmetry_derived")
query_file = tmp_path / "telmetry_derived" / "query.sql"

sql_content = "SELECT 123 "
sql_content = "SELECT 123 AS field"
query_file.write_text(sql_content)

assert DryRun(sqlfile=str(query_file)).get_sql() == sql_content
Expand All @@ -83,7 +84,7 @@ def test_get_referenced_tables(self, tmp_query_path):
"SELECT * FROM `moz-fx-data-shared-prod.telemetry_derived.clients_daily_v6` "
"WHERE submission_date = '2020-01-01'"
)
query_dryrun = DryRun(str(query_file)).get_referenced_tables()
query_dryrun = DryRun(str(query_file), strip_dml=True).get_referenced_tables()

assert len(query_dryrun) == 1
assert query_dryrun[0]["datasetId"] == "telemetry_derived"
Expand Down