Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion databricks-mcp-server/databricks_mcp_server/tools/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,4 @@ def get_table_details(
warehouse_id=warehouse_id,
)
# Convert to dict for JSON serialization
return result.model_dump() if hasattr(result, "model_dump") else result
return result.model_dump(exclude_none=True) if hasattr(result, "model_dump") else result
Original file line number Diff line number Diff line change
Expand Up @@ -202,17 +202,27 @@ def keep_basic_stats(self) -> "TableSchemaResult":
)

def remove_stats(self) -> "TableSchemaResult":
"""Return a new TableSchemaResult with column_details removed.
"""Return a new TableSchemaResult with column statistics removed.

Creates a minimal version with just DDL/structure.
Keeps column names and types but removes all numeric/histogram stats.
"""
tables_no_stats = []
for table in self.tables:
# Strip stats from column details if they exist
basic_columns = None
if table.column_details:
basic_columns = {}
for col_name, col_detail in table.column_details.items():
basic_columns[col_name] = ColumnDetail(
name=col_detail.name,
data_type=col_detail.data_type,
)

table_no_stats = DataSourceInfo(
name=table.name,
comment=table.comment,
ddl=table.ddl,
column_details=None,
column_details=basic_columns,
updated_at=None,
error=table.error,
total_rows=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def filter_tables_by_patterns(self, tables: List[Dict[str, Any]], patterns: List

def get_table_ddl(self, catalog: str, schema: str, table_name: str) -> str:
"""Get the DDL (CREATE TABLE statement) for a table."""
full_table_name = f"{catalog}.{schema}.{table_name}"
full_table_name = f"`{catalog}`.`{schema}`.`{table_name}`"
query = f"SHOW CREATE TABLE {full_table_name}"

try:
Expand Down Expand Up @@ -213,14 +213,19 @@ def get_table_ddl(self, catalog: str, schema: str, table_name: str) -> str:

def collect_column_stats(
self, catalog: str, schema: str, table_name: str
) -> Tuple[Dict[str, ColumnDetail], int, List[Dict[str, Any]]]:
) -> Tuple[Dict[str, ColumnDetail], Optional[int], List[Dict[str, Any]]]:
"""
Collect enhanced column statistics for a UC table.

Args:
catalog: Catalog name
schema: Schema name
table_name: Table name

Returns:
Tuple of (column_details dict, total_rows, sample_data)
"""
full_table_name = f"{catalog}.{schema}.{table_name}"
full_table_name = f"`{catalog}`.`{schema}`.`{table_name}`"
return self._collect_stats_for_ref(
table_ref=full_table_name,
catalog=catalog,
Expand All @@ -231,7 +236,7 @@ def collect_column_stats(

def collect_volume_stats(
self, volume_path: str, format: str
) -> Tuple[Dict[str, ColumnDetail], int, List[Dict[str, Any]]]:
) -> Tuple[Dict[str, ColumnDetail], Optional[int], List[Dict[str, Any]]]:
"""
Collect enhanced column statistics for volume folder data.

Expand All @@ -251,14 +256,40 @@ def collect_volume_stats(
fetch_value_counts_table=None,
)

def _describe_columns(self, catalog: str, schema: str, table_name: str) -> Dict[str, ColumnDetail]:
"""
Return column names and types for a UC table without collecting statistics.

Used by get_table_info when stat level is NONE.
"""
full_table_name = f"`{catalog}`.`{schema}`.`{table_name}`"
try:
describe_result = self.executor.execute(
sql_query=f"DESCRIBE TABLE {full_table_name}",
catalog=catalog,
schema=schema,
timeout=45,
)
column_details: Dict[str, ColumnDetail] = {}
for col in describe_result or []:
col_name = col.get("col_name")
data_type = col.get("data_type", "string").lower()
if not col_name or col_name.startswith("#") or col_name == "":
continue
column_details[col_name] = ColumnDetail(name=col_name, data_type=data_type)
return column_details
except Exception as e:
logger.warning(f"Failed to describe columns for {full_table_name}: {e}")
return {}

def _collect_stats_for_ref(
self,
table_ref: str,
catalog: Optional[str],
schema: Optional[str],
use_describe_table: bool,
fetch_value_counts_table: Optional[str],
) -> Tuple[Dict[str, ColumnDetail], int, List[Dict[str, Any]]]:
) -> Tuple[Dict[str, ColumnDetail], Optional[int], List[Dict[str, Any]]]:
"""
Internal method to collect column statistics for any table reference.

Expand Down Expand Up @@ -305,6 +336,18 @@ def _collect_stats_for_ref(
data_type = "string"
describe_result.append({"col_name": col_name, "data_type": data_type})

# Map describe columns to ColumnDetail
column_details = {}
for col in describe_result:
col_name = col.get("col_name")
data_type = col.get("data_type", "string").lower()
if not col_name:
continue
# Handle empty rows/comments in DESCRIBE
if col_name.startswith("#") or col_name == "":
continue
column_details[col_name] = ColumnDetail(name=col_name, data_type=data_type)

# Step 2: Get row count
count_result = self.executor.execute(
sql_query=f"SELECT COUNT(*) as total_rows FROM {table_ref}",
Expand All @@ -330,17 +373,18 @@ def _collect_stats_for_ref(
data_type = col_info.get("data_type", "").lower()
escaped_col = f"`{col_name}`"

# Determine column type
# Determine column type for building query
is_numeric = any(t in data_type for t in NUMERIC_TYPES)
is_timestamp = "timestamp" in data_type
is_array = "array" in data_type
is_struct_or_map = "struct" in data_type or "map" in data_type or "variant" in data_type
is_boolean = "boolean" in data_type
is_id = any(p in col_name.lower() for p in ID_PATTERNS) and (
"bigint" in data_type or "string" in data_type
)

if is_array:
col_type = "array"
if is_array or is_struct_or_map:
col_type = "complex"
elif is_timestamp:
col_type = "timestamp"
elif is_boolean:
Expand All @@ -362,7 +406,7 @@ def _collect_stats_for_ref(
columns_needing_value_counts.append((col_name, "boolean"))

if not union_queries:
return {}, total_rows, []
return column_details, total_rows, []

# Execute combined stats query
combined_query = base_cte + "\nUNION ALL\n".join(union_queries)
Expand All @@ -372,26 +416,35 @@ def _collect_stats_for_ref(
schema=schema,
timeout=60,
)
# Step 6: Parse stats results (updates column_details in-place)
self._parse_stats_results(stats_result, column_types, column_details)

# Step 4: Get sample data
sample_result = self.executor.execute(
sql_query=f"SELECT * FROM {table_ref} LIMIT {SAMPLE_ROW_COUNT}",
catalog=catalog,
schema=schema,
timeout=45,
)
# Filter out _rescued_data column from samples
if sample_result:
sample_result = [{k: v for k, v in row.items() if k != "_rescued_data"} for row in sample_result]
sample_result = []
try:
sample_result = self.executor.execute(
sql_query=f"SELECT * FROM {table_ref} LIMIT {SAMPLE_ROW_COUNT}",
catalog=catalog,
schema=schema,
timeout=45,
)
# Filter out _rescued_data column from samples
if sample_result:
sample_result = [{k: v for k, v in row.items() if k != "_rescued_data"} for row in sample_result]
except Exception as e:
logger.warning(f"Failed to get sample data for {table_ref}: {e}")

# Step 5: Build column samples from sample data
column_samples = self._extract_column_samples(describe_result, sample_result)

# Step 6: Parse stats results
column_details = self._parse_stats_results(stats_result, column_types, column_samples)
# Update column details with samples
for col_name, samples in column_samples.items():
if col_name in column_details:
column_details[col_name].samples = samples

# Step 7: Get value counts for categorical columns (only for UC tables)
if fetch_value_counts_table:
columns_needing_value_counts = []
for col_name, detail in column_details.items():
if column_types.get(col_name) == "categorical":
approx_unique = detail.unique_count or 0
Expand All @@ -413,7 +466,7 @@ def _build_column_stats_query(
self, col_name: str, escaped_col: str, data_type: str, col_type: str, base_ref: str
) -> str:
"""Build stats query for a column based on its type."""
if col_type == "array":
if col_type == "complex":
return f"""
SELECT
'{col_name}' AS column_name,
Expand Down Expand Up @@ -510,20 +563,22 @@ def _parse_stats_results(
self,
stats_result: List[Dict],
column_types: Dict[str, str],
column_samples: Dict[str, List[str]],
) -> Dict[str, ColumnDetail]:
"""Parse stats query results into ColumnDetail objects."""
column_details: Dict[str, ColumnDetail] = {}

column_details: Dict[str, ColumnDetail],
) -> None:
"""Parse stats query results into existing ColumnDetail objects."""
for row in stats_result:
col_name = row.get("column_name")
if not col_name:
if not col_name or col_name not in column_details:
continue

detail = column_details[col_name]
col_type = column_types.get(col_name, "categorical")
samples = column_samples.get(col_name, [])
approx_unique = int(row.get("unique_count") or 0) if row.get("unique_count") is not None else None

# Update base stats
detail.null_count = int(row.get("null_count") or 0) if row.get("null_count") is not None else 0
detail.unique_count = approx_unique

# Parse histogram if present
histogram_bins = None
if row.get("histogram_data"):
Expand All @@ -542,65 +597,27 @@ def _parse_stats_results(
except Exception as e:
logger.debug(f"Failed to parse histogram for {col_name}: {e}")

# Build ColumnDetail based on type
detail.histogram = histogram_bins

# Update numeric/timestamp/etc based on type and stats row
if col_type == "numeric":
detail = ColumnDetail(
name=col_name,
data_type="numeric",
samples=samples,
total_count=int(row.get("total_count") or 0),
null_count=int(row.get("null_count") or 0),
unique_count=approx_unique,
min=float(row["min_val"]) if row.get("min_val") else None,
max=float(row["max_val"]) if row.get("max_val") else None,
avg=float(row["mean_val"]) if row.get("mean_val") else None,
mean=float(row["mean_val"]) if row.get("mean_val") else None,
stddev=float(row["stddev_val"]) if row.get("stddev_val") else None,
q1=float(row["q1_val"]) if row.get("q1_val") else None,
median=float(row["median_val"]) if row.get("median_val") else None,
q3=float(row["q3_val"]) if row.get("q3_val") else None,
histogram=histogram_bins,
)
detail.total_count = int(row.get("total_count") or 0)
detail.min = float(row["min_val"]) if row.get("min_val") else None
detail.max = float(row["max_val"]) if row.get("max_val") else None
detail.avg = float(row["mean_val"]) if row.get("mean_val") else None
detail.mean = float(row["mean_val"]) if row.get("mean_val") else None
detail.stddev = float(row["stddev_val"]) if row.get("stddev_val") else None
detail.q1 = float(row["q1_val"]) if row.get("q1_val") else None
detail.median = float(row["median_val"]) if row.get("median_val") else None
detail.q3 = float(row["q3_val"]) if row.get("q3_val") else None
elif col_type == "timestamp":
detail = ColumnDetail(
name=col_name,
data_type="timestamp",
samples=samples,
total_count=int(row.get("total_count") or 0),
null_count=int(row.get("null_count") or 0),
unique_count=approx_unique,
min_date=str(row["min_val"]) if row.get("min_val") else None,
max_date=str(row["max_val"]) if row.get("max_val") else None,
histogram=histogram_bins,
)
elif col_type == "array":
detail = ColumnDetail(
name=col_name,
data_type="array",
samples=samples,
total_count=int(row.get("total_count") or 0),
null_count=int(row.get("null_count") or 0),
)
detail.total_count = int(row.get("total_count") or 0)
detail.min_date = str(row["min_val"]) if row.get("min_val") else None
detail.max_date = str(row["max_val"]) if row.get("max_val") else None
else:
# boolean, id, categorical, date
final_type = col_type
if col_type == "categorical" and approx_unique and approx_unique >= MAX_CATEGORICAL_VALUES:
final_type = "string"

detail = ColumnDetail(
name=col_name,
data_type=final_type,
samples=samples,
total_count=int(row.get("total_count") or 0),
null_count=int(row.get("null_count") or 0),
unique_count=approx_unique,
min=str(row["min_val"]) if row.get("min_val") else None,
max=str(row["max_val"]) if row.get("max_val") else None,
)

column_details[col_name] = detail

return column_details
detail.total_count = int(row.get("total_count") or 0)
detail.min = str(row["min_val"]) if row.get("min_val") else None
detail.max = str(row["max_val"]) if row.get("max_val") else None

def _fetch_value_counts(
self,
Expand All @@ -611,7 +628,9 @@ def _fetch_value_counts(
column_details: Dict[str, ColumnDetail],
) -> None:
"""Fetch exact value counts for small-cardinality columns."""
full_table_name = f"{catalog}.{schema}.{table_name}"
# Ensure table name is properly quoted for SQL
quoted_table = table_name if table_name.startswith("`") else f"`{table_name}`"
full_table_name = f"`{catalog}`.`{schema}`.{quoted_table}"

for col_name, _col_type in columns:
if col_name not in column_details:
Expand Down Expand Up @@ -679,16 +698,18 @@ def get_table_info(
# Get DDL
ddl = self.get_table_ddl(catalog, schema, table_name)

# Collect stats if requested
# Collect schema and stats (or schema-only if collect_stats=False)
column_details = None
total_rows = None
sample_data = None

if collect_stats:
try:
try:
if collect_stats:
column_details, total_rows, sample_data = self.collect_column_stats(catalog, schema, table_name)
except Exception as e:
logger.warning(f"Failed to collect stats for {full_table_name}: {e}")
else:
column_details = self._describe_columns(catalog, schema, table_name)
except Exception as e:
logger.warning(f"Failed to collect stats for {full_table_name}: {e}")

table_info = TableInfo(
name=full_table_name,
Expand Down
Loading