diff --git a/databricks-mcp-server/databricks_mcp_server/tools/sql.py b/databricks-mcp-server/databricks_mcp_server/tools/sql.py index 3483c173..01355525 100644 --- a/databricks-mcp-server/databricks_mcp_server/tools/sql.py +++ b/databricks-mcp-server/databricks_mcp_server/tools/sql.py @@ -151,4 +151,4 @@ def get_table_details( warehouse_id=warehouse_id, ) # Convert to dict for JSON serialization - return result.model_dump() if hasattr(result, "model_dump") else result + return result.model_dump(exclude_none=True) if hasattr(result, "model_dump") else result diff --git a/databricks-tools-core/databricks_tools_core/sql/sql_utils/models.py b/databricks-tools-core/databricks_tools_core/sql/sql_utils/models.py index 09501f7c..2353e546 100644 --- a/databricks-tools-core/databricks_tools_core/sql/sql_utils/models.py +++ b/databricks-tools-core/databricks_tools_core/sql/sql_utils/models.py @@ -202,17 +202,27 @@ def keep_basic_stats(self) -> "TableSchemaResult": ) def remove_stats(self) -> "TableSchemaResult": - """Return a new TableSchemaResult with column_details removed. + """Return a new TableSchemaResult with column statistics removed. - Creates a minimal version with just DDL/structure. + Keeps column names and types but removes all numeric/histogram stats. """ tables_no_stats = [] for table in self.tables: + # Strip stats from column details if they exist + basic_columns = None + if table.column_details: + basic_columns = {} + for col_name, col_detail in table.column_details.items(): + basic_columns[col_name] = ColumnDetail( + name=col_detail.name, + data_type=col_detail.data_type, + ) + table_no_stats = DataSourceInfo( name=table.name, comment=table.comment, ddl=table.ddl, - column_details=None, + column_details=basic_columns, updated_at=None, error=table.error, total_rows=None, diff --git a/databricks-tools-core/databricks_tools_core/sql/sql_utils/table_stats_collector.py b/databricks-tools-core/databricks_tools_core/sql/sql_utils/table_stats_collector.py index ded04f7e..7ed5de26 100644 --- a/databricks-tools-core/databricks_tools_core/sql/sql_utils/table_stats_collector.py +++ b/databricks-tools-core/databricks_tools_core/sql/sql_utils/table_stats_collector.py @@ -151,7 +151,7 @@ def filter_tables_by_patterns(self, tables: List[Dict[str, Any]], patterns: List def get_table_ddl(self, catalog: str, schema: str, table_name: str) -> str: """Get the DDL (CREATE TABLE statement) for a table.""" - full_table_name = f"{catalog}.{schema}.{table_name}" + full_table_name = f"`{catalog}`.`{schema}`.`{table_name}`" query = f"SHOW CREATE TABLE {full_table_name}" try: @@ -213,14 +213,19 @@ def get_table_ddl(self, catalog: str, schema: str, table_name: str) -> str: def collect_column_stats( self, catalog: str, schema: str, table_name: str - ) -> Tuple[Dict[str, ColumnDetail], int, List[Dict[str, Any]]]: + ) -> Tuple[Dict[str, ColumnDetail], Optional[int], List[Dict[str, Any]]]: """ Collect enhanced column statistics for a UC table. + Args: + catalog: Catalog name + schema: Schema name + table_name: Table name + Returns: Tuple of (column_details dict, total_rows, sample_data) """ - full_table_name = f"{catalog}.{schema}.{table_name}" + full_table_name = f"`{catalog}`.`{schema}`.`{table_name}`" return self._collect_stats_for_ref( table_ref=full_table_name, catalog=catalog, @@ -231,7 +236,7 @@ def collect_column_stats( def collect_volume_stats( self, volume_path: str, format: str - ) -> Tuple[Dict[str, ColumnDetail], int, List[Dict[str, Any]]]: + ) -> Tuple[Dict[str, ColumnDetail], Optional[int], List[Dict[str, Any]]]: """ Collect enhanced column statistics for volume folder data. @@ -251,6 +256,32 @@ def collect_volume_stats( fetch_value_counts_table=None, ) + def _describe_columns(self, catalog: str, schema: str, table_name: str) -> Dict[str, ColumnDetail]: + """ + Return column names and types for a UC table without collecting statistics. + + Used by get_table_info when stat level is NONE. + """ + full_table_name = f"`{catalog}`.`{schema}`.`{table_name}`" + try: + describe_result = self.executor.execute( + sql_query=f"DESCRIBE TABLE {full_table_name}", + catalog=catalog, + schema=schema, + timeout=45, + ) + column_details: Dict[str, ColumnDetail] = {} + for col in describe_result or []: + col_name = col.get("col_name") + data_type = col.get("data_type", "string").lower() + if not col_name or col_name.startswith("#") or col_name == "": + continue + column_details[col_name] = ColumnDetail(name=col_name, data_type=data_type) + return column_details + except Exception as e: + logger.warning(f"Failed to describe columns for {full_table_name}: {e}") + return {} + def _collect_stats_for_ref( self, table_ref: str, @@ -258,7 +289,7 @@ def _collect_stats_for_ref( schema: Optional[str], use_describe_table: bool, fetch_value_counts_table: Optional[str], - ) -> Tuple[Dict[str, ColumnDetail], int, List[Dict[str, Any]]]: + ) -> Tuple[Dict[str, ColumnDetail], Optional[int], List[Dict[str, Any]]]: """ Internal method to collect column statistics for any table reference. @@ -305,6 +336,18 @@ def _collect_stats_for_ref( data_type = "string" describe_result.append({"col_name": col_name, "data_type": data_type}) + # Map describe columns to ColumnDetail + column_details = {} + for col in describe_result: + col_name = col.get("col_name") + data_type = col.get("data_type", "string").lower() + if not col_name: + continue + # Handle empty rows/comments in DESCRIBE + if col_name.startswith("#") or col_name == "": + continue + column_details[col_name] = ColumnDetail(name=col_name, data_type=data_type) + # Step 2: Get row count count_result = self.executor.execute( sql_query=f"SELECT COUNT(*) as total_rows FROM {table_ref}", @@ -330,17 +373,18 @@ def _collect_stats_for_ref( data_type = col_info.get("data_type", "").lower() escaped_col = f"`{col_name}`" - # Determine column type + # Determine column type for building query is_numeric = any(t in data_type for t in NUMERIC_TYPES) is_timestamp = "timestamp" in data_type is_array = "array" in data_type + is_struct_or_map = "struct" in data_type or "map" in data_type or "variant" in data_type is_boolean = "boolean" in data_type is_id = any(p in col_name.lower() for p in ID_PATTERNS) and ( "bigint" in data_type or "string" in data_type ) - if is_array: - col_type = "array" + if is_array or is_struct_or_map: + col_type = "complex" elif is_timestamp: col_type = "timestamp" elif is_boolean: @@ -362,7 +406,7 @@ def _collect_stats_for_ref( columns_needing_value_counts.append((col_name, "boolean")) if not union_queries: - return {}, total_rows, [] + return column_details, total_rows, [] # Execute combined stats query combined_query = base_cte + "\nUNION ALL\n".join(union_queries) @@ -372,26 +416,35 @@ def _collect_stats_for_ref( schema=schema, timeout=60, ) + # Step 6: Parse stats results (updates column_details in-place) + self._parse_stats_results(stats_result, column_types, column_details) # Step 4: Get sample data - sample_result = self.executor.execute( - sql_query=f"SELECT * FROM {table_ref} LIMIT {SAMPLE_ROW_COUNT}", - catalog=catalog, - schema=schema, - timeout=45, - ) - # Filter out _rescued_data column from samples - if sample_result: - sample_result = [{k: v for k, v in row.items() if k != "_rescued_data"} for row in sample_result] + sample_result = [] + try: + sample_result = self.executor.execute( + sql_query=f"SELECT * FROM {table_ref} LIMIT {SAMPLE_ROW_COUNT}", + catalog=catalog, + schema=schema, + timeout=45, + ) + # Filter out _rescued_data column from samples + if sample_result: + sample_result = [{k: v for k, v in row.items() if k != "_rescued_data"} for row in sample_result] + except Exception as e: + logger.warning(f"Failed to get sample data for {table_ref}: {e}") # Step 5: Build column samples from sample data column_samples = self._extract_column_samples(describe_result, sample_result) - # Step 6: Parse stats results - column_details = self._parse_stats_results(stats_result, column_types, column_samples) + # Update column details with samples + for col_name, samples in column_samples.items(): + if col_name in column_details: + column_details[col_name].samples = samples # Step 7: Get value counts for categorical columns (only for UC tables) if fetch_value_counts_table: + columns_needing_value_counts = [] for col_name, detail in column_details.items(): if column_types.get(col_name) == "categorical": approx_unique = detail.unique_count or 0 @@ -413,7 +466,7 @@ def _build_column_stats_query( self, col_name: str, escaped_col: str, data_type: str, col_type: str, base_ref: str ) -> str: """Build stats query for a column based on its type.""" - if col_type == "array": + if col_type == "complex": return f""" SELECT '{col_name}' AS column_name, @@ -510,20 +563,22 @@ def _parse_stats_results( self, stats_result: List[Dict], column_types: Dict[str, str], - column_samples: Dict[str, List[str]], - ) -> Dict[str, ColumnDetail]: - """Parse stats query results into ColumnDetail objects.""" - column_details: Dict[str, ColumnDetail] = {} - + column_details: Dict[str, ColumnDetail], + ) -> None: + """Parse stats query results into existing ColumnDetail objects.""" for row in stats_result: col_name = row.get("column_name") - if not col_name: + if not col_name or col_name not in column_details: continue + detail = column_details[col_name] col_type = column_types.get(col_name, "categorical") - samples = column_samples.get(col_name, []) approx_unique = int(row.get("unique_count") or 0) if row.get("unique_count") is not None else None + # Update base stats + detail.null_count = int(row.get("null_count") or 0) if row.get("null_count") is not None else 0 + detail.unique_count = approx_unique + # Parse histogram if present histogram_bins = None if row.get("histogram_data"): @@ -542,65 +597,27 @@ def _parse_stats_results( except Exception as e: logger.debug(f"Failed to parse histogram for {col_name}: {e}") - # Build ColumnDetail based on type + detail.histogram = histogram_bins + + # Update numeric/timestamp/etc based on type and stats row if col_type == "numeric": - detail = ColumnDetail( - name=col_name, - data_type="numeric", - samples=samples, - total_count=int(row.get("total_count") or 0), - null_count=int(row.get("null_count") or 0), - unique_count=approx_unique, - min=float(row["min_val"]) if row.get("min_val") else None, - max=float(row["max_val"]) if row.get("max_val") else None, - avg=float(row["mean_val"]) if row.get("mean_val") else None, - mean=float(row["mean_val"]) if row.get("mean_val") else None, - stddev=float(row["stddev_val"]) if row.get("stddev_val") else None, - q1=float(row["q1_val"]) if row.get("q1_val") else None, - median=float(row["median_val"]) if row.get("median_val") else None, - q3=float(row["q3_val"]) if row.get("q3_val") else None, - histogram=histogram_bins, - ) + detail.total_count = int(row.get("total_count") or 0) + detail.min = float(row["min_val"]) if row.get("min_val") else None + detail.max = float(row["max_val"]) if row.get("max_val") else None + detail.avg = float(row["mean_val"]) if row.get("mean_val") else None + detail.mean = float(row["mean_val"]) if row.get("mean_val") else None + detail.stddev = float(row["stddev_val"]) if row.get("stddev_val") else None + detail.q1 = float(row["q1_val"]) if row.get("q1_val") else None + detail.median = float(row["median_val"]) if row.get("median_val") else None + detail.q3 = float(row["q3_val"]) if row.get("q3_val") else None elif col_type == "timestamp": - detail = ColumnDetail( - name=col_name, - data_type="timestamp", - samples=samples, - total_count=int(row.get("total_count") or 0), - null_count=int(row.get("null_count") or 0), - unique_count=approx_unique, - min_date=str(row["min_val"]) if row.get("min_val") else None, - max_date=str(row["max_val"]) if row.get("max_val") else None, - histogram=histogram_bins, - ) - elif col_type == "array": - detail = ColumnDetail( - name=col_name, - data_type="array", - samples=samples, - total_count=int(row.get("total_count") or 0), - null_count=int(row.get("null_count") or 0), - ) + detail.total_count = int(row.get("total_count") or 0) + detail.min_date = str(row["min_val"]) if row.get("min_val") else None + detail.max_date = str(row["max_val"]) if row.get("max_val") else None else: - # boolean, id, categorical, date - final_type = col_type - if col_type == "categorical" and approx_unique and approx_unique >= MAX_CATEGORICAL_VALUES: - final_type = "string" - - detail = ColumnDetail( - name=col_name, - data_type=final_type, - samples=samples, - total_count=int(row.get("total_count") or 0), - null_count=int(row.get("null_count") or 0), - unique_count=approx_unique, - min=str(row["min_val"]) if row.get("min_val") else None, - max=str(row["max_val"]) if row.get("max_val") else None, - ) - - column_details[col_name] = detail - - return column_details + detail.total_count = int(row.get("total_count") or 0) + detail.min = str(row["min_val"]) if row.get("min_val") else None + detail.max = str(row["max_val"]) if row.get("max_val") else None def _fetch_value_counts( self, @@ -611,7 +628,9 @@ def _fetch_value_counts( column_details: Dict[str, ColumnDetail], ) -> None: """Fetch exact value counts for small-cardinality columns.""" - full_table_name = f"{catalog}.{schema}.{table_name}" + # Ensure table name is properly quoted for SQL + quoted_table = table_name if table_name.startswith("`") else f"`{table_name}`" + full_table_name = f"`{catalog}`.`{schema}`.{quoted_table}" for col_name, _col_type in columns: if col_name not in column_details: @@ -679,16 +698,18 @@ def get_table_info( # Get DDL ddl = self.get_table_ddl(catalog, schema, table_name) - # Collect stats if requested + # Collect schema and stats (or schema-only if collect_stats=False) column_details = None total_rows = None sample_data = None - if collect_stats: - try: + try: + if collect_stats: column_details, total_rows, sample_data = self.collect_column_stats(catalog, schema, table_name) - except Exception as e: - logger.warning(f"Failed to collect stats for {full_table_name}: {e}") + else: + column_details = self._describe_columns(catalog, schema, table_name) + except Exception as e: + logger.warning(f"Failed to collect stats for {full_table_name}: {e}") table_info = TableInfo( name=full_table_name, diff --git a/databricks-tools-core/databricks_tools_core/sql/table_stats.py b/databricks-tools-core/databricks_tools_core/sql/table_stats.py index 24a3cdef..5e5c4ae9 100644 --- a/databricks-tools-core/databricks_tools_core/sql/table_stats.py +++ b/databricks-tools-core/databricks_tools_core/sql/table_stats.py @@ -102,6 +102,7 @@ def get_table_details( table_names = table_names or [] has_patterns = any(_has_glob_pattern(name) for name in table_names) needs_listing = len(table_names) == 0 or has_patterns + failed_tables: List[DataSourceInfo] = [] if needs_listing: # List all tables first @@ -120,9 +121,28 @@ def get_table_details( else: # Direct lookup - build table info without listing logger.debug(f"Direct lookup for tables: {table_names}") - tables_to_fetch = [{"name": name, "updated_at": None, "comment": None} for name in table_names] + tables_to_fetch = [] + for name in table_names: + try: + # Fetch metadata via SDK to get the comment and updated_at + t = collector.client.tables.get(f"{catalog}.{schema}.{name}") + tables_to_fetch.append( + { + "name": t.name, + "updated_at": getattr(t, "updated_at", None), + "comment": getattr(t, "comment", None), + } + ) + except Exception as e: + logger.warning(f"Failed to fetch metadata for {catalog}.{schema}.{name}: {e}") + failed_tables.append( + DataSourceInfo( + name=f"{catalog}.{schema}.{name}", + error=f"Failed to fetch table metadata: {e}", + ) + ) - if not tables_to_fetch: + if not tables_to_fetch and not failed_tables: return TableSchemaResult(catalog=catalog, schema_name=schema, tables=[]) # Determine whether to collect stats @@ -137,6 +157,10 @@ def get_table_details( collect_stats=collect_stats, ) + # Append any tables that failed metadata lookup with their error info + if failed_tables: + table_infos.extend(failed_tables) + # Build result result = TableSchemaResult( catalog=catalog,