diff --git a/.test/README.md b/.test/README.md index d5c8fe46..82b0de43 100644 --- a/.test/README.md +++ b/.test/README.md @@ -1,23 +1,585 @@ -# Skill Testing Framework +# Skill Testing & Optimization Framework -Test Databricks skills with real execution on serverless compute. +Evaluate and optimize Databricks SKILL.md files using automated scorers and [GEPA](https://github.com/gepa-ai/gepa)-powered optimization. -**Note:** This framework is for contributors only and is not distributed via install_skills.sh. +## Quick Start: Optimize a Skill + +One command evaluates a skill's current quality, runs GEPA optimization, and shows the results: + +```bash +uv run python .test/scripts/optimize.py databricks-model-serving --preset quick --apply +``` + +This will: +1. Load the SKILL.md and its test cases from `ground_truth.yaml` +2. Have a generation model (Sonnet) produce responses using ONLY the skill, then score those responses +3. Also score the SKILL.md itself for pattern/fact coverage +4. Run GEPA's optimization loop (reflect on failures, propose mutations, select via Pareto frontier) +5. Show a diff with quality improvement and token reduction +6. Apply the optimized SKILL.md back to disk ## Setup ```bash -uv pip install -e ".test/[dev]" -.test/install_skill_test.sh +# Install with optimization dependencies +uv pip install -e ".test/[all]" + +# Authentication for models (pick one) +# Option A: Databricks Model Serving (default for both gen + reflection) +export DATABRICKS_API_KEY="dapi..." +export DATABRICKS_API_BASE="https://.cloud.databricks.com/serving-endpoints" + +# Option B: OpenAI +export OPENAI_API_KEY="sk-..." +export GEPA_REFLECTION_LM="openai/gpt-4o" +export GEPA_GEN_LM="openai/gpt-4o" + +# Optional: override generation model (default: databricks/databricks-claude-sonnet-4-6) +# export GEPA_GEN_LM="databricks/databricks-claude-sonnet-4-6" + +# Optional: set a global token budget ceiling for optimization +# export GEPA_TOKEN_BUDGET=50000 +``` + +--- + +## Optimization Commands + +### Evaluate + Optimize a Skill + +```bash +# Standard optimization (50 iterations per component, up to 5 passes) +uv run python .test/scripts/optimize.py + +# Quick pass (15 iterations, good for initial check) +uv run python .test/scripts/optimize.py --preset quick + +# Thorough optimization (150 iterations, production quality) +uv run python .test/scripts/optimize.py --preset thorough + +# Dry run: see scores and config without calling GEPA +uv run python .test/scripts/optimize.py --dry-run + +# Apply the last saved result (no re-run!) +uv run python .test/scripts/optimize.py --apply-last + +# Run optimization and immediately apply +uv run python .test/scripts/optimize.py --apply + +# Use a specific generation model for evaluation +uv run python .test/scripts/optimize.py --gen-model "openai/gpt-4o" + +# Control iteration depth (default: 5 passes) +uv run python .test/scripts/optimize.py --max-passes 3 + +# Set a token budget ceiling (candidates exceeding this are penalized) +uv run python .test/scripts/optimize.py --token-budget 50000 + +# Optimize all skills that have test cases +uv run python .test/scripts/optimize.py --all --preset quick +``` + +After each run, the optimized result is automatically saved to `.test/skills//optimized_SKILL.md`. You can review it, diff it against the original, and apply when ready with `--apply-last` — no need to re-run the optimization. + +### Optimize MCP Tool Descriptions + +GEPA can also optimize the `@mcp.tool` docstrings in `databricks-mcp-server/`. Tool descriptions are what the AI agent sees when deciding which tool to call -- concise, accurate descriptions lead to better tool selection. + +```bash +# Optimize a skill AND its related tool modules together +uv run python .test/scripts/optimize.py databricks-model-serving --include-tools --tool-modules serving sql + +# Optimize specific tool modules alongside a skill +uv run python .test/scripts/optimize.py databricks-model-serving --include-tools --tool-modules serving compute jobs + +# Optimize ALL tool modules alongside a skill +uv run python .test/scripts/optimize.py databricks-model-serving --include-tools + +# Optimize ONLY tool descriptions (no SKILL.md) +uv run python .test/scripts/optimize.py databricks-model-serving --tools-only --tool-modules serving + +# Dry run to see components and token counts +uv run python .test/scripts/optimize.py databricks-model-serving --include-tools --dry-run ``` -Requires a Databricks workspace with serverless SQL/compute enabled. +When `--include-tools` is used, GEPA creates one component per tool module (e.g., `tools_sql`, `tools_serving`) and round-robins through them alongside `skill_md`. The `--apply` flag writes optimized docstrings back to the MCP server source files. + +**Note:** The SkillBench evaluator (default) automatically includes tools even without `--include-tools`, since tool descriptions are the primary token consumer (~17K tokens across 88 tools). Use `--tools-only` to optimize only tool descriptions without the SKILL.md. + +The iteration budget (`max_metric_calls`) is automatically scaled by the number of components so each one gets the preset's full budget. Additionally, the optimizer runs **up to 5 passes** (configurable with `--max-passes`), re-seeding from the previous best each time. It stops early if a pass produces no improvement. + +Available tool modules (88 tools across 16 modules): +`agent_bricks`, `aibi_dashboards`, `apps`, `compute`, `file`, `genie`, `jobs`, `lakebase`, `manifest`, `pipelines`, `serving`, `sql`, `unity_catalog`, `user`, `vector_search`, `volume_files` + +### Changing the Generation Model + +The evaluator uses a **generation model** to simulate an agent reading the SKILL.md and producing a response. Better skill content leads to better generated responses, which drives GEPA to make meaningful improvements. + +The default is `databricks/databricks-claude-sonnet-4-6`. + +| Method | Example | +|--------|---------| +| Environment variable | `export GEPA_GEN_LM="databricks/databricks-claude-sonnet-4-6"` | +| CLI flag | `--gen-model "openai/gpt-4o"` | +| Python | `optimize_skill("my-skill", gen_model="anthropic/claude-sonnet-4-5-20250514")` | + +### Changing the Reflection Model + +GEPA uses a reflection LM to analyze scorer failures and propose skill improvements. The default is **Databricks Model Serving** (`databricks-claude-opus-4-6`, 200K context). + +| Method | Example | +|--------|---------| +| Environment variable | `export GEPA_REFLECTION_LM="databricks/databricks-claude-opus-4-6"` | +| CLI flag | `--reflection-lm "openai/gpt-4o"` | +| Python | `optimize_skill("my-skill", reflection_lm="anthropic/claude-sonnet-4-5-20250514")` | + +Model strings use [litellm provider prefixes](https://docs.litellm.ai/docs/providers): + +| Provider | Prefix | Example | Context | +|----------|--------|---------|---------| +| Databricks Model Serving | `databricks/` | `databricks/databricks-claude-opus-4-6` | 200K | +| OpenAI | `openai/` | `openai/gpt-4o` | 128K | +| Anthropic | `anthropic/` | `anthropic/claude-sonnet-4-5-20250514` | 200K | + +**Context window requirement:** The reflection model must have a context window large enough to hold the full candidate (all components) plus GEPA's reflection overhead (~3x the raw candidate tokens). Models with small context windows (e.g., 8K) will fail with `BadRequestError` during reflection. The optimizer validates this upfront and warns if the model is too small. + +### Authentication + +| Provider | Required Environment Variables | +|----------|-------------------------------| +| Databricks | `DATABRICKS_API_KEY`, `DATABRICKS_API_BASE` (must end with `/serving-endpoints`) | +| OpenAI | `OPENAI_API_KEY` | +| Anthropic | `ANTHROPIC_API_KEY` | + +Optional overrides: `GEPA_REFLECTION_LM` (reflection model), `GEPA_GEN_LM` (generation model for evaluation), `GEPA_TOKEN_BUDGET` (token ceiling for optimization). + +--- + +## How Evaluation Works + +The evaluation system answers a single question: **does this SKILL.md teach an AI agent the right things?** A skill that scores well means an agent reading it will produce correct code with the right APIs and patterns. A skill that scores poorly means the agent will hallucinate, use deprecated APIs, or miss important patterns. + +Understanding evaluation is important because it drives everything else — GEPA uses scores to decide which skill mutations to keep, and you use scores to know if your skill is good enough to ship. + +### Why These Files Exist + +Each skill under `.test/skills//` has two key files: + +**`ground_truth.yaml`** — The test cases. Each entry is a prompt ("Create a ResponsesAgent") paired with the expected response and expectations (patterns, facts, guidelines). These define *what the skill should teach*. Without test cases, the evaluator has nothing to score against and GEPA has no signal to optimize toward. + +**`manifest.yaml`** — The scorer configuration. Controls *which scorers run* and *what thresholds apply*. Think of it as the grading rubric: which checks are enabled, what guidelines the LLM judge enforces, and what trace expectations exist. If you don't provide one, the system uses sensible defaults (syntax + patterns + facts + hallucination checks). + +The test cases in `ground_truth.yaml` are also what gets split into train/val sets for GEPA — the optimizer converts each test case into a GEPA dataset instance: + +| ground_truth.yaml field | GEPA field | Purpose | +|------------------------|------------|---------| +| `inputs.prompt` | `input` | The task the reflection LM sees | +| `outputs.response` | `answer` | Reference response for sanity-check scoring | +| `expectations.*` | `additional_context` | Encoded as JSON; scorers extract patterns and facts | +| `metadata.category` | (stratification) | Ensures balanced train/val split | + +### The Layered Evaluation + +Rather than scoring a single static response, the evaluator runs five layers that give GEPA progressively richer signal: + +| Layer | Weight | What it does | Source | +|-------|--------|-------------|--------| +| **Generated response quality** | 20% | An LLM reads ONLY the SKILL.md and answers the test prompt. Its response is scored for patterns/facts. | `evaluator.py` → litellm generation | +| **Skill content coverage** | 35% | Checks if the SKILL.md itself contains the patterns and facts needed. If a pattern is missing from the skill, this drops immediately. | `evaluator.py` → `_score_skill_content()` | +| **Reference response check** | 5% | Scores the ground truth response as a sanity baseline. This is mostly static — it ensures the test case itself is valid. | `evaluator.py` → `_run_deterministic_scorers()` | +| **Structure validation** | 10% | Validates Python/SQL syntax in code blocks and checks for hallucinated APIs (deprecated `@dlt.table`, old `mlflow.evaluate`, etc). | `evaluator.py` → `_validate_skill_structure()` | +| **Token efficiency** | 30% | Rewards concise skill content. Shrinking below original size earns a bonus (up to 1.15x), same size = 1.0, linear penalty to 0.0 at 2x original. | `evaluator.py` → token counting | + +**Why this works:** The key insight is that Layer 1 (generated response) creates a causal chain — if the SKILL.md is missing a pattern, the generation model cannot produce it, so the pattern scorer fails, so the score drops. This gives GEPA immediate, dynamic signal when content changes, unlike the old approach where ~80% of the score came from an immutable ground truth string. + +**Fallback mode:** When no generation model is available (no `GEPA_GEN_LM`), the weights shift to 40% skill content + 20% reference + 10% structure + 30% efficiency. + +### SkillBench Evaluator (Default) + +The default evaluator (`--evaluator skillbench`) measures **skill effectiveness**: how much does the skill help an agent answer correctly? It runs each test case twice — once WITH the skill and once WITHOUT — then scores the delta. + +| Weight | Dimension | What it measures | +|--------|-----------|-----------------| +| **45%** | Skill Effectiveness | `pass_rate_with - pass_rate_without` — the delta. Only rewards content the agent doesn't already know. | +| **25%** | Absolute Quality | `pass_rate_with` — overall correctness with the skill present. | +| **5%** | Structure | Syntax validity (Python/SQL) and no hallucinated APIs. | +| **25%** | Token Efficiency | Smaller candidates score higher. Linear penalty for growth (0.0 at 2x original). Bonus for reduction (up to 1.15 at 0% of original). | + +**Key difference from the legacy evaluator:** SkillBench uses binary pass/fail assertions (from `expectations` in `ground_truth.yaml`) rather than fuzzy scorer scores. Assertions are classified as: +- **NEEDS_SKILL** — fails both with and without the skill (the skill must teach this) +- **REGRESSION** — passes without, fails with (the skill confuses the agent — simplify or remove) +- **POSITIVE** — fails without, passes with (the skill is helping — keep it) +- **NEUTRAL** — same result either way (the agent already knows this — adding it wastes tokens) + +The reflection LM sees these labels in the `Error` field of each example's side info, guiding it to add NEEDS_SKILL content and remove REGRESSION content. + +**Token budget:** Use `--token-budget N` to set a hard ceiling. Candidates exceeding the budget receive a steep penalty on top of the normal efficiency score. Set via CLI or `GEPA_TOKEN_BUDGET` env var. + +To use the legacy evaluator instead: `--evaluator legacy`. + +### Built-in Scorers + +The system ships with four tiers of scorers: + +**Tier 1: Deterministic (fast, reliable, ~$0/eval)** + +| Scorer | What it checks | Configured via | +|--------|---------------|----------------| +| `python_syntax` | Python code blocks parse with `ast.parse()` | `manifest.yaml` → `scorers.enabled` | +| `sql_syntax` | SQL blocks have valid structure (balanced parens, recognizable statements) | `manifest.yaml` → `scorers.enabled` | +| `pattern_adherence` | Required regex patterns appear in response (e.g., `ResponsesAgent`, `CLUSTER BY`) | `ground_truth.yaml` → `expectations.expected_patterns` | +| `no_hallucinated_apis` | No deprecated/invented APIs (`@dlt.table`, `dlt.read`, `PARTITION BY`, old `mlflow.evaluate`) | `manifest.yaml` → `scorers.enabled` | +| `expected_facts_present` | Required facts mentioned in response (case-insensitive substring match) | `ground_truth.yaml` → `expectations.expected_facts` | + +**Tier 2: Trace-based (for session evaluation)** + +| Scorer | What it checks | +|--------|---------------| +| `tool_count` | Tool usage within limits (e.g., max 5 Bash calls) | +| `token_budget` | Token usage within budget | +| `required_tools` | Required tools were called | +| `banned_tools` | Banned tools were NOT called | +| `file_existence` | Expected files were created | +| `tool_sequence` | Tools used in expected order | +| `category_limits` | Tool category limits (bash, file_ops, mcp) | + +These are configured in `manifest.yaml` under `trace_expectations`. + +**Tier 3: LLM judges (expensive, nuanced, ~$0.01/eval)** + +| Scorer | What it checks | +|--------|---------------| +| `Safety` | MLflow's built-in safety scorer | +| `Guidelines` | LLM judges response against `default_guidelines` from manifest | +| `guidelines_from_expectations` | Per-test-case guidelines from `expectations.guidelines` in ground_truth.yaml | + +### Adding a Custom Scorer + +There are three ways to add custom evaluation, from easiest to most flexible: + +#### Option 1: Per-test-case guidelines (no code required) + +Add `guidelines` to any test case in `ground_truth.yaml`. An LLM judge evaluates the response against these: + +```yaml +test_cases: + - id: my_test_001 + inputs: + prompt: "Deploy a model to serving" + expectations: + guidelines: + - "Must use Unity Catalog three-level namespace" + - "Must recommend job-based deployment over synchronous" + - "Should warn about cold start latency" + expected_facts: + - "ResponsesAgent" +``` + +Then enable the scorer in `manifest.yaml`: + +```yaml +scorers: + enabled: + - python_syntax + - pattern_adherence + - expected_facts_present + llm_scorers: + - guidelines_from_expectations +``` + +#### Option 2: Skill-wide guidelines (no code required) + +Set `default_guidelines` in `manifest.yaml` to apply rules to ALL test cases for a skill: + +```yaml +scorers: + enabled: + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Guidelines + default_guidelines: + - "Must use ResponsesAgent pattern, not ChatAgent" + - "Must use self.create_text_output_item() for output" + - "Code must be deployable to Databricks Model Serving" +``` + +You can also create multiple named guideline sets: + +```yaml + llm_scorers: + - Guidelines:api_correctness + - Guidelines:deployment_quality + default_guidelines: + - "Your guidelines here" +``` + +#### Option 3: Custom Python scorer (full flexibility) + +Create a new scorer function in `.test/src/skill_test/scorers/` and register it. Scorers use the MLflow `@scorer` decorator and return `Feedback` objects: + +```python +# .test/src/skill_test/scorers/my_custom.py +from mlflow.genai.scorers import scorer +from mlflow.entities import Feedback +from typing import Dict, Any + +@scorer +def my_custom_check(outputs: Dict[str, Any], expectations: Dict[str, Any]) -> Feedback: + """Check for something specific to my use case.""" + response = outputs.get("response", "") + + # Your custom logic here + issues = [] + if "spark.sql(" in response and "spark.read.table(" not in response: + issues.append("Should prefer spark.read.table() over spark.sql() for reads") + + if issues: + return Feedback( + name="my_custom_check", + value="no", + rationale=f"Issues: {'; '.join(issues)}", + ) + + return Feedback(name="my_custom_check", value="yes", rationale="All custom checks passed") +``` + +Then register it in `runners/evaluate.py` → `build_scorers()`: + +```python +SCORER_MAP = { + # ... existing scorers ... + "my_custom_check": my_custom_check, +} +``` + +And enable it in your skill's `manifest.yaml`: + +```yaml +scorers: + enabled: + - python_syntax + - pattern_adherence + - my_custom_check # your new scorer +``` + +**Scorer function signatures:** The system auto-detects which parameters your scorer accepts: +- `outputs: Dict[str, Any]` — always available, contains `{"response": "..."}` +- `expectations: Dict[str, Any]` — from ground_truth.yaml `expectations` field +- `inputs: Dict[str, Any]` — contains `{"prompt": "..."}` + +Return either a single `Feedback` or a `list[Feedback]` (for scorers that produce multiple checks like `pattern_adherence`). + +### Manifest Configuration Examples + +Here are manifest patterns for different skill types: + +**Python SDK skill** — emphasizes syntax and API correctness: +```yaml +scorers: + enabled: [python_syntax, pattern_adherence, no_hallucinated_apis, expected_facts_present] + llm_scorers: [guidelines_from_expectations] + default_guidelines: + - "Must use ResponsesAgent pattern for GenAI agents" +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 +``` + +**SQL-heavy skill** — adds SQL validation: +```yaml +scorers: + enabled: [python_syntax, sql_syntax, pattern_adherence, no_hallucinated_apis, expected_facts_present] + default_guidelines: + - "Must use SDP syntax (CREATE OR REFRESH STREAMING TABLE)" +``` + +**Skill with trace expectations** — limits tool usage during session evaluation: +```yaml +scorers: + enabled: [python_syntax, pattern_adherence, no_hallucinated_apis, expected_facts_present] + default_guidelines: + - "Must use correct MCP tools (manage_ka, manage_mas)" + trace_expectations: + tool_limits: + manage_ka: 10 + manage_mas: 10 + required_tools: [Read] + banned_tools: [] +``` + +--- + +## Best Practices for Optimization + +These practices are derived from the [optimize_anything API guide](https://gepa-ai.github.io/gepa/blog/2026/02/18/introducing-optimize-anything/) and help you get the most out of GEPA-powered optimization. + +### 1. Write Rich Evaluators with Actionable Side Information (ASI) + +GEPA's reflection LM learns from diagnostic feedback, not just scores. The more context your evaluator surfaces, the better the proposals. + +- **Return diagnostics**, not just a number. Use `oa.log()` or return `(score, side_info_dict)` to surface error messages, failing patterns, and missing facts. +- **Break scoring into multiple dimensions**. Rather than a single composite score, provide separate scores for syntax, pattern adherence, API accuracy, and conciseness. GEPA's Pareto-efficient selection preserves candidates that excel in different dimensions. +- Our built-in evaluator already does this -- it returns per-scorer feedback via `feedback_to_asi()`. + +### 2. Build a Diverse Dataset of Test Cases + +GEPA operates in three modes depending on what data you provide: +- **Single-task** (no dataset): evaluator scores the artifact directly +- **Multi-task** (dataset only): Pareto-efficient search across tasks +- **Generalization** (dataset + valset): trains on tasks, validates on held-out examples + +For best results: +- Aim for **5+ test cases** to enable a train/val split (Generalization mode). Fewer than 5 defaults to single-task mode. +- Cover **different categories** of usage (e.g., simple queries, complex joins, error handling). The automatic stratified splitter ensures balanced representation. +- Use `/skill-test add` to interactively generate test cases, then review with `/skill-test review`. + +### 3. Tune reflection_minibatch_size for Focused Improvement + +GEPA's default `reflection_minibatch_size=2` shows the reflection LM feedback from 2 tasks per iteration. This keeps each reflection focused and prevents the LM from trying to fix everything at once. + +- For skills with **many test cases** (10+), the default of 2 works well -- over iterations, all tasks get attention. +- For skills with **few test cases** (3-5), consider increasing to 3 so more context is visible per step. + +### 4. Use Multi-Component Optimization for Skills + Tools + +When you optimize a SKILL.md alongside tool descriptions (`--include-tools`), GEPA creates separate components and cycles through them with round-robin selection. This means: + +- Each component gets its **own** optimization budget (the preset's `max_metric_calls` is multiplied by the component count). +- Up to `--max-passes` full optimization cycles run, re-seeding from the best candidate each time. +- Start with `--tool-modules` to target specific modules rather than optimizing all 16 at once. + +### 5. Choose the Right Preset + +| Preset | Budget per Component | Use Case | +|--------|---------------------|----------| +| `quick` | 15 calls | Fast feedback loop, initial exploration | +| `standard` | 50 calls | Default, good balance of quality and cost | +| `thorough` | 150 calls | Production-quality optimization | + +For multi-component runs, the actual `max_metric_calls` = budget x number of components. + +### 6. Leverage the Background Context + +The `background` parameter tells the reflection LM domain-specific constraints. Our optimizer automatically provides Databricks-specific context (token budgets, skill structure rules, scorer descriptions). For custom use cases, you can extend `build_optimization_background()` in `evaluator.py`. + +### 7. Iterate with Dry Runs First + +Always start with `--dry-run` to verify your setup: +```bash +uv run python .test/scripts/optimize.py --include-tools --dry-run +``` +This shows the component list, token counts, current score, and config without calling GEPA. Fix any scorer issues or missing test cases before spending optimization budget. --- -## New Skill Journey +## Building Test Cases for a Skill + +Skills need test cases before optimization can work well. There are three ways to add them: + +### Quick: Extract from MLflow Traces + +If you have MLflow traces (from `mlflow autolog claude`), extract test cases directly from them. You can find your traces at your workspace's MLflow experiment page, e.g.: +`https://.cloud.databricks.com/ml/experiments//traces` + +```bash +# Step 1: Set authentication +export DATABRICKS_HOST="https://e2-demo-field-eng.cloud.databricks.com" +export DATABRICKS_TOKEN="dapi..." + +# Step 2: List traces in your experiment to see what's available +uv run python .test/scripts/trace_to_examples.py \ + --experiment-id 2452310130108632 --list + +# Step 3: Extract examples from all recent traces +uv run python .test/scripts/trace_to_examples.py \ + --experiment-id 2452310130108632 \ + --skill databricks-model-serving + +# Or extract from a specific trace ID (from the UI or --list output) +uv run python .test/scripts/trace_to_examples.py \ + --trace-id tr-d416fccdab46e2dea6bad1d0bd8aaaa8 \ + --skill databricks-model-serving + +# Or extract from a specific MLflow run ID +uv run python .test/scripts/trace_to_examples.py \ + --run-id abc123def456 \ + --skill databricks-model-serving + +# Refine auto-extracted expectations with an LLM +uv run python .test/scripts/trace_to_examples.py \ + --experiment-id 2452310130108632 \ + --skill databricks-model-serving --refine + +# Auto-append directly to ground_truth.yaml (skip manual review) +uv run python .test/scripts/trace_to_examples.py \ + --experiment-id 2452310130108632 \ + --skill databricks-model-serving --trust + +# Limit how many traces to process +uv run python .test/scripts/trace_to_examples.py \ + --experiment-id 2452310130108632 \ + --skill databricks-model-serving --limit 5 +``` + +You can also extract from local session.jsonl files (Claude Code stores these at `~/.claude/projects/`): + +```bash +uv run python .test/scripts/trace_to_examples.py \ + --trace ~/.claude/projects/.../session.jsonl \ + --skill databricks-model-serving +``` + +The script extracts user prompt / assistant response pairs, auto-generates `expected_patterns` from code blocks and `expected_facts` from API references, and saves to `candidates.yaml` for review (or directly to `ground_truth.yaml` with `--trust`). + +#### Workflow: MLflow Traces to Optimized Skill + +The end-to-end workflow for turning real agent sessions into skill improvements: + +```bash +# 1. Extract examples from your traces +uv run python .test/scripts/trace_to_examples.py \ + --experiment-id 2452310130108632 \ + --skill databricks-model-serving --refine + +# 2. Review the extracted candidates +cat .test/skills/databricks-model-serving/candidates.yaml + +# 3. Promote good candidates to ground_truth.yaml +# (edit candidates.yaml, keep the good ones, then) +uv run python .test/scripts/trace_to_examples.py \ + --experiment-id 2452310130108632 \ + --skill databricks-model-serving --trust + +# 4. Run optimization with the enriched dataset +uv run python .test/scripts/optimize.py databricks-model-serving --preset quick + +# 5. If score improves, apply +uv run python .test/scripts/optimize.py databricks-model-serving --preset standard --apply +``` + +### Quick: Add a Single Example Manually + +```bash +# Interactive mode — prompts for each field +uv run python .test/scripts/add_example.py databricks-model-serving -Complete workflow for testing a skill from scratch (e.g., `mlflow-evaluation`). +# Inline mode — provide prompt and response directly +uv run python .test/scripts/add_example.py databricks-model-serving \ + --prompt "Create a ChatAgent with tool calling" \ + --response-file /path/to/response.md \ + --facts "Uses ChatAgent class" "Implements predict method" \ + --patterns "ChatAgent" "def predict" + +# From clipboard (paste prompt + response separated by ---) +uv run python .test/scripts/add_example.py databricks-model-serving --from-clipboard +``` + +The script auto-generates an ID, detects code language, extracts patterns from code blocks, and confirms before saving. + +### Full Workflow: Initialize + Add + Review ### 1. Initialize Test Scaffolding @@ -25,11 +587,7 @@ Complete workflow for testing a skill from scratch (e.g., `mlflow-evaluation`). /skill-test init ``` -Claude will: -1. Read the skill's SKILL.md documentation -2. Generate `manifest.yaml` with appropriate scorers -3. Create empty `ground_truth.yaml` and `candidates.yaml` templates -4. Recommend test prompts based on documentation +Generates `manifest.yaml` with scorer config, empty `ground_truth.yaml`, and `candidates.yaml`. ### 2. Add Test Cases @@ -37,87 +595,297 @@ Claude will: /skill-test add ``` -Run this with the recommended prompts from init. Claude will: -1. Ask for your test prompt -2. Invoke the skill to generate a response -3. Execute code blocks on Databricks -4. Auto-save passing tests to `ground_truth.yaml` -5. Save failing tests to `candidates.yaml` for review - -Repeat for each recommended prompt. +Interactively generates test cases. Passing tests go to `ground_truth.yaml`, failing ones to `candidates.yaml` for review. ### 3. Review Candidates ``` /skill-test review +/skill-test review --batch --filter-success ``` -Review any tests that failed execution and were saved to candidates: -1. Load pending tests from `candidates.yaml` -2. Present each with prompt, response, and execution results -3. Allow you to approve, reject, skip, or edit -4. Promote approved candidates to `ground_truth.yaml` +### 4. Configure Scorers (Optional) -For batch approval of successful tests: +Edit `.test/skills//manifest.yaml` or: ``` -/skill-test review --batch --filter-success +/skill-test scorers update --add-guideline "Must use CLUSTER BY" ``` -### 4. Configure Scorers (Optional) +### 5. Run Evaluation ``` -/skill-test scorers +/skill-test run ``` -View current scorer configuration. To update: +### 6. Save Baseline + Check Regressions ``` -/skill-test scorers update --add-guideline "Must use CLUSTER BY" +/skill-test baseline +/skill-test regression ``` -Or edit `.test/skills//manifest.yaml` directly to: -- Add/remove scorers -- Update default guidelines -- Configure trace expectations +--- -### 5. Run Evaluation +## Test Case Examples + +Test cases live in `ground_truth.yaml` and tell GEPA what the skill should produce for a given prompt. Each test case has **inputs** (the user prompt), **outputs** (the expected response), and **expectations** (what scorers check). Here are real examples from the repo. + +### Example 1: Code Generation Test Case + +A test case that checks whether the skill produces correct Python code with the right API calls: +```yaml +test_cases: + - id: serving_responses_agent_001 + inputs: + prompt: "Create a ResponsesAgent that uses UC functions for tool calling" + outputs: + response: | + ```python + from databricks.agents import ResponsesAgent + + agent = ResponsesAgent( + model="databricks-claude-sonnet-4", + tools=[{"type": "function", "function": {"name": "catalog.schema.my_function"}}], + instructions="You are a helpful assistant.", + ) + ``` + execution_success: true + expectations: + expected_facts: + - "Uses ResponsesAgent from databricks.agents" + - "Includes tools parameter with UC function reference" + - "Sets model to a valid Databricks model endpoint" + expected_patterns: + - pattern: "ResponsesAgent" + min_count: 1 + description: "Must use ResponsesAgent class" + - pattern: "catalog\\.\\w+\\.\\w+" + min_count: 1 + description: "UC function in three-level namespace" + metadata: + category: happy_path + difficulty: medium ``` -/skill-test run + +**What GEPA sees:** The evaluator scores the skill against this test case. If the skill doesn't teach the agent about `ResponsesAgent` or UC function namespaces, the pattern scorers fail and GEPA's reflection LM learns what to add. + +### Example 2: SQL/Pipeline Test Case + +A test case for Spark Declarative Pipelines (SDP) with SQL syntax validation: + +```yaml + - id: sdp_bronze_ingestion_001 + inputs: + prompt: "Create a bronze ingestion streaming table that reads from cloud storage using Auto Loader" + outputs: + response: | + ```sql + CREATE OR REFRESH STREAMING TABLE bronze_events + CLUSTER BY (event_date) + AS SELECT * + FROM STREAM read_files('s3://bucket/events/', format => 'json') + ``` + execution_success: null + expectations: + expected_facts: + - "Uses CREATE OR REFRESH STREAMING TABLE syntax" + - "Uses CLUSTER BY for data layout optimization" + - "Uses STREAM read_files() for Auto Loader ingestion" + expected_patterns: + - pattern: "CREATE OR REFRESH STREAMING TABLE" + min_count: 1 + description: "SDP streaming table DDL" + - pattern: "CLUSTER BY" + min_count: 1 + description: "Liquid clustering" + - pattern: "read_files\\(" + min_count: 1 + description: "Auto Loader function" + guidelines: + - "Must use SDP syntax (CREATE OR REFRESH), not legacy DLT syntax" + metadata: + category: happy_path + difficulty: easy +``` + +### Example 3: Error Handling Test Case + +Test cases don't have to be happy paths. Testing that a skill warns about common mistakes: + +```yaml + - id: serving_error_raw_dict_001 + inputs: + prompt: "Why does my ChatAgent return raw dicts instead of proper messages?" + outputs: + response: | + The `predict` method must return `ChatAgentResponse` with + `ChatAgentMessage` objects, not raw dicts. Use + `self.create_text_output_item(text)` to build response items. + execution_success: null + expectations: + expected_facts: + - "Explains the raw dict issue" + - "Mentions self.create_text_output_item()" + expected_patterns: + - pattern: "create_text_output_item" + min_count: 1 + description: "Correct helper method" + metadata: + category: error_handling + difficulty: easy +``` + +### Example 4: Minimal Test Case + +The simplest possible test case -- just a prompt and expected facts: + +```yaml + - id: genie_create_space_001 + inputs: + prompt: "Create a Genie Space for our sales data" + outputs: + response: "I'll create a Genie Space connected to your sales tables." + expectations: + expected_facts: + - "Creates a Genie Space" + - "Connects to data tables" + metadata: + category: happy_path + difficulty: easy ``` -Executes code blocks on Databricks or locally (depends on SKILLS, MCP, etc.) and reports pass/fail for each test in `ground_truth.yaml`. +--- + +## End-to-End Walkthrough -**Note:** Requires test cases in ground_truth.yaml (from steps 2-3). +Here's a complete example of adding test cases and running optimization for a new skill. -### 6. MLflow Evaluation (Optional) +### Step 1: Initialize scaffolding +```bash +# Creates manifest.yaml, ground_truth.yaml, candidates.yaml +/skill-test my-new-skill init ``` -/skill-test mlflow + +### Step 2: Write test cases + +Edit `.test/skills/my-new-skill/ground_truth.yaml`: + +```yaml +metadata: + skill_name: my-new-skill + version: "1.0" + +test_cases: + - id: basic_001 + inputs: + prompt: "Show me how to create a Delta table with liquid clustering" + outputs: + response: | + ```sql + CREATE TABLE catalog.schema.events ( + event_id BIGINT, + event_date DATE, + payload STRING + ) + CLUSTER BY (event_date) + ``` + expectations: + expected_facts: + - "Uses CREATE TABLE with CLUSTER BY" + expected_patterns: + - pattern: "CLUSTER BY" + min_count: 1 + description: "Liquid clustering syntax" + metadata: + category: happy_path + difficulty: easy + + - id: basic_002 + inputs: + prompt: "How do I read from a Delta table using Spark?" + outputs: + response: | + ```python + df = spark.read.table("catalog.schema.my_table") + ``` + expectations: + expected_facts: + - "Uses three-level namespace" + expected_patterns: + - pattern: "spark\\.read\\.table" + min_count: 1 + description: "Spark table reader" + metadata: + category: happy_path + difficulty: easy + + # ... add at least 5 test cases for train/val split ``` -Runs full evaluation with LLM judges and logs results to MLflow. Provides deeper quality assessment beyond pass/fail execution. +### Step 3: Verify setup with a dry run -### 7. Save Baseline +```bash +uv run python .test/scripts/optimize.py my-new-skill --dry-run +``` +Output: ``` -/skill-test baseline +=== Dry Run: my-new-skill (skillbench) === +SKILL.md path: .claude/skills/my-new-skill/SKILL.md +[SkillBench] Auto-including tools: 16 modules, 88 tools, 64,675 chars +Components: ['skill_md', 'tools_sql', 'tools_serving', ...] +Total original tokens: 20,147 + skill_md: 2,847 tokens + tools_sql: 3,200 tokens + ... +Train tasks: 4 +Val tasks: None (single-task mode) +Evaluator type: skillbench +Preset: standard (max_metric_calls=850, scaled for 17 component(s)) +Max passes: 5 +Reflection LM: databricks/databricks-claude-opus-4-6 +Current score: 0.723 ``` -Saves current metrics to `baselines//baseline.yaml`. +### Step 4: Run optimization + +```bash +# Quick first pass to see if GEPA can improve +uv run python .test/scripts/optimize.py my-new-skill --preset quick -### 8. Check Regressions +# Review the saved result +cat .test/skills/my-new-skill/optimized_SKILL.md +diff .claude/skills/my-new-skill/SKILL.md .test/skills/my-new-skill/optimized_SKILL.md -After skill changes: +# Happy with it? Apply without re-running +uv run python .test/scripts/optimize.py my-new-skill --apply-last + +# Or run standard for better results and apply immediately +uv run python .test/scripts/optimize.py my-new-skill --preset standard --apply ``` -/skill-test regression + +### Step 5: Save baseline for regression checking + +```bash +/skill-test my-new-skill baseline + +# Later, after making changes: +/skill-test my-new-skill regression ``` -Compares current pass rate against the saved baseline. +### Tips for Writing Good Test Cases + +- **Cover different categories**: happy_path, error_handling, edge cases. The splitter stratifies by `metadata.category` so each category is represented in both train and val sets. +- **Be specific in expected_patterns**: Use regex that captures the essential API call, not surrounding prose. `"ResponsesAgent"` is better than `".*ResponsesAgent.*"`. +- **Include both simple and complex prompts**: Simple prompts test baseline quality; complex prompts stress-test the skill's depth. +- **Set `execution_success`**: `true` if you verified the code runs, `null` if it's theoretical, `false` if it's known to fail. This helps scorers weight results. +- **Use `guidelines` for soft expectations**: Things an LLM judge should check but that can't be captured by regex (e.g., "Should explain why CLUSTER BY is preferred over partitioning"). --- -## Trace Evaluation (In Progress) +## Trace Evaluation Capture Claude Code sessions and evaluate against skill expectations. @@ -133,60 +901,29 @@ mlflow autolog claude -u databricks -n "$MLFLOW_EXPERIMENT_NAME" . ### Evaluate Traces -**Local trace file:** ``` /skill-test trace-eval --trace ~/.claude/projects/.../session.jsonl -``` - -**From MLflow run ID** (from `mlflow.search_runs`): -``` /skill-test trace-eval --run-id abc123 -``` - -**From MLflow trace ID** (from `mlflow.get_trace`): -``` -/skill-test trace-eval --trace-id tr-d416fccdab46e2dea6bad1d0bd8aaaa8 -``` - -**List available traces:** -``` -/skill-test list-traces --local /skill-test list-traces --experiment "$MLFLOW_EXPERIMENT_NAME" ``` -### Configure Expectations - -In `manifest.yaml`: -```yaml -scorers: - trace_expectations: - tool_limits: - Bash: 15 - mcp__databricks__execute_sql: 10 - token_budget: - max_total: 150000 - required_tools: - - Read - banned_tools: - - "DROP DATABASE" -``` - --- ## Command Reference -| Command | Description | -|---------|-------------| -| `run` | Execute tests against ground truth (default) | -| `init` | Generate test scaffolding from skill docs | -| `add` | Add test cases interactively | -| `review` | Review and promote candidates | -| `baseline` | Save current results as baseline | -| `regression` | Compare against baseline | -| `mlflow` | Full evaluation with LLM judges | -| `trace-eval` | Evaluate session traces | -| `list-traces` | List available traces | -| `scorers` | View/update scorer config | +| Command | Description | +|------------|------------------------------------------| +| `run` | Execute tests against ground truth | +| `init` | Generate test scaffolding from skill docs| +| `add` | Add test cases interactively | +| `review` | Review and promote candidates | +| `baseline` | Save current results as baseline | +| `regression` | Compare against baseline | +| `mlflow` | Full evaluation with LLM judges | +| `optimize` | Optimize skill with GEPA | +| `trace-eval` | Evaluate session traces | +| `list-traces` | List available traces | +| `scorers` | View/update scorer config | --- @@ -194,34 +931,20 @@ scorers: ``` .test/skills// -├── manifest.yaml # Scorers, guidelines, trace expectations -├── ground_truth.yaml # Verified test cases -└── candidates.yaml # Pending review +├── manifest.yaml # Scorers, guidelines, trace expectations +├── ground_truth.yaml # Verified test cases +├── candidates.yaml # Pending review +├── optimized_SKILL.md # Last optimization output (auto-saved) +└── last_optimization.json # Metadata for --apply-last .test/baselines// └── baseline.yaml # Regression baseline -``` - ---- - -## Test Case Format -```yaml -test_cases: - - id: "eval_basic_001" - inputs: - prompt: "Create a scorer for response length" - outputs: - response: | - ```python - @scorer - def response_length(outputs): - return Feedback(name="length", value=len(outputs["response"])) - ``` - execution_success: true - expectations: - expected_facts: ["@scorer", "Feedback"] - guidelines: ["Must use mlflow.genai.scorers"] +.test/scripts/ +├── optimize.py # CLI for GEPA optimization +├── trace_to_examples.py # Extract test cases from session.jsonl traces +├── add_example.py # Manually add test cases to ground_truth.yaml +└── _common.py # Shared CLI utilities ``` --- diff --git a/.test/notebooks/gepa_skill_optimization_demo.ipynb b/.test/notebooks/gepa_skill_optimization_demo.ipynb new file mode 100644 index 00000000..b403bf1f --- /dev/null +++ b/.test/notebooks/gepa_skill_optimization_demo.ipynb @@ -0,0 +1,353 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GEPA Skill Optimization Demo\n", + "\n", + "This notebook demonstrates how the skill-test framework uses [GEPA](https://github.com/gepa-ai/gepa) to automatically optimize Databricks SKILL.md files for **quality** and **token efficiency**.\n", + "\n", + "SKILL.md files teach AI agents (like Claude Code) Databricks patterns. Every token in a skill consumes agent context window budget, so skills should be as concise and high-quality as possible.\n", + "\n", + "**What GEPA does:**\n", + "1. Scores the current SKILL.md against deterministic scorers (syntax, patterns, APIs, facts)\n", + "2. Reflects on failures and proposes mutations to improve the skill\n", + "3. Selects the best candidate via Pareto frontier optimization\n", + "4. Repeats until quality converges or budget is exhausted" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "# Add skill-test to path\n", + "repo_root = Path(\".\").resolve()\n", + "while not (repo_root / \".test\" / \"src\").exists() and repo_root != repo_root.parent:\n", + " repo_root = repo_root.parent\n", + "sys.path.insert(0, str(repo_root / \".test\" / \"src\"))\n", + "\n", + "print(f\"Repo root: {repo_root}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "import os\n\n# Configure the reflection model -- pick ONE:\n\n# Option A: Databricks Model Serving (default, recommended)\n# IMPORTANT: DATABRICKS_API_BASE must end with /serving-endpoints\n# os.environ[\"DATABRICKS_API_KEY\"] = \"dapi...\" \n# os.environ[\"DATABRICKS_API_BASE\"] = \"https://.cloud.databricks.com/serving-endpoints\"\n# os.environ[\"GEPA_REFLECTION_LM\"] = \"databricks/databricks-gpt-5-2\"\n\n# Option B: OpenAI\n# os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"\n# os.environ[\"GEPA_REFLECTION_LM\"] = \"openai/gpt-4o\"\n\nprint(f\"Reflection LM: {os.environ.get('GEPA_REFLECTION_LM', 'databricks/databricks-gpt-5-2 (default)')}\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Inspect the Skill\n", + "\n", + "Let's look at the `databricks-model-serving` skill -- its current size, test cases, and baseline score." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "SKILL_NAME = \"databricks-model-serving\"\n", + "\n", + "from skill_test.optimize.evaluator import _find_skill_md, count_tokens\n", + "from skill_test.optimize.splitter import create_gepa_datasets\n", + "\n", + "# Load skill\n", + "skill_path = _find_skill_md(SKILL_NAME)\n", + "original_content = skill_path.read_text()\n", + "original_tokens = count_tokens(original_content)\n", + "\n", + "# Load test cases\n", + "train, val = create_gepa_datasets(SKILL_NAME)\n", + "\n", + "print(f\"Skill: {SKILL_NAME}\")\n", + "print(f\"Path: {skill_path}\")\n", + "print(f\"Lines: {len(original_content.splitlines())}\")\n", + "print(f\"Tokens: {original_tokens:,}\")\n", + "print(f\"Train cases: {len(train)}\")\n", + "print(f\"Val cases: {len(val) if val else 'None'}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show first few test cases\n", + "for t in train[:3]:\n", + " print(f\"\\n--- {t['id']} ---\")\n", + " print(f\"Prompt: {t['input'][:100]}...\")\n", + " if t.get('answer'):\n", + " print(f\"Answer: {t['answer'][:100]}...\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Evaluate Current Quality (Baseline)\n", + "\n", + "Before optimizing, measure the current skill quality using the scorer pipeline." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "from skill_test.optimize.evaluator import create_skill_evaluator, SKILL_KEY\nfrom skill_test.optimize.splitter import to_gepa_instances\n\nevaluator = create_skill_evaluator(SKILL_NAME)\nseed_candidate = {SKILL_KEY: original_content}\n\n# Evaluate on all train tasks\ngepa_instances = to_gepa_instances(train)\n\nprint(f\"{'Task ID':<35} {'Score':>8}\")\nprint(\"-\" * 45)\nfor i, inst in enumerate(gepa_instances):\n score, side_info = evaluator(seed_candidate, inst)\n task_id = train[i]['id']\n status = 'PASS' if score >= 0.5 else 'FAIL'\n print(f\"{task_id:<35} {score:>7.3f} {status}\")\n\n# Quick baseline\nscores = [evaluator(seed_candidate, inst)[0] for inst in gepa_instances]\nbaseline_score = sum(scores) / len(scores)\nprint(f\"\\nBaseline Score: {baseline_score:.3f}\")\nprint(f\"Token Count: {original_tokens:,}\")" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Run GEPA Optimization\n", + "\n", + "Now run the optimization. GEPA will:\n", + "- Use the current SKILL.md as the seed candidate\n", + "- Run scorers against each test case\n", + "- Reflect on failures to propose mutations\n", + "- Select the best candidate via Pareto frontier\n", + "- Penalize token bloat (80% quality, 20% efficiency weighting)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from skill_test.optimize.runner import optimize_skill\n", + "\n", + "result = optimize_skill(\n", + " skill_name=SKILL_NAME,\n", + " mode=\"static\",\n", + " preset=\"quick\", # 15 iterations -- increase to \"standard\" (50) or \"thorough\" (150) for better results\n", + ")\n", + "\n", + "print(f\"Optimization complete!\")\n", + "print(f\"GEPA metric calls: {result.gepa_result.total_metric_calls}\")\n", + "print(f\"Candidates explored: {result.gepa_result.num_candidates}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4: Results Comparison\n", + "\n", + "Compare the original vs. optimized skill across quality and token efficiency." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"=\" * 60)\n", + "print(f\" OPTIMIZATION RESULTS: {SKILL_NAME}\")\n", + "print(\"=\" * 60)\n", + "print()\n", + "\n", + "# Quality comparison\n", + "quality_delta = result.improvement\n", + "quality_pct = (quality_delta / result.original_score * 100) if result.original_score > 0 else 0\n", + "print(f\" Quality Score\")\n", + "print(f\" Before: {result.original_score:.3f}\")\n", + "print(f\" After: {result.optimized_score:.3f}\")\n", + "print(f\" Delta: {quality_delta:+.3f} ({quality_pct:+.1f}%)\")\n", + "print()\n", + "\n", + "# Token comparison \n", + "token_delta = result.original_token_count - result.optimized_token_count\n", + "print(f\" Token Count\")\n", + "print(f\" Before: {result.original_token_count:,}\")\n", + "print(f\" After: {result.optimized_token_count:,}\")\n", + "print(f\" Saved: {token_delta:,} tokens ({result.token_reduction_pct:.1f}% reduction)\")\n", + "print()\n", + "\n", + "# Line count comparison\n", + "orig_lines = len(result.original_content.splitlines())\n", + "opt_lines = len(result.optimized_content.splitlines())\n", + "print(f\" Lines\")\n", + "print(f\" Before: {orig_lines}\")\n", + "print(f\" After: {opt_lines}\")\n", + "print(f\" Saved: {orig_lines - opt_lines} lines\")\n", + "print()\n", + "\n", + "# Validation scores\n", + "if result.val_scores:\n", + " avg_val = sum(result.val_scores.values()) / len(result.val_scores)\n", + " print(f\" Validation (held-out test cases)\")\n", + " for tid, score in result.val_scores.items():\n", + " print(f\" {tid}: {score:.3f}\")\n", + " print(f\" Average: {avg_val:.3f}\")\n", + "\n", + "print()\n", + "print(\"=\" * 60)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visual comparison bar chart\n", + "try:\n", + " import matplotlib.pyplot as plt\n", + " import matplotlib\n", + " matplotlib.rcParams['font.family'] = 'monospace'\n", + "\n", + " fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n", + "\n", + " # Quality scores\n", + " ax = axes[0]\n", + " bars = ax.bar(\n", + " ['Before', 'After'],\n", + " [result.original_score, result.optimized_score],\n", + " color=['#d4534b', '#4a9c5d'],\n", + " width=0.5\n", + " )\n", + " ax.set_ylim(0, 1.1)\n", + " ax.set_ylabel('Quality Score')\n", + " ax.set_title(f'Quality: {result.original_score:.3f} → {result.optimized_score:.3f}')\n", + " for bar, val in zip(bars, [result.original_score, result.optimized_score]):\n", + " ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.02,\n", + " f'{val:.3f}', ha='center', fontweight='bold')\n", + "\n", + " # Token counts\n", + " ax = axes[1]\n", + " bars = ax.bar(\n", + " ['Before', 'After'],\n", + " [result.original_token_count, result.optimized_token_count],\n", + " color=['#d4534b', '#4a9c5d'],\n", + " width=0.5\n", + " )\n", + " ax.set_ylabel('Token Count')\n", + " ax.set_title(f'Tokens: {result.original_token_count:,} → {result.optimized_token_count:,} ({result.token_reduction_pct:.0f}% reduction)')\n", + " for bar, val in zip(bars, [result.original_token_count, result.optimized_token_count]):\n", + " ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 50,\n", + " f'{val:,}', ha='center', fontweight='bold')\n", + "\n", + " fig.suptitle(f'GEPA Optimization: {SKILL_NAME}', fontsize=14, fontweight='bold')\n", + " plt.tight_layout()\n", + " plt.show()\n", + "except ImportError:\n", + " print(\"(matplotlib not installed -- skipping chart)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5: Review the Diff\n", + "\n", + "Inspect what GEPA changed in the SKILL.md." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from skill_test.optimize.review import review_optimization\n", + "\n", + "review_optimization(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 6: Apply (Optional)\n", + "\n", + "If the results look good, apply the optimized SKILL.md. Uncomment the cell below to write it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Uncomment to apply:\n", + "# from skill_test.optimize.review import apply_optimization\n", + "# apply_optimization(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Multi-Component Optimization: Skills + Tools\n\nGEPA supports optimizing multiple text components simultaneously. You can optimize SKILL.md files **alongside** MCP tool descriptions in a single run.\n\nGEPA's `RoundRobinReflectionComponentSelector` cycles through components one at a time, so each gets dedicated reflection and mutation." + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "# Inspect available MCP tools\nfrom skill_test.optimize.tools import get_tool_stats, extract_tool_descriptions, tools_to_gepa_components\n\nstats = get_tool_stats()\nprint(f\"MCP Tool Modules: {stats['modules']}\")\nprint(f\"Total Tools: {stats['total_tools']}\")\nprint(f\"Total Chars: {stats['total_description_chars']:,}\")\nprint()\nfor mod, info in stats[\"per_module\"].items():\n print(f\" {mod:<20} {info['tools']:>2} tools {info['chars']:>6,} chars\")\n\n# Show what GEPA components look like for selected modules\ntool_map = extract_tool_descriptions(modules=[\"serving\", \"sql\"])\ncomponents = tools_to_gepa_components(tool_map, per_module=True)\nprint(f\"\\nGEPA components for serving + sql: {list(components.keys())}\")\nfor name, text in components.items():\n from skill_test.optimize.evaluator import count_tokens\n print(f\" {name}: {count_tokens(text):,} tokens\")" + }, + { + "cell_type": "code", + "source": "## Changing the Reflection Model\n\nBy default, GEPA uses `databricks/databricks-gpt-5-2` via Databricks Model Serving.\nOverride per-call or via environment variable:\n\n```python\n# Per-call\nresult = optimize_skill(\"my-skill\", reflection_lm=\"openai/gpt-4o\")\n\n# Environment variable (persistent)\nos.environ[\"GEPA_REFLECTION_LM\"] = \"databricks/databricks-gpt-5-2\"\n```\n\nSee README.md for full model configuration options.", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "The GEPA optimization pipeline:\n", + "\n", + "| Metric | Before | After | Change |\n", + "|--------|--------|-------|--------|\n", + "| Quality Score | `result.original_score` | `result.optimized_score` | `result.improvement` |\n", + "| Token Count | `result.original_token_count` | `result.optimized_token_count` | `result.token_reduction_pct`% |\n", + "\n", + "Key points:\n", + "- **Quality gate**: Existing scorers (syntax, patterns, APIs, facts) are reused as-is\n", + "- **Token efficiency**: 80/20 quality/efficiency weighting penalizes bloated skills\n", + "- **Validation split**: Held-out test cases detect overfitting\n", + "- **Reflection LM**: Configurable via `--reflection-lm` flag or `GEPA_REFLECTION_LM` env var\n", + "- **Default model**: `databricks/databricks-gpt-5-2` via Databricks Model Serving" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/.test/pyproject.toml b/.test/pyproject.toml index 8feaeed3..5e114eb4 100644 --- a/.test/pyproject.toml +++ b/.test/pyproject.toml @@ -17,7 +17,8 @@ dependencies = [ [project.optional-dependencies] databricks = ["databricks-sdk>=0.20.0"] dev = ["pytest>=8.0", "pytest-asyncio>=0.23"] -all = ["skill-test[databricks,dev]"] +optimize = ["gepa>=0.1.0", "tiktoken>=0.7.0"] +all = ["skill-test[databricks,dev,optimize]"] [project.scripts] skill-test = "skill_test.cli:main" diff --git a/.test/scripts/add_example.py b/.test/scripts/add_example.py new file mode 100644 index 00000000..68f5c920 --- /dev/null +++ b/.test/scripts/add_example.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 +"""Manually add a test case to a skill's ground_truth.yaml. + +Usage: + # Interactive mode — prompts for each field + uv run python .test/scripts/add_example.py databricks-model-serving + + # Inline mode — provide prompt and response directly + uv run python .test/scripts/add_example.py databricks-model-serving \ + --prompt "Create a ChatAgent with tool calling" \ + --response-file /path/to/response.md \ + --facts "Uses ChatAgent class" "Implements predict method" \ + --patterns "ChatAgent" "def predict" + + # From clipboard + uv run python .test/scripts/add_example.py databricks-model-serving --from-clipboard +""" + +import argparse +import re +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from _common import setup_path + +setup_path() + + +def _detect_language(response: str) -> str | None: + """Auto-detect code language from response code blocks.""" + langs = re.findall(r"```(\w+)\n", response) + if not langs: + return None + # Most common language wins + from collections import Counter + counts = Counter(l for l in langs if l != "text") + return counts.most_common(1)[0][0] if counts else None + + +def _auto_extract_patterns(response: str) -> list[str]: + """Extract patterns from code blocks.""" + patterns = [] + for match in re.finditer(r"```(?:python)\n(.*?)```", response, re.DOTALL): + code = match.group(1) + for m in re.finditer(r"class\s+(\w+)", code): + patterns.append(m.group(1)) + for m in re.finditer(r"def\s+(\w+)", code): + patterns.append(m.group(1)) + + for match in re.finditer(r"```(?:sql)\n(.*?)```", response, re.DOTALL): + code = match.group(1) + for m in re.finditer(r"(?:CREATE|ALTER)\s+(?:TABLE|VIEW)\s+(\S+)", code, re.I): + patterns.append(m.group(1)) + + return list(dict.fromkeys(patterns)) + + +def _next_id(skill_name: str, existing_ids: set[str]) -> str: + """Generate the next sequential ID for a skill.""" + prefix = skill_name.replace("-", "_") + idx = 1 + while True: + candidate = f"{prefix}_{idx:03d}" + if candidate not in existing_ids: + return candidate + idx += 1 + + +def _read_clipboard() -> str: + """Read text from system clipboard.""" + import subprocess + try: + result = subprocess.run(["pbpaste"], capture_output=True, text=True, check=True) + return result.stdout + except (subprocess.CalledProcessError, FileNotFoundError): + try: + result = subprocess.run( + ["xclip", "-selection", "clipboard", "-o"], + capture_output=True, text=True, check=True, + ) + return result.stdout + except (subprocess.CalledProcessError, FileNotFoundError): + print("Error: Could not read clipboard (tried pbpaste and xclip)") + sys.exit(1) + + +def main(): + parser = argparse.ArgumentParser( + description="Add a test case to a skill's ground_truth.yaml", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "skill_name", + help="Name of the skill (e.g., databricks-model-serving)", + ) + parser.add_argument( + "--prompt", "-p", + default=None, + help="The user prompt for the test case", + ) + parser.add_argument( + "--response", "-r", + default=None, + help="The expected response text (inline)", + ) + parser.add_argument( + "--response-file", + type=Path, + default=None, + help="Path to a file containing the expected response", + ) + parser.add_argument( + "--facts", "-f", + nargs="*", + default=None, + help="Expected facts that must appear in the response", + ) + parser.add_argument( + "--patterns", + nargs="*", + default=None, + help="Expected patterns (regex) that must match in the response", + ) + parser.add_argument( + "--category", "-c", + default="happy_path", + help="Test case category (default: happy_path)", + ) + parser.add_argument( + "--from-clipboard", + action="store_true", + help="Read prompt and response from clipboard (separated by ---)", + ) + parser.add_argument( + "--id", + default=None, + help="Override the auto-generated test case ID", + ) + + args = parser.parse_args() + + import yaml + from skill_test.dataset import get_dataset_source, YAMLDatasetSource + + # Validate skill exists + skill_dir = Path(".test/skills") / args.skill_name + gt_path = skill_dir / "ground_truth.yaml" + + if not skill_dir.exists(): + print(f"Error: skill directory not found: {skill_dir}") + print("Available skills:") + for d in sorted(Path(".test/skills").iterdir()): + if d.is_dir() and not d.name.startswith("_"): + print(f" {d.name}") + sys.exit(1) + + # Load existing records + existing_ids = set() + if gt_path.exists(): + try: + source = YAMLDatasetSource(gt_path) + existing = source.load() + existing_ids = {r.id for r in existing} + except Exception: + pass + + # Get prompt + prompt = args.prompt + response = args.response + + if args.from_clipboard: + clipboard = _read_clipboard() + if "---" in clipboard: + parts = clipboard.split("---", 1) + prompt = parts[0].strip() + response = parts[1].strip() + else: + prompt = clipboard.strip() + print("Clipboard content set as prompt (no --- separator found for response)") + + if args.response_file: + response = args.response_file.read_text() + + if not prompt: + print("Enter the user prompt (Ctrl+D to finish):") + prompt = sys.stdin.read().strip() + + if not prompt: + print("Error: prompt is required") + sys.exit(1) + + if not response: + print("Enter the expected response (Ctrl+D to finish):") + response = sys.stdin.read().strip() + + # Generate ID + test_id = args.id or _next_id(args.skill_name, existing_ids) + + # Auto-extract patterns and facts + auto_patterns = _auto_extract_patterns(response) if response else [] + auto_facts = args.facts or [] + user_patterns = args.patterns or [] + + # Merge auto and user patterns + all_patterns = list(dict.fromkeys(user_patterns + auto_patterns)) + + # Detect language + language = _detect_language(response) if response else None + + # Build test case + test_case = { + "id": test_id, + "inputs": {"prompt": prompt}, + "metadata": { + "category": args.category, + "source": "manual", + }, + } + + if response: + test_case["outputs"] = {"response": response} + if language: + test_case["metadata"]["language"] = language + + expectations = {} + if auto_facts: + expectations["expected_facts"] = auto_facts + if all_patterns: + expectations["expected_patterns"] = all_patterns + if expectations: + test_case["expectations"] = expectations + + # Show summary + print(f"\n--- Test Case Preview ---") + print(f"ID: {test_id}") + print(f"Skill: {args.skill_name}") + print(f"Prompt: {prompt[:100]}{'...' if len(prompt) > 100 else ''}") + if response: + print(f"Response: {len(response)} chars") + if all_patterns: + print(f"Patterns: {all_patterns}") + if auto_facts: + print(f"Facts: {auto_facts}") + print(f"Category: {args.category}") + + # Confirm + if sys.stdin.isatty(): + confirm = input("\nAppend to ground_truth.yaml? [Y/n] ").strip().lower() + if confirm and confirm != "y": + print("Aborted.") + sys.exit(0) + + # Save + if gt_path.exists(): + with open(gt_path) as f: + data = yaml.safe_load(f) or {"test_cases": []} + else: + gt_path.parent.mkdir(parents=True, exist_ok=True) + data = {"test_cases": []} + + data["test_cases"].append(test_case) + + with open(gt_path, "w") as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + + print(f"Added test case '{test_id}' to {gt_path}") + + +if __name__ == "__main__": + main() diff --git a/.test/scripts/generate_examples.py b/.test/scripts/generate_examples.py new file mode 100644 index 00000000..890b5bc8 --- /dev/null +++ b/.test/scripts/generate_examples.py @@ -0,0 +1,428 @@ +#!/usr/bin/env python3 +"""Generate test cases from user requirements for skill optimization. + +Converts natural-language requirements into ground_truth.yaml test cases +with binary assertions (patterns + facts) and guidelines for LLM judge +evaluation. + +Usage: + # From a requirements file (one requirement per line) + uv run python .test/scripts/generate_examples.py databricks-metric-views \ + --requirements requirements.txt + + # Inline requirements (repeatable) + uv run python .test/scripts/generate_examples.py databricks-metric-views \ + --requirement "Must explain MEASURE() wrapping for all measure references" \ + --requirement "Should show error handling when SELECT * is used on metric views" + + # Interactive mode (prompts for requirements) + uv run python .test/scripts/generate_examples.py databricks-metric-views --interactive + + # Auto-append to ground_truth.yaml (skip manual review) + uv run python .test/scripts/generate_examples.py databricks-metric-views \ + --requirement "Must explain MEASURE() wrapping" --trust + + # With a second LLM pass to tighten assertions + uv run python .test/scripts/generate_examples.py databricks-metric-views \ + --requirement "Must explain MEASURE() wrapping" --refine +""" + +import argparse +import json +import re +import sys +from pathlib import Path +from typing import Any + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from _common import setup_path + +setup_path() + + +def _find_skill_md(skill_name: str) -> str: + """Load SKILL.md content for the given skill.""" + from skill_test.optimize.evaluator import _find_skill_md as find_md + path = find_md(skill_name) + if path is None: + raise FileNotFoundError(f"Could not find SKILL.md for '{skill_name}'") + return path.read_text() + + +def _load_existing_ids(skill_name: str) -> set[str]: + """Load existing test case IDs from ground_truth.yaml.""" + import yaml + gt_path = Path(".test/skills") / skill_name / "ground_truth.yaml" + if not gt_path.exists(): + return set() + with open(gt_path) as f: + data = yaml.safe_load(f) or {} + return {tc["id"] for tc in data.get("test_cases", []) if "id" in tc} + + +def generate_examples_from_requirements( + skill_name: str, + requirements: list[str], + skill_md: str, + gen_model: str, + count_per_requirement: int = 3, +) -> list[dict]: + """Generate test cases from requirements using an LLM. + + For each requirement, generates ``count_per_requirement`` test cases + grounded in the SKILL.md content. + + Returns: + List of test case dicts in ground_truth.yaml format. + """ + import litellm + + existing_ids = _load_existing_ids(skill_name) + all_examples: list[dict] = [] + + for req_idx, requirement in enumerate(requirements): + print(f"\n Generating for requirement {req_idx + 1}/{len(requirements)}:") + print(f" {requirement[:100]}") + + messages = [ + { + "role": "system", + "content": ( + "You are an expert test case generator for Databricks skills. " + "Given a SKILL.md document and a user requirement, generate " + f"{count_per_requirement} realistic test cases that evaluate " + "whether the skill satisfies the requirement.\n\n" + "Each test case must include:\n" + "- A realistic user prompt\n" + "- A reference response grounded in SKILL.md\n" + "- Binary assertions: expected_patterns (regex) and expected_facts (substring)\n" + "- Guidelines derived from the requirement (for LLM judge evaluation)\n" + "- Category and difficulty metadata\n\n" + "Return a JSON array of test cases. Each test case:\n" + "{\n" + ' "prompt": "user question",\n' + ' "response": "reference answer grounded in SKILL.md",\n' + ' "expected_patterns": [{"pattern": "regex", "min_count": 1, "description": "what it checks"}],\n' + ' "expected_facts": ["substring that must appear"],\n' + ' "guidelines": ["evaluation guideline from the requirement"],\n' + ' "category": "happy_path|edge_case|error_handling",\n' + ' "difficulty": "easy|medium|hard"\n' + "}\n\n" + "Important:\n" + "- Patterns should be regex that work with re.findall(pattern, response, re.IGNORECASE)\n" + "- Facts should be exact substrings (case-insensitive) from the response\n" + "- Guidelines should be evaluable by an LLM judge looking at the response\n" + "- Ground everything in SKILL.md — don't invent APIs or syntax" + ), + }, + { + "role": "user", + "content": ( + f"## SKILL.md for '{skill_name}':\n\n" + f"{skill_md[:8000]}\n\n" + f"## Requirement:\n{requirement}\n\n" + f"## Existing test case IDs (avoid duplicates):\n" + f"{', '.join(sorted(existing_ids)[:20]) or 'None'}\n\n" + f"Generate {count_per_requirement} test cases as a JSON array." + ), + }, + ] + + try: + resp = litellm.completion( + model=gen_model, + messages=messages, + temperature=0.7, + ) + content = resp.choices[0].message.content or "" + + # Extract JSON array from response + json_match = re.search(r"\[.*\]", content, re.DOTALL) + if not json_match: + print(f" Warning: no JSON array found in response") + continue + + cases = json.loads(json_match.group()) + if not isinstance(cases, list): + print(f" Warning: expected JSON array, got {type(cases)}") + continue + + for case_idx, case in enumerate(cases): + test_id = f"{skill_name}_gen_{req_idx:02d}_{case_idx:02d}" + # Avoid collisions with existing IDs + while test_id in existing_ids: + test_id += "_x" + existing_ids.add(test_id) + + example = { + "id": test_id, + "inputs": {"prompt": case.get("prompt", "")}, + "outputs": { + "response": case.get("response", ""), + "execution_success": True, + }, + "expectations": {}, + "metadata": { + "category": case.get("category", "happy_path"), + "difficulty": case.get("difficulty", "medium"), + "source": "generated_from_requirement", + "requirement": requirement[:200], + }, + } + + if case.get("expected_patterns"): + example["expectations"]["expected_patterns"] = case["expected_patterns"] + if case.get("expected_facts"): + example["expectations"]["expected_facts"] = case["expected_facts"] + if case.get("guidelines"): + example["expectations"]["guidelines"] = case["guidelines"] + + all_examples.append(example) + + print(f" Generated {len(cases)} test case(s)") + + except Exception as e: + print(f" Error generating for requirement: {e}") + + return all_examples + + +def refine_examples(examples: list[dict], gen_model: str) -> list[dict]: + """Second LLM pass to validate and tighten assertions.""" + import litellm + + for ex in examples: + prompt_text = ex["inputs"]["prompt"][:200] + response_text = ex["outputs"]["response"][:1000] + patterns = ex["expectations"].get("expected_patterns", []) + facts = ex["expectations"].get("expected_facts", []) + guidelines = ex["expectations"].get("guidelines", []) + + messages = [ + { + "role": "system", + "content": ( + "You are reviewing auto-generated test expectations. " + "Validate that patterns actually match the response, " + "facts are actually present as substrings, and guidelines " + "are clear and evaluable. Tighten or fix as needed. " + "Return JSON only." + ), + }, + { + "role": "user", + "content": ( + f"Prompt: {prompt_text}\n\n" + f"Response: {response_text}\n\n" + f"Patterns: {json.dumps(patterns)}\n" + f"Facts: {json.dumps(facts)}\n" + f"Guidelines: {json.dumps(guidelines)}\n\n" + "Return a JSON object with:\n" + '- "expected_patterns": refined list\n' + '- "expected_facts": refined list (must be exact substrings of response)\n' + '- "guidelines": refined list\n' + "Remove any patterns/facts that don't actually match the response." + ), + }, + ] + + try: + resp = litellm.completion(model=gen_model, messages=messages, temperature=0) + content = resp.choices[0].message.content or "" + json_match = re.search(r"\{.*\}", content, re.DOTALL) + if json_match: + refined = json.loads(json_match.group()) + if "expected_patterns" in refined: + ex["expectations"]["expected_patterns"] = refined["expected_patterns"] + if "expected_facts" in refined: + ex["expectations"]["expected_facts"] = refined["expected_facts"] + if "guidelines" in refined: + ex["expectations"]["guidelines"] = refined["guidelines"] + except Exception as e: + print(f" Warning: refinement failed for {ex['id']}: {e}") + + return examples + + +def save_candidates(examples: list[dict], skill_name: str) -> Path: + """Save generated examples to candidates.yaml for review.""" + import yaml + + output_path = Path(".test/skills") / skill_name / "candidates.yaml" + output_path.parent.mkdir(parents=True, exist_ok=True) + + data = {"test_cases": examples} + with open(output_path, "w") as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + + print(f"\nSaved {len(examples)} examples to {output_path}") + print(f"Review and then append to ground_truth.yaml, or re-run with --trust") + return output_path + + +def append_to_ground_truth(examples: list[dict], skill_name: str) -> None: + """Append generated examples to ground_truth.yaml.""" + import yaml + + gt_path = Path(".test/skills") / skill_name / "ground_truth.yaml" + gt_path.parent.mkdir(parents=True, exist_ok=True) + + if gt_path.exists(): + with open(gt_path) as f: + data = yaml.safe_load(f) or {"test_cases": []} + else: + data = {"test_cases": []} + + existing_ids = {tc["id"] for tc in data.get("test_cases", []) if "id" in tc} + new_examples = [ex for ex in examples if ex["id"] not in existing_ids] + + if not new_examples: + print("No new examples to add (all IDs already exist).") + return + + data["test_cases"].extend(new_examples) + + with open(gt_path, "w") as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + + print(f"\nAppended {len(new_examples)} examples to {gt_path}") + + +def run_generation( + skill_name: str, + requirements: list[str], + gen_model: str, + trust: bool = False, + refine: bool = False, + count_per_requirement: int = 3, +) -> list[dict]: + """Run the full generation pipeline. Called by optimize.py --generate-from. + + Returns: + List of generated test case dicts. + """ + skill_md = _find_skill_md(skill_name) + print(f"Generating test cases for '{skill_name}' from {len(requirements)} requirement(s)") + + examples = generate_examples_from_requirements( + skill_name=skill_name, + requirements=requirements, + skill_md=skill_md, + gen_model=gen_model, + count_per_requirement=count_per_requirement, + ) + + if not examples: + print("No examples generated.") + return [] + + if refine: + print("\nRefining assertions with LLM...") + examples = refine_examples(examples, gen_model) + + if trust: + append_to_ground_truth(examples, skill_name) + else: + save_candidates(examples, skill_name) + + return examples + + +def main(): + parser = argparse.ArgumentParser( + description="Generate test cases from requirements for skill optimization", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "skill_name", + help="Name of the skill (e.g., databricks-metric-views)", + ) + parser.add_argument( + "--requirements", + type=str, + default=None, + metavar="FILE", + help="Path to a requirements file (one requirement per line)", + ) + parser.add_argument( + "--requirement", + action="append", + default=None, + dest="inline_requirements", + help="Inline requirement (repeatable)", + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Interactively enter requirements", + ) + parser.add_argument( + "--gen-model", + default=None, + help="LLM model for generation (default: GEPA_GEN_LM env or Sonnet)", + ) + parser.add_argument( + "--count", + type=int, + default=3, + help="Number of test cases per requirement (default: 3)", + ) + parser.add_argument( + "--trust", + action="store_true", + help="Auto-append to ground_truth.yaml instead of writing candidates.yaml", + ) + parser.add_argument( + "--refine", + action="store_true", + help="Second LLM pass to validate and tighten assertions", + ) + + args = parser.parse_args() + + # Collect requirements from all sources + requirements: list[str] = [] + + if args.requirements: + req_path = Path(args.requirements) + if not req_path.exists(): + print(f"Error: requirements file not found: {req_path}") + sys.exit(1) + requirements.extend( + line.strip() for line in req_path.read_text().splitlines() + if line.strip() and not line.strip().startswith("#") + ) + + if args.inline_requirements: + requirements.extend(args.inline_requirements) + + if args.interactive: + print("Enter requirements (one per line, empty line to finish):") + while True: + line = input(" > ").strip() + if not line: + break + requirements.append(line) + + if not requirements: + parser.error("Provide requirements via --requirements, --requirement, or --interactive") + + # Resolve gen_model + gen_model = args.gen_model + if gen_model is None: + from skill_test.optimize.config import DEFAULT_GEN_LM + gen_model = DEFAULT_GEN_LM + + run_generation( + skill_name=args.skill_name, + requirements=requirements, + gen_model=gen_model, + trust=args.trust, + refine=args.refine, + count_per_requirement=args.count, + ) + + +if __name__ == "__main__": + main() diff --git a/.test/scripts/generate_ground_truth.py b/.test/scripts/generate_ground_truth.py new file mode 100644 index 00000000..083b08a2 --- /dev/null +++ b/.test/scripts/generate_ground_truth.py @@ -0,0 +1,645 @@ +#!/usr/bin/env python3 +"""Generate ground_truth.yaml and manifest.yaml for skills missing test cases. + +Reads each SKILL.md, extracts code examples, headers, and key patterns, +then generates test cases that enable GEPA scorers to produce real signal. + +Usage: + # Generate for a specific skill + uv run python .test/scripts/generate_ground_truth.py databricks-metric-views + + # Generate for all missing skills + uv run python .test/scripts/generate_ground_truth.py --all + + # Preview without writing (dry run) + uv run python .test/scripts/generate_ground_truth.py --all --dry-run +""" + +import argparse +import re +import sys +import textwrap +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path + +import yaml + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + +@dataclass +class CodeBlock: + language: str + code: str + context: str # surrounding prose/header text + + +@dataclass +class Section: + level: int # 2 for ##, 3 for ### + title: str + content: str + code_blocks: list[CodeBlock] = field(default_factory=list) + + +@dataclass +class Pattern: + pattern: str + description: str + min_count: int = 1 + + +# --------------------------------------------------------------------------- +# SKILL.md parsing +# --------------------------------------------------------------------------- + +def extract_frontmatter(content: str) -> dict[str, str]: + """Extract YAML frontmatter from SKILL.md.""" + m = re.match(r"^---\s*\n(.*?)\n---", content, re.DOTALL) + if m: + try: + return yaml.safe_load(m.group(1)) or {} + except yaml.YAMLError: + return {} + return {} + + +def extract_code_blocks(content: str) -> list[CodeBlock]: + """Extract fenced code blocks with their language and surrounding context.""" + blocks = [] + # Match ```lang ... ``` + for m in re.finditer( + r"```(\w+)\s*\n(.*?)```", + content, + re.DOTALL, + ): + lang = m.group(1).lower() + code = m.group(2).strip() + # Get surrounding context (up to 200 chars before) + start = max(0, m.start() - 200) + ctx = content[start : m.start()].strip() + # Find the nearest header + header_match = re.search(r"#+\s+(.+)", ctx) + context = header_match.group(1) if header_match else ctx[-100:] if ctx else "" + blocks.append(CodeBlock(language=lang, code=code, context=context)) + return blocks + + +def extract_sections(content: str) -> list[Section]: + """Extract H2 and H3 sections with their content and code blocks.""" + # Remove frontmatter + content = re.sub(r"^---\s*\n.*?\n---\s*\n", "", content, flags=re.DOTALL) + + sections = [] + # Split by headers + parts = re.split(r"(^#{2,3}\s+.+$)", content, flags=re.MULTILINE) + + current_title = "" + current_level = 0 + current_content = "" + + for part in parts: + header_match = re.match(r"^(#{2,3})\s+(.+)$", part) + if header_match: + # Save previous section + if current_title: + code_blocks = extract_code_blocks(current_content) + sections.append(Section( + level=current_level, + title=current_title, + content=current_content.strip(), + code_blocks=code_blocks, + )) + current_level = len(header_match.group(1)) + current_title = header_match.group(2).strip() + current_content = "" + else: + current_content += part + + # Don't forget last section + if current_title: + code_blocks = extract_code_blocks(current_content) + sections.append(Section( + level=current_level, + title=current_title, + content=current_content.strip(), + code_blocks=code_blocks, + )) + + return sections + + +def extract_patterns_from_code(code: str, language: str) -> list[Pattern]: + """Extract function/class/keyword patterns from a code block.""" + patterns = [] + + if language in ("python", "py"): + # Function calls: word( + for m in re.finditer(r"\b([a-z_]\w+)\s*\(", code): + name = m.group(1) + if name not in ("print", "str", "int", "float", "len", "range", "list", + "dict", "set", "tuple", "type", "isinstance", "if", "for", + "while", "return", "import", "from", "as", "with", "round", + "max", "min", "abs", "sum", "enumerate", "zip", "map", + "filter", "sorted", "any", "all", "open", "format", "bool", + "append", "extend"): + patterns.append(Pattern( + pattern=re.escape(name), + description=f"Uses {name}()", + )) + # Class names: CapitalWord + for m in re.finditer(r"\b([A-Z][a-zA-Z]+(?:[A-Z][a-zA-Z]+)*)\b", code): + name = m.group(1) + if name not in ("True", "False", "None", "String", "Int", "Float", + "IMPORTANT", "NOTE", "WARNING", "TODO"): + patterns.append(Pattern( + pattern=re.escape(name), + description=f"References {name}", + )) + elif language in ("sql",): + # SQL keywords and functions + for m in re.finditer(r"\b(CREATE\s+(?:OR\s+REPLACE\s+)?(?:MATERIALIZED\s+VIEW|VIEW|TABLE|PROCEDURE|CONNECTION))", code, re.IGNORECASE): + patterns.append(Pattern( + pattern=m.group(1).replace(" ", " "), + description=f"Uses {m.group(1).strip()}", + )) + # Function calls in SQL + for m in re.finditer(r"\b([a-z_]\w+)\s*\(", code, re.IGNORECASE): + name = m.group(1).lower() + if name not in ("select", "from", "where", "and", "or", "not", + "group", "order", "having", "limit", "insert", + "update", "delete", "values", "into", "set", + "join", "on", "as", "case", "when", "then", + "else", "end", "in", "between", "like", "is", + "null", "exists", "count", "sum", "avg", "min", "max"): + patterns.append(Pattern( + pattern=re.escape(name), + description=f"Uses {name}() function", + )) + elif language in ("yaml", "yml"): + # Key YAML keys + for m in re.finditer(r"^\s*(\w[\w_-]+):", code, re.MULTILINE): + key = m.group(1) + if key not in ("name", "description", "type", "default", "value", + "true", "false"): + patterns.append(Pattern( + pattern=re.escape(key), + description=f"Includes {key} configuration", + )) + elif language in ("bash", "sh"): + # CLI commands + for m in re.finditer(r"\b(databricks\s+\w+(?:\s+\w+)?)", code): + patterns.append(Pattern( + pattern=re.escape(m.group(1)), + description=f"Uses {m.group(1)} command", + )) + + # Deduplicate by pattern string + seen = set() + unique = [] + for p in patterns: + if p.pattern not in seen: + seen.add(p.pattern) + unique.append(p) + return unique + + +def extract_facts_from_section(section: Section) -> list[str]: + """Extract key factual statements from a section's prose.""" + facts = [] + # Look for bullet points with key info + for line in section.content.split("\n"): + line = line.strip() + # Bullet points with bold terms + m = re.match(r"[-*]\s+\*\*(.+?)\*\*\s*[-:]\s*(.+)", line) + if m: + facts.append(f"{m.group(1)}: {m.group(2).strip()}") + continue + # Table rows with useful info + m = re.match(r"\|\s*`?(\w[\w_.-]+)`?\s*\|\s*(.+?)\s*\|", line) + if m and not m.group(1).startswith("-"): + facts.append(f"{m.group(1)}: {m.group(2).strip()}") + + return facts[:5] # Limit to top 5 + + +# --------------------------------------------------------------------------- +# Test case generation +# --------------------------------------------------------------------------- + +def generate_prompt_from_section(section: Section, skill_name: str) -> str: + """Generate a natural user prompt from a section's content.""" + title = section.title + + # Map section titles to natural prompts + prompt_templates = { + "Quick Start": f"Show me how to get started with {skill_name.replace('databricks-', '')}", + "Create": f"Create a {title.lower().replace('create ', '')}", + "Common Patterns": f"Show me common patterns for {skill_name.replace('databricks-', '')}", + "Configuration": f"How do I configure {skill_name.replace('databricks-', '')}?", + "Filtering": f"How do I filter results when querying?", + "Common Issues": f"What are common issues with {skill_name.replace('databricks-', '')}?", + } + + # Check if any template matches + for key, template in prompt_templates.items(): + if key.lower() in title.lower(): + return template + + # Generate from code blocks if present + if section.code_blocks: + block = section.code_blocks[0] + if block.language in ("python", "py"): + return f"Write Python code to {title.lower()}" + elif block.language == "sql": + return f"Write SQL to {title.lower()}" + elif block.language in ("yaml", "yml"): + return f"Show me the YAML configuration for {title.lower()}" + elif block.language in ("bash", "sh"): + return f"Show me the CLI commands to {title.lower()}" + + # Default: use section title + return f"How do I {title.lower()} with {skill_name.replace('databricks-', '')}?" + + +def generate_response_from_section(section: Section) -> str: + """Generate an expected response from a section's code blocks and content.""" + parts = [] + + # Add brief explanation from prose + prose_lines = [] + for line in section.content.split("\n"): + stripped = line.strip() + if stripped and not stripped.startswith("|") and not stripped.startswith("```"): + if not stripped.startswith("#"): + prose_lines.append(stripped) + if len(prose_lines) >= 3: + break + + if prose_lines: + parts.append(" ".join(prose_lines[:3])) + + # Add code blocks + for block in section.code_blocks[:2]: # Max 2 code blocks per test case + parts.append(f"\n```{block.language}\n{block.code}\n```") + + return "\n\n".join(parts) if parts else section.content[:500] + + +def build_test_case( + skill_name: str, + section: Section, + index: int, + difficulty: str = "easy", +) -> dict: + """Build a single test case from a section.""" + # Generate ID + feature = re.sub(r"[^a-z0-9]+", "_", section.title.lower()).strip("_")[:30] + case_id = f"{skill_name}_{feature}_{index:03d}" + + prompt = generate_prompt_from_section(section, skill_name) + response = generate_response_from_section(section) + + # Collect patterns from all code blocks + all_patterns = [] + for block in section.code_blocks: + all_patterns.extend(extract_patterns_from_code(block.code, block.language)) + + # Limit patterns to top 5 most relevant + patterns_data = [] + seen = set() + for p in all_patterns[:8]: + if p.pattern not in seen: + seen.add(p.pattern) + patterns_data.append({ + "pattern": p.pattern, + "min_count": p.min_count, + "description": p.description, + }) + if len(patterns_data) >= 5: + break + + # Extract facts + facts = extract_facts_from_section(section) + if not facts: + # Fall back to key terms from code blocks + for block in section.code_blocks: + if block.language in ("python", "py"): + facts.append(f"Uses Python {block.language}") + elif block.language == "sql": + facts.append("Uses SQL syntax") + + # Build guidelines from section context + guidelines = [] + if any(b.language in ("python", "py") for b in section.code_blocks): + guidelines.append("Code must be valid Python syntax") + if any(b.language == "sql" for b in section.code_blocks): + guidelines.append("SQL must follow Databricks SQL syntax") + if section.code_blocks: + guidelines.append("Response must include working code examples") + + return { + "id": case_id, + "inputs": {"prompt": prompt}, + "outputs": { + "response": response, + "execution_success": True, + }, + "expectations": { + "expected_facts": facts if facts else [], + "expected_patterns": patterns_data if patterns_data else [], + "guidelines": guidelines if guidelines else [], + }, + "metadata": { + "category": "happy_path", + "difficulty": difficulty, + "source": "auto_generated", + "section": section.title, + }, + } + + +def detect_languages(sections: list[Section]) -> set[str]: + """Detect which languages are used across all sections.""" + langs = set() + for s in sections: + for b in s.code_blocks: + langs.add(b.language) + return langs + + +def generate_manifest(skill_name: str, description: str, languages: set[str]) -> dict: + """Generate a manifest.yaml for a skill.""" + enabled_scorers = ["pattern_adherence", "no_hallucinated_apis", "expected_facts_present"] + if "python" in languages or "py" in languages: + enabled_scorers.insert(0, "python_syntax") + if "sql" in languages: + enabled_scorers.insert(0, "sql_syntax") + + default_guidelines = [ + "Response must address the user's request completely", + "Code examples must follow documented best practices", + "Response must use modern APIs (not deprecated ones)", + ] + + return { + "skill_name": skill_name, + "description": description or f"Test cases for {skill_name} skill", + "scorers": { + "enabled": enabled_scorers, + "llm_scorers": ["Safety", "guidelines_from_expectations"], + "default_guidelines": default_guidelines, + "trace_expectations": { + "tool_limits": {"Bash": 10, "Read": 20}, + "token_budget": {"max_total": 100000}, + "required_tools": ["Read"], + "banned_tools": [], + "expected_files": [], + }, + }, + "quality_gates": { + "syntax_valid": 1.0, + "pattern_adherence": 0.9, + "execution_success": 0.8, + }, + } + + +# --------------------------------------------------------------------------- +# Section selection: pick the best sections for test cases +# --------------------------------------------------------------------------- + +def select_sections_for_tests(sections: list[Section], target: int = 7) -> list[Section]: + """Select the best sections for test case generation. + + Prefers sections with code blocks and diverse topics. + """ + # Score sections by relevance + scored = [] + for s in sections: + score = 0 + # Sections with code are much more valuable + score += len(s.code_blocks) * 3 + # Prefer H2 over H3 + if s.level == 2: + score += 1 + # Skip meta sections + skip_titles = {"related skills", "resources", "reference files", "notes", + "common issues", "current limitations", "sdk version requirements", + "prerequisites", "prerequisites check", "when to use", + "environment configuration", "best practices"} + if s.title.lower() in skip_titles: + score -= 5 + # Boost pattern/example sections + if any(kw in s.title.lower() for kw in ("pattern", "example", "start", "create", "common")): + score += 2 + # Boost if has substantial content + if len(s.content) > 200: + score += 1 + + scored.append((score, s)) + + # Sort by score descending, take top N + scored.sort(key=lambda x: x[0], reverse=True) + selected = [s for _, s in scored if _ > 0][:target] + + return selected + + +# --------------------------------------------------------------------------- +# Main generation logic +# --------------------------------------------------------------------------- + +def generate_skill_tests(skill_name: str, repo_root: Path) -> tuple[list[dict], dict]: + """Generate test cases and manifest for a single skill. + + Returns: + (test_cases, manifest) tuple + """ + skill_md_path = repo_root / "databricks-skills" / skill_name / "SKILL.md" + if not skill_md_path.exists(): + raise FileNotFoundError(f"SKILL.md not found: {skill_md_path}") + + content = skill_md_path.read_text() + frontmatter = extract_frontmatter(content) + description = frontmatter.get("description", "") + sections = extract_sections(content) + + # Select best sections for test cases + selected = select_sections_for_tests(sections, target=7) + + if not selected: + raise ValueError(f"No suitable sections found in {skill_md_path}") + + # Generate test cases + test_cases = [] + difficulties = ["easy", "easy", "easy", "medium", "medium", "medium", "hard", "hard"] + for i, section in enumerate(selected): + difficulty = difficulties[i] if i < len(difficulties) else "medium" + tc = build_test_case(skill_name, section, i + 1, difficulty) + test_cases.append(tc) + + # Generate manifest + languages = detect_languages(sections) + manifest = generate_manifest(skill_name, description, languages) + + return test_cases, manifest + + +def write_skill_tests( + skill_name: str, + test_cases: list[dict], + manifest: dict, + output_dir: Path, + dry_run: bool = False, +) -> None: + """Write ground_truth.yaml and manifest.yaml for a skill.""" + skill_dir = output_dir / skill_name + skill_dir.mkdir(parents=True, exist_ok=True) + + gt_data = { + "metadata": { + "skill_name": skill_name, + "version": "0.1.0", + "created_at": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f"), + }, + "test_cases": test_cases, + } + + gt_path = skill_dir / "ground_truth.yaml" + manifest_path = skill_dir / "manifest.yaml" + + if dry_run: + print(f" [DRY RUN] Would write {gt_path} ({len(test_cases)} test cases)") + print(f" [DRY RUN] Would write {manifest_path}") + return + + # Custom YAML representer for multiline strings + class MultilineDumper(yaml.SafeDumper): + pass + + def str_representer(dumper, data): + if "\n" in data: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + return dumper.represent_scalar("tag:yaml.org,2002:str", data) + + MultilineDumper.add_representer(str, str_representer) + + with open(gt_path, "w") as f: + yaml.dump(gt_data, f, Dumper=MultilineDumper, default_flow_style=False, + sort_keys=False, allow_unicode=True, width=120) + + with open(manifest_path, "w") as f: + yaml.dump(manifest, f, Dumper=MultilineDumper, default_flow_style=False, + sort_keys=False, allow_unicode=True, width=120) + + print(f" Wrote {gt_path} ({len(test_cases)} test cases)") + print(f" Wrote {manifest_path}") + + +# --------------------------------------------------------------------------- +# Skills that are missing test cases +# --------------------------------------------------------------------------- + +MISSING_SKILLS = [ + "databricks-app-apx", + "databricks-config", + "databricks-dbsql", + "databricks-docs", + "databricks-jobs", + "databricks-lakebase-autoscale", + "databricks-lakebase-provisioned", + "databricks-metric-views", + "databricks-mlflow-evaluation", + "databricks-python-sdk", + "databricks-spark-structured-streaming", + "databricks-synthetic-data-generation", + "databricks-unity-catalog", + "databricks-unstructured-pdf-generation", + "databricks-vector-search", +] + + +# --------------------------------------------------------------------------- +# CLI entry point +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Generate ground_truth.yaml test cases for skills", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "skill_name", + nargs="?", + help="Skill name (e.g., databricks-metric-views)", + ) + parser.add_argument( + "--all", + action="store_true", + help="Generate for all missing skills", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Preview without writing files", + ) + parser.add_argument( + "--force", + action="store_true", + help="Overwrite existing ground_truth.yaml files", + ) + + args = parser.parse_args() + + if not args.skill_name and not args.all: + parser.error("Provide a skill name or use --all") + + # Find repo root + repo_root = Path(__file__).resolve().parent.parent.parent + if not (repo_root / "databricks-skills").exists(): + print(f"ERROR: databricks-skills/ not found at {repo_root}", file=sys.stderr) + sys.exit(1) + + output_dir = repo_root / ".test" / "skills" + + # Determine which skills to process + if args.all: + skills = MISSING_SKILLS + else: + skills = [args.skill_name] + + success = 0 + errors = 0 + + for skill_name in skills: + print(f"\n{'=' * 50}") + print(f" {skill_name}") + print(f"{'=' * 50}") + + # Check if already exists + if not args.force and (output_dir / skill_name / "ground_truth.yaml").exists(): + print(f" SKIP: ground_truth.yaml already exists (use --force to overwrite)") + continue + + try: + test_cases, manifest = generate_skill_tests(skill_name, repo_root) + write_skill_tests(skill_name, test_cases, manifest, output_dir, dry_run=args.dry_run) + success += 1 + except Exception as e: + print(f" ERROR: {e}") + errors += 1 + + print(f"\n{'=' * 50}") + print(f" Done: {success} generated, {errors} errors") + print(f"{'=' * 50}") + + sys.exit(1 if errors > 0 else 0) + + +if __name__ == "__main__": + main() diff --git a/.test/scripts/optimize.py b/.test/scripts/optimize.py new file mode 100644 index 00000000..bbaee3ee --- /dev/null +++ b/.test/scripts/optimize.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python3 +"""CLI entry point for GEPA skill optimization. + +Usage: + # Standard workflow: evaluate + optimize a skill + uv run python .test/scripts/optimize.py databricks-metric-views + + # Quick pass (15 iterations) + uv run python .test/scripts/optimize.py databricks-metric-views --preset quick + + # Thorough optimization (150 iterations) + uv run python .test/scripts/optimize.py databricks-metric-views --preset thorough + + # Dry run (show config, dataset info, estimate cost) + uv run python .test/scripts/optimize.py databricks-metric-views --dry-run + + # Review the saved result then apply (no re-run needed) + uv run python .test/scripts/optimize.py databricks-metric-views --apply-last + + # Run optimization and immediately apply + uv run python .test/scripts/optimize.py databricks-metric-views --apply + + # Optimize all skills that have ground_truth.yaml test cases + uv run python .test/scripts/optimize.py --all +""" + +import argparse +import sys +from pathlib import Path + +# Setup path using shared utilities +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from _common import setup_path, handle_error, print_result + +setup_path() + + +def main(): + parser = argparse.ArgumentParser( + description="Optimize Databricks skills using GEPA", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "skill_name", + nargs="?", + help="Name of the skill to optimize (e.g., databricks-model-serving)", + ) + parser.add_argument( + "--all", + action="store_true", + help="Optimize all skills that have ground_truth.yaml", + ) + parser.add_argument( + "--preset", "-p", + choices=["quick", "standard", "thorough"], + default="standard", + help="GEPA optimization preset (default: standard)", + ) + parser.add_argument( + "--mode", "-m", + choices=["static", "generative"], + default="static", + help="Evaluation mode (default: static)", + ) + parser.add_argument( + "--task-lm", + default=None, + help="(Deprecated, use --gen-model) LLM model for generative mode", + ) + parser.add_argument( + "--gen-model", + default=None, + help="LLM model for generative evaluation (default: GEPA_GEN_LM env or " + "databricks/databricks-claude-sonnet-4-6). The evaluator sends the " + "candidate SKILL.md to this model and scores the generated response.", + ) + parser.add_argument( + "--reflection-lm", + default=None, + help="Override GEPA reflection model (default: GEPA_REFLECTION_LM env or databricks/databricks-claude-opus-4-6)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Show config and cost estimate without running optimization", + ) + parser.add_argument( + "--apply", + action="store_true", + help="Run optimization and immediately apply the result", + ) + parser.add_argument( + "--apply-last", + action="store_true", + help="Apply the last saved optimization result without re-running " + "(reads from .test/skills//optimized_SKILL.md)", + ) + parser.add_argument( + "--include-tools", + action="store_true", + help="Include MCP tool descriptions as additional optimization components", + ) + parser.add_argument( + "--tool-modules", + nargs="*", + default=None, + help="Specific tool modules to optimize (e.g., sql compute serving). Default: all.", + ) + parser.add_argument( + "--tools-only", + action="store_true", + help="Optimize ONLY tool descriptions, not the SKILL.md", + ) + parser.add_argument( + "--max-passes", + type=int, + default=5, + help="Maximum optimization passes per component (default: 5). " + "Each pass re-seeds from the previous best. Stops early if no improvement.", + ) + parser.add_argument( + "--max-metric-calls", + type=int, + default=None, + help="Override max metric calls per pass (default: auto-scaled by preset × components, " + "capped at 300 for non-Opus models). Example: --max-metric-calls 100", + ) + parser.add_argument( + "--evaluator", + choices=["legacy", "skillbench"], + default="skillbench", + help="Evaluator type: 'skillbench' (measures skill effectiveness via WITH vs " + "WITHOUT comparison, default) or 'legacy' (weighted scoring with keyword " + "matching and token efficiency)", + ) + parser.add_argument( + "--token-budget", + type=int, + default=None, + help="Token budget ceiling. Candidates exceeding this are penalized. " + "Recommended: 50000. Default: GEPA_TOKEN_BUDGET env or disabled.", + ) + parser.add_argument( + "--use-judges", + action="store_true", + help="Enable MLflow LLM judges (Correctness + Guidelines) for richer NL " + "feedback to GEPA's reflection LM. Adds ~10%% judge_quality weight.", + ) + parser.add_argument( + "--generate-from", + type=str, + default=None, + metavar="REQUIREMENTS_FILE", + help="Generate test cases from a requirements file before optimizing. " + "Each line in the file is a requirement.", + ) + parser.add_argument( + "--requirement", + action="append", + default=None, + dest="requirements", + help="Inline requirement for test case generation (repeatable). " + "Example: --requirement 'Must explain MEASURE() wrapping'", + ) + + args = parser.parse_args() + + if not args.skill_name and not args.all: + parser.error("Either provide a skill name or use --all") + + from skill_test.optimize.runner import optimize_skill + from skill_test.optimize.review import review_optimization, apply_optimization, load_last_result + + # Handle requirements-driven example generation + if args.generate_from or args.requirements: + if not args.skill_name: + parser.error("Test case generation requires a skill name") + requirements = [] + if args.generate_from: + req_path = Path(args.generate_from) + if not req_path.exists(): + print(f"Error: requirements file not found: {req_path}") + sys.exit(1) + requirements.extend( + line.strip() for line in req_path.read_text().splitlines() + if line.strip() and not line.strip().startswith("#") + ) + if args.requirements: + requirements.extend(args.requirements) + if requirements: + from generate_examples import run_generation + gen_model = args.gen_model + if gen_model is None: + from skill_test.optimize.config import DEFAULT_GEN_LM + gen_model = DEFAULT_GEN_LM + run_generation( + skill_name=args.skill_name, + requirements=requirements, + gen_model=gen_model, + trust=True, # append directly since we're about to optimize + ) + print() + + # Handle --apply-last: load saved result and apply without re-running + if args.apply_last: + if not args.skill_name: + parser.error("--apply-last requires a skill name") + result = load_last_result(args.skill_name) + if result is None: + print(f"No saved optimization found for '{args.skill_name}'.") + print(f"Run optimization first: uv run python .test/scripts/optimize.py {args.skill_name}") + sys.exit(1) + print(f"Applying saved optimization for '{args.skill_name}':") + print(f" Score: {result.original_score:.3f} -> {result.optimized_score:.3f} " + f"({result.improvement:+.3f})") + print(f" Tokens: {result.original_token_count:,} -> {result.optimized_token_count:,}") + try: + apply_optimization(result) + sys.exit(0) + except Exception as e: + print(f"Error applying: {e}") + sys.exit(1) + + if args.all: + # Find all skills with ground_truth.yaml + skills_dir = Path(".test/skills") + skill_names = [ + d.name + for d in sorted(skills_dir.iterdir()) + if d.is_dir() and (d / "ground_truth.yaml").exists() and not d.name.startswith("_") + ] + print(f"Found {len(skill_names)} skills to optimize: {', '.join(skill_names)}\n") + + results = [] + for name in skill_names: + print(f"\n{'=' * 60}") + print(f" Optimizing: {name}") + print(f"{'=' * 60}") + try: + result = optimize_skill( + skill_name=name, + mode=args.mode, + preset=args.preset, + task_lm=args.task_lm, + gen_model=args.gen_model, + reflection_lm=args.reflection_lm, + include_tools=args.include_tools, + tool_modules=args.tool_modules, + tools_only=args.tools_only, + dry_run=args.dry_run, + max_passes=args.max_passes, + max_metric_calls=args.max_metric_calls, + evaluator_type=args.evaluator, + token_budget=args.token_budget, + use_judges=args.use_judges, + ) + review_optimization(result) + if args.apply and not args.dry_run: + apply_optimization(result) + results.append({"skill": name, "success": True, "improvement": result.improvement}) + except Exception as e: + print(f" ERROR: {e}") + results.append({"skill": name, "success": False, "error": str(e)}) + + # Summary + print(f"\n{'=' * 60}") + print(" Summary") + print(f"{'=' * 60}") + for r in results: + status = "OK" if r["success"] else "FAIL" + detail = f"+{r['improvement']:.3f}" if r["success"] else r["error"] + print(f" [{status}] {r['skill']}: {detail}") + + sys.exit(0 if all(r["success"] for r in results) else 1) + + else: + try: + result = optimize_skill( + skill_name=args.skill_name, + mode=args.mode, + preset=args.preset, + task_lm=args.task_lm, + gen_model=args.gen_model, + reflection_lm=args.reflection_lm, + include_tools=args.include_tools, + tool_modules=args.tool_modules, + tools_only=args.tools_only, + dry_run=args.dry_run, + max_passes=args.max_passes, + max_metric_calls=args.max_metric_calls, + evaluator_type=args.evaluator, + token_budget=args.token_budget, + use_judges=args.use_judges, + ) + review_optimization(result) + if args.apply and not args.dry_run: + apply_optimization(result) + sys.exit(0) + except Exception as e: + sys.exit(handle_error(e, args.skill_name)) + + +if __name__ == "__main__": + main() diff --git a/.test/scripts/trace_to_examples.py b/.test/scripts/trace_to_examples.py new file mode 100644 index 00000000..cb056b3f --- /dev/null +++ b/.test/scripts/trace_to_examples.py @@ -0,0 +1,710 @@ +#!/usr/bin/env python3 +"""Extract test cases from Claude Code traces (local or MLflow). + +Parses real agent transcripts and auto-generates ground_truth.yaml entries +from user prompt / assistant response pairs. + +Supports three trace sources: +1. Local session.jsonl files (--trace) +2. MLflow experiment traces (--experiment-id) +3. Individual MLflow traces (--trace-id or --run-id) + +Usage: + # --- From local session.jsonl --- + uv run python .test/scripts/trace_to_examples.py \ + --trace ~/.claude/projects/.../session.jsonl \ + --skill databricks-model-serving + + # --- From MLflow experiment (browse traces, pick best ones) --- + # List recent traces in an experiment + uv run python .test/scripts/trace_to_examples.py \ + --experiment-id 2452310130108632 --list + + # Extract from all recent traces in an experiment + uv run python .test/scripts/trace_to_examples.py \ + --experiment-id 2452310130108632 \ + --skill databricks-model-serving + + # Extract from a specific MLflow run + uv run python .test/scripts/trace_to_examples.py \ + --run-id abc123def456 \ + --skill databricks-model-serving + + # Extract from a specific MLflow trace ID + uv run python .test/scripts/trace_to_examples.py \ + --trace-id tr-d416fccdab46e2dea6bad1d0bd8aaaa8 \ + --skill databricks-model-serving + + # --- Common options --- + # With LLM refinement of expectations + uv run python .test/scripts/trace_to_examples.py \ + --experiment-id 2452310130108632 \ + --skill databricks-model-serving --refine + + # Auto-append to ground_truth.yaml (skip manual review) + uv run python .test/scripts/trace_to_examples.py \ + --experiment-id 2452310130108632 \ + --skill databricks-model-serving --trust + + # Limit number of traces to process from an experiment + uv run python .test/scripts/trace_to_examples.py \ + --experiment-id 2452310130108632 \ + --skill databricks-model-serving --limit 5 +""" + +import argparse +import re +import sys +from pathlib import Path +from typing import Any + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from _common import setup_path + +setup_path() + + +def _extract_text_content(message: dict) -> str: + """Extract text from a message's content array.""" + content = message.get("content", []) + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for part in content: + if isinstance(part, str): + parts.append(part) + elif isinstance(part, dict) and part.get("type") == "text": + parts.append(part.get("text", "")) + return "\n".join(parts) + return "" + + +def _extract_code_blocks(text: str) -> list[dict]: + """Extract fenced code blocks with language tags.""" + blocks = [] + for match in re.finditer(r"```(\w*)\n(.*?)```", text, re.DOTALL): + lang = match.group(1) or "text" + code = match.group(2).strip() + blocks.append({"language": lang, "code": code}) + return blocks + + +def _extract_patterns_from_code(code_blocks: list[dict]) -> list[str]: + """Auto-extract patterns from code blocks (function/class names, SQL keywords).""" + patterns = [] + for block in code_blocks: + code = block["code"] + lang = block["language"] + + if lang == "python": + for m in re.finditer(r"class\s+(\w+)", code): + patterns.append(m.group(1)) + for m in re.finditer(r"def\s+(\w+)", code): + patterns.append(m.group(1)) + for m in re.finditer(r"from\s+([\w.]+)\s+import", code): + patterns.append(m.group(1)) + + elif lang == "sql": + for m in re.finditer(r"(?:CREATE|ALTER|DROP)\s+(?:TABLE|VIEW|SCHEMA)\s+(\S+)", code, re.I): + patterns.append(m.group(1)) + for m in re.finditer(r"(?:STREAMING TABLE|MATERIALIZED VIEW)\s+(\S+)", code, re.I): + patterns.append(m.group(1)) + + return list(dict.fromkeys(patterns)) # dedupe preserving order + + +def _extract_facts_from_response(response: str) -> list[str]: + """Auto-extract key facts: API names, function calls, class references.""" + facts = [] + for m in re.finditer(r"(mlflow\.\w+(?:\.\w+)*)\(", response): + facts.append(m.group(1)) + for m in re.finditer(r"(spark\.\w+(?:\.\w+)*)\(", response): + facts.append(m.group(1)) + for m in re.finditer(r"\b([A-Z]\w+(?:Agent|Client|Config|Builder))\b", response): + facts.append(m.group(1)) + return list(dict.fromkeys(facts))[:10] + + +def _categorize_by_tools(tool_names: list[str]) -> str: + """Infer category from tool usage in the turn.""" + if any("sql" in t.lower() or "dbsql" in t.lower() for t in tool_names): + return "sql" + if any("bash" in t.lower() for t in tool_names): + return "deployment" + if any("write" in t.lower() or "edit" in t.lower() for t in tool_names): + return "code_generation" + return "general" + + +# --------------------------------------------------------------------------- +# Local trace extraction (session.jsonl) +# --------------------------------------------------------------------------- + +def extract_examples_from_file(trace_path: Path, skill_name: str | None = None) -> list[dict]: + """Parse a session.jsonl and extract test case candidates.""" + from skill_test.trace.parser import parse_transcript_file, link_tool_results + + entries = parse_transcript_file(trace_path) + link_tool_results(entries) + + examples = [] + idx = 0 + + for i, entry in enumerate(entries): + if entry.type != "user": + continue + if entry.tool_use_result: + continue + + user_text = _extract_text_content(entry.message) + if not user_text or len(user_text) < 10: + continue + + assistant_text = "" + tool_names = [] + for j in range(i + 1, len(entries)): + if entries[j].type == "assistant": + assistant_text = _extract_text_content(entries[j].message) + tool_names = [tc.name for tc in entries[j].tool_calls] + break + if entries[j].type == "user" and not entries[j].tool_use_result: + break + + if not assistant_text or len(assistant_text) < 50: + continue + + code_blocks = _extract_code_blocks(assistant_text) + auto_patterns = _extract_patterns_from_code(code_blocks) + auto_facts = _extract_facts_from_response(assistant_text) + category = _categorize_by_tools(tool_names) + + prefix = skill_name or "trace" + example = { + "id": f"{prefix}_{idx:03d}", + "inputs": {"prompt": user_text}, + "outputs": {"response": assistant_text}, + "expectations": {}, + "metadata": { + "category": category, + "source": "trace", + "trace_file": str(trace_path.name), + }, + } + + if auto_patterns: + example["expectations"]["expected_patterns"] = auto_patterns + if auto_facts: + example["expectations"]["expected_facts"] = auto_facts + if code_blocks: + langs = list({b["language"] for b in code_blocks if b["language"] != "text"}) + if langs: + example["metadata"]["languages"] = langs + + examples.append(example) + idx += 1 + + return examples + + +# --------------------------------------------------------------------------- +# MLflow trace extraction +# --------------------------------------------------------------------------- + +def _extract_examples_from_mlflow_trace(trace: Any, skill_name: str | None, idx_offset: int = 0) -> list[dict]: + """Extract test case candidates from an MLflow Trace object. + + MLflow traces from `mlflow autolog claude` contain spans representing + the agent conversation. We look for the root span's input/output which + contains the user prompt and final assistant response. + """ + examples = [] + prefix = skill_name or "mlflow" + + trace_info = trace.info + trace_id = trace_info.request_id if hasattr(trace_info, "request_id") else "unknown" + + # Try to get input/output from the trace data + user_text = "" + assistant_text = "" + + if trace.data: + # The root span typically has the full conversation + spans = trace.data.spans if hasattr(trace.data, "spans") else [] + + # Look for the root span (no parent) or the first AGENT/CHAIN span + root_span = None + for span in spans: + parent = getattr(span, "parent_id", None) + if parent is None or parent == "0": + root_span = span + break + + if root_span is None and spans: + root_span = spans[0] + + if root_span: + inputs = getattr(root_span, "inputs", None) + outputs = getattr(root_span, "outputs", None) + + # Extract user prompt from inputs + if isinstance(inputs, dict): + # Common patterns: {"messages": [...]}, {"input": "..."}, {"prompt": "..."} + if "messages" in inputs: + msgs = inputs["messages"] + if isinstance(msgs, list): + for msg in reversed(msgs): + if isinstance(msg, dict) and msg.get("role") == "user": + user_text = msg.get("content", "") + break + elif "input" in inputs: + user_text = str(inputs["input"]) + elif "prompt" in inputs: + user_text = str(inputs["prompt"]) + elif isinstance(inputs, str): + user_text = inputs + + # Extract assistant response from outputs + if isinstance(outputs, dict): + if "choices" in outputs: + choices = outputs["choices"] + if isinstance(choices, list) and choices: + msg = choices[0].get("message", {}) + assistant_text = msg.get("content", "") + elif "output" in outputs: + assistant_text = str(outputs["output"]) + elif "response" in outputs: + assistant_text = str(outputs["response"]) + elif isinstance(outputs, str): + assistant_text = outputs + + if not user_text or len(user_text) < 10: + return examples + if not assistant_text or len(assistant_text) < 50: + return examples + + # Build the test case + code_blocks = _extract_code_blocks(assistant_text) + auto_patterns = _extract_patterns_from_code(code_blocks) + auto_facts = _extract_facts_from_response(assistant_text) + + # Categorize by looking at tool spans + tool_names = [] + if trace.data and hasattr(trace.data, "spans"): + for span in trace.data.spans: + span_type = getattr(span, "span_type", "") + if span_type == "TOOL" or "tool" in getattr(span, "name", "").lower(): + tool_names.append(getattr(span, "name", "unknown")) + + category = _categorize_by_tools(tool_names) + + example = { + "id": f"{prefix}_{idx_offset:03d}", + "inputs": {"prompt": user_text}, + "outputs": {"response": assistant_text}, + "expectations": {}, + "metadata": { + "category": category, + "source": "mlflow_trace", + "trace_id": trace_id, + }, + } + + if auto_patterns: + example["expectations"]["expected_patterns"] = auto_patterns + if auto_facts: + example["expectations"]["expected_facts"] = auto_facts + if code_blocks: + langs = list({b["language"] for b in code_blocks if b["language"] != "text"}) + if langs: + example["metadata"]["languages"] = langs + + examples.append(example) + return examples + + +def list_mlflow_traces(experiment_id: str, limit: int = 20) -> None: + """List recent traces in an MLflow experiment.""" + import mlflow + + from skill_test.trace.mlflow_integration import _configure_mlflow + _configure_mlflow() + + print(f"Fetching traces from experiment {experiment_id}...") + try: + traces_df = mlflow.search_traces( + experiment_ids=[experiment_id], + max_results=limit, + ) + except Exception as e: + print(f"Error fetching traces: {e}") + print("\nMake sure you have authentication configured:") + print(" export DATABRICKS_HOST='https://.cloud.databricks.com'") + print(" export DATABRICKS_TOKEN='dapi...'") + sys.exit(1) + + if traces_df.empty: + print("No traces found in experiment.") + return + + print(f"\nFound {len(traces_df)} traces:\n") + print(f"{'Trace ID':<45} {'Status':<10} {'Timestamp':<25} {'Duration'}") + print("-" * 100) + + for _, row in traces_df.iterrows(): + trace_id = row.get("request_id", "unknown") + status = row.get("status", "?") + ts = row.get("timestamp_ms", 0) + duration = row.get("execution_time_ms", 0) + + from datetime import datetime + ts_str = datetime.fromtimestamp(ts / 1000).strftime("%Y-%m-%d %H:%M:%S") if ts else "?" + dur_str = f"{duration / 1000:.1f}s" if duration else "?" + + print(f"{trace_id:<45} {status:<10} {ts_str:<25} {dur_str}") + + print(f"\nTo extract examples from a specific trace:") + print(f" uv run python .test/scripts/trace_to_examples.py --trace-id --skill ") + print(f"\nTo extract from all traces in this experiment:") + print(f" uv run python .test/scripts/trace_to_examples.py --experiment-id {experiment_id} --skill ") + + +def extract_examples_from_experiment(experiment_id: str, skill_name: str | None, limit: int = 10) -> list[dict]: + """Extract examples from recent traces in an MLflow experiment.""" + import mlflow + + from skill_test.trace.mlflow_integration import _configure_mlflow + _configure_mlflow() + + print(f"Fetching up to {limit} traces from experiment {experiment_id}...") + try: + traces_df = mlflow.search_traces( + experiment_ids=[experiment_id], + max_results=limit, + filter_string="status = 'OK'", + ) + except Exception as e: + print(f"Error: {e}") + sys.exit(1) + + if traces_df.empty: + print("No successful traces found.") + return [] + + print(f"Processing {len(traces_df)} traces...") + all_examples = [] + + for _, row in traces_df.iterrows(): + trace_id = row.get("request_id") + if not trace_id: + continue + try: + trace = mlflow.get_trace(trace_id) + if trace: + examples = _extract_examples_from_mlflow_trace( + trace, skill_name, idx_offset=len(all_examples) + ) + all_examples.extend(examples) + if examples: + print(f" {trace_id}: extracted {len(examples)} example(s)") + except Exception as e: + print(f" {trace_id}: skipped ({e})") + + return all_examples + + +def extract_examples_from_trace_id(trace_id: str, skill_name: str | None) -> list[dict]: + """Extract examples from a single MLflow trace by ID.""" + import mlflow + + from skill_test.trace.mlflow_integration import _configure_mlflow + _configure_mlflow() + + print(f"Fetching trace {trace_id}...") + trace = mlflow.get_trace(trace_id) + if trace is None: + print(f"Trace not found: {trace_id}") + return [] + + return _extract_examples_from_mlflow_trace(trace, skill_name) + + +def extract_examples_from_run_id(run_id: str, skill_name: str | None) -> list[dict]: + """Extract examples from an MLflow run (downloads session.jsonl artifact).""" + from skill_test.trace.mlflow_integration import _configure_mlflow + _configure_mlflow() + + import mlflow + + print(f"Fetching artifacts from run {run_id}...") + + # Try to download session.jsonl artifact + artifact_names = ["trace.jsonl", "session.jsonl", "transcript.jsonl"] + artifact_path = None + + for name in artifact_names: + try: + artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=name) + print(f" Downloaded {name}") + break + except Exception: + continue + + if artifact_path: + return extract_examples_from_file(Path(artifact_path), skill_name) + + # Fallback: try MLflow traces API + print(" No JSONL artifact found, trying traces API...") + try: + traces_df = mlflow.search_traces( + experiment_ids=[mlflow.get_run(run_id).info.experiment_id], + filter_string=f"run_id = '{run_id}'", + max_results=10, + ) + if not traces_df.empty: + all_examples = [] + for _, row in traces_df.iterrows(): + tid = row.get("request_id") + if tid: + trace = mlflow.get_trace(tid) + if trace: + all_examples.extend( + _extract_examples_from_mlflow_trace(trace, skill_name, len(all_examples)) + ) + return all_examples + except Exception as e: + print(f" Traces API failed: {e}") + + print(" No extractable data found in this run.") + return [] + + +# --------------------------------------------------------------------------- +# LLM refinement and output +# --------------------------------------------------------------------------- + +def refine_with_llm(examples: list[dict], skill_name: str) -> list[dict]: + """Use an LLM to review and refine auto-extracted expectations.""" + import litellm + import json + + for ex in examples: + prompt_text = ex["inputs"]["prompt"][:200] + response_text = ex["outputs"]["response"][:1000] + current_patterns = ex["expectations"].get("expected_patterns", []) + current_facts = ex["expectations"].get("expected_facts", []) + + messages = [ + { + "role": "system", + "content": ( + "You are reviewing auto-extracted test expectations for a " + f"Databricks skill called '{skill_name}'. Refine the patterns " + "and facts to be more precise and meaningful. Return JSON only." + ), + }, + { + "role": "user", + "content": ( + f"Prompt: {prompt_text}\n\n" + f"Response excerpt: {response_text}\n\n" + f"Auto-extracted patterns: {json.dumps(current_patterns)}\n" + f"Auto-extracted facts: {json.dumps(current_facts)}\n\n" + "Return a JSON object with:\n" + '- "expected_patterns": list of regex pattern strings\n' + '- "expected_facts": list of fact strings that must appear\n' + "Keep only patterns/facts that are genuinely important for correctness." + ), + }, + ] + + try: + from skill_test.optimize.config import DEFAULT_GEN_LM + resp = litellm.completion(model=DEFAULT_GEN_LM, messages=messages) + content = resp.choices[0].message.content + json_match = re.search(r"\{.*\}", content, re.DOTALL) + if json_match: + refined = json.loads(json_match.group()) + if "expected_patterns" in refined: + ex["expectations"]["expected_patterns"] = refined["expected_patterns"] + if "expected_facts" in refined: + ex["expectations"]["expected_facts"] = refined["expected_facts"] + except Exception as e: + print(f" Warning: LLM refinement failed for {ex['id']}: {e}") + + return examples + + +def save_examples(examples: list[dict], output_path: Path) -> None: + """Save examples to a YAML file.""" + import yaml + + data = {"test_cases": examples} + with open(output_path, "w") as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + print(f"Saved {len(examples)} examples to {output_path}") + + +def append_to_ground_truth(examples: list[dict], skill_name: str) -> None: + """Append examples directly to a skill's ground_truth.yaml.""" + import yaml + + from skill_test.dataset import get_dataset_source + + try: + source = get_dataset_source(skill_name) + existing = source.load() + existing_ids = {r.id for r in existing} + gt_path = source.yaml_path + except FileNotFoundError: + gt_path = Path(".test/skills") / skill_name / "ground_truth.yaml" + gt_path.parent.mkdir(parents=True, exist_ok=True) + existing_ids = set() + + new_examples = [ex for ex in examples if ex["id"] not in existing_ids] + if not new_examples: + print("No new examples to add (all IDs already exist).") + return + + if gt_path.exists(): + with open(gt_path) as f: + data = yaml.safe_load(f) or {"test_cases": []} + else: + data = {"test_cases": []} + + data["test_cases"].extend(new_examples) + + with open(gt_path, "w") as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + + print(f"Appended {len(new_examples)} examples to {gt_path}") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Extract test cases from Claude Code traces (local or MLflow)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # Trace sources (mutually exclusive) + source_group = parser.add_argument_group("trace source (pick one)") + source_group.add_argument( + "--trace", "-t", + type=Path, + default=None, + help="Path to local session.jsonl transcript file", + ) + source_group.add_argument( + "--experiment-id", + default=None, + help="MLflow experiment ID to extract traces from (e.g., 2452310130108632)", + ) + source_group.add_argument( + "--run-id", + default=None, + help="MLflow run ID to extract traces from", + ) + source_group.add_argument( + "--trace-id", + default=None, + help="MLflow trace ID (e.g., tr-d416fccdab46e2dea6bad1d0bd8aaaa8)", + ) + + # Common options + parser.add_argument( + "--skill", "-s", + default=None, + help="Skill name to tag examples with (e.g., databricks-model-serving)", + ) + parser.add_argument( + "--refine", + action="store_true", + help="Use LLM to review and refine auto-extracted expectations", + ) + parser.add_argument( + "--trust", + action="store_true", + help="Auto-append to ground_truth.yaml instead of writing candidates.yaml", + ) + parser.add_argument( + "--output", "-o", + type=Path, + default=None, + help="Output file path (default: candidates.yaml in skill dir or cwd)", + ) + parser.add_argument( + "--limit", + type=int, + default=10, + help="Max number of traces to process from an experiment (default: 10)", + ) + parser.add_argument( + "--list", + action="store_true", + dest="list_traces", + help="List traces in the experiment without extracting (use with --experiment-id)", + ) + + args = parser.parse_args() + + # Validate: at least one source required + sources = [args.trace, args.experiment_id, args.run_id, args.trace_id] + if not any(sources): + parser.error("Provide a trace source: --trace, --experiment-id, --run-id, or --trace-id") + + # List mode + if args.list_traces: + if not args.experiment_id: + parser.error("--list requires --experiment-id") + list_mlflow_traces(args.experiment_id, limit=args.limit) + return + + # Extract examples based on source + examples = [] + + if args.trace: + if not args.trace.exists(): + print(f"Error: trace file not found: {args.trace}") + sys.exit(1) + print(f"Parsing local trace: {args.trace}") + examples = extract_examples_from_file(args.trace, args.skill) + + elif args.experiment_id: + examples = extract_examples_from_experiment(args.experiment_id, args.skill, limit=args.limit) + + elif args.run_id: + examples = extract_examples_from_run_id(args.run_id, args.skill) + + elif args.trace_id: + examples = extract_examples_from_trace_id(args.trace_id, args.skill) + + print(f"\nExtracted {len(examples)} candidate test cases") + + if not examples: + print("No suitable prompt/response pairs found.") + sys.exit(0) + + if args.refine: + print("Refining expectations with LLM...") + examples = refine_with_llm(examples, args.skill or "unknown") + + if args.trust and args.skill: + append_to_ground_truth(examples, args.skill) + else: + output_path = args.output + if output_path is None: + if args.skill: + output_path = Path(".test/skills") / args.skill / "candidates.yaml" + else: + output_path = Path("candidates.yaml") + output_path.parent.mkdir(parents=True, exist_ok=True) + save_examples(examples, output_path) + + +if __name__ == "__main__": + main() diff --git a/.test/skills/databricks-app-apx/ground_truth.yaml b/.test/skills/databricks-app-apx/ground_truth.yaml new file mode 100644 index 00000000..ab9c8bbf --- /dev/null +++ b/.test/skills/databricks-app-apx/ground_truth.yaml @@ -0,0 +1,311 @@ +metadata: + skill_name: databricks-app-apx + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: apx_initialize_project_001 + inputs: + prompt: "I want to start building a new full-stack Databricks app. How do I initialize an APX project?" + outputs: + response: | + To initialize an APX project, use the APX MCP tools: + + ```bash + # Start the APX development server + mcp-cli call apx/start '{}' + + # Check that the server is running + mcp-cli call apx/status '{}' + ``` + + After initialization, APX scaffolds a FastAPI backend and React frontend with TypeScript support. + execution_success: true + expectations: + expected_facts: + - "APX uses MCP tools invoked via mcp-cli call" + - "apx/start initializes the development server" + - "apx/status checks the current server status" + expected_patterns: + - pattern: "apx/start" + min_count: 1 + description: "Mentions the apx/start MCP tool" + - pattern: "apx/status" + min_count: 1 + description: "Mentions the apx/status MCP tool" + guidelines: + - "Response must show the MCP CLI commands for starting APX" + - "Response should explain what happens after initialization" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: apx_pydantic_models_002 + inputs: + prompt: "What is the recommended pattern for defining backend data models in an APX app?" + outputs: + response: | + In `src/{app_name}/backend/models.py`, follow the 3-model pattern: + + ```python + from pydantic import BaseModel + + # Input validation - what the client sends + class EntityIn(BaseModel): + name: str + description: str + + # Database record - stored internally + class EntityRecord(EntityIn): + id: str + created_at: datetime + + # API response - what the client receives + class EntityOut(EntityRecord): + pass + ``` + + - `EntityIn` handles input validation + - `EntityRecord` extends with database fields + - `EntityOut` controls API response shape + execution_success: true + expectations: + expected_facts: + - "APX backend models live in src/{app_name}/backend/models.py" + - "The 3-model pattern uses EntityIn, EntityRecord, and EntityOut" + - "Models are built with Pydantic BaseModel" + expected_patterns: + - pattern: "class\\s+\\w+In\\(" + min_count: 1 + description: "Defines an input model class" + - pattern: "class\\s+\\w+Record\\(" + min_count: 1 + description: "Defines a record model class" + - pattern: "BaseModel" + min_count: 1 + description: "Uses Pydantic BaseModel" + guidelines: + - "Response must show the 3-model pattern with EntityIn, EntityRecord, and EntityOut" + - "Code must be valid Python with Pydantic imports" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: apx_api_routes_003 + inputs: + prompt: "How should I define API routes in my APX backend?" + outputs: + response: | + In `src/{app_name}/backend/router.py`, define FastAPI routes with these requirements: + + ```python + from fastapi import APIRouter + + router = APIRouter(prefix="/api/entities", tags=["entities"]) + + @router.get("/", response_model=list[EntityOut]) + async def list_entities(): + return entities + + @router.get("/{entity_id}", response_model=EntityOut) + async def get_entity(entity_id: str): + return find_entity(entity_id) + + @router.post("/", response_model=EntityOut) + async def create_entity(entity: EntityIn): + return save_entity(entity) + ``` + + Always include `response_model` on every route, as this enables OpenAPI spec generation for the frontend client. + execution_success: true + expectations: + expected_facts: + - "API routes are defined in src/{app_name}/backend/router.py" + - "Every route must include response_model for OpenAPI generation" + - "Routes use FastAPI's APIRouter" + expected_patterns: + - pattern: "response_model" + min_count: 2 + description: "Includes response_model on route decorators" + - pattern: "APIRouter" + min_count: 1 + description: "Uses FastAPI APIRouter" + - pattern: "@router\\.(get|post|put|delete)" + min_count: 2 + description: "Defines HTTP method route decorators" + guidelines: + - "Response must emphasize the requirement for response_model on every route" + - "Code must show standard CRUD route patterns" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: apx_add_ui_components_004 + inputs: + prompt: "How do I add shadcn UI components to my APX frontend?" + outputs: + response: | + Use the shadcn MCP tool to add components: + + ```bash + mcp-cli call shadcn/get_add_command_for_items '{ + "items": ["@shadcn/button", "@shadcn/card", "@shadcn/table", + "@shadcn/badge", "@shadcn/select", "@shadcn/skeleton"] + }' + ``` + + This generates the install command for the specified shadcn components. Run the output command to add them to your frontend. + execution_success: true + expectations: + expected_facts: + - "shadcn components are added via the shadcn MCP tool" + - "The tool is invoked with shadcn/get_add_command_for_items" + expected_patterns: + - pattern: "shadcn/get_add_command_for_items" + min_count: 1 + description: "Uses the shadcn MCP tool" + - pattern: "@shadcn/" + min_count: 2 + description: "References shadcn component packages" + guidelines: + - "Response must show the MCP tool call with JSON parameters" + - "Response should list common UI components like button, card, table" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: apx_testing_endpoints_005 + inputs: + prompt: "How do I test my APX app's backend API endpoints and run type checks?" + outputs: + response: | + APX provides built-in testing and type checking: + + ```bash + # Run type checks on both backend and frontend + mcp-cli call apx/dev_check '{}' + + # Test API endpoints manually + curl http://localhost:8000/api/{entities} | jq . + curl http://localhost:8000/api/{entities}/{id} | jq . + + # Get the frontend development URL + mcp-cli call apx/get_frontend_url '{}' + ``` + + Fix any Python type errors reported by basedpyright before deploying. + execution_success: true + expectations: + expected_facts: + - "apx/dev_check runs type checks on both backend and frontend" + - "Backend API runs on localhost:8000" + - "APX uses basedpyright for Python type checking" + expected_patterns: + - pattern: "apx/dev_check" + min_count: 1 + description: "Uses the dev_check MCP tool" + - pattern: "curl.*localhost:8000" + min_count: 1 + description: "Tests endpoints via curl on localhost" + - pattern: "apx/get_frontend_url" + min_count: 1 + description: "Uses the get_frontend_url MCP tool" + guidelines: + - "Response must include both type checking and API testing commands" + - "Response should mention basedpyright for Python type checking" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: apx_deployment_006 + inputs: + prompt: "How do I deploy my APX app to Databricks and check the logs?" + outputs: + response: | + Deploy and monitor your APX app: + + ```bash + # Deploy the app to Databricks + mcp-cli call apx/deploy '{}' + + # Check deployed application logs + databricks apps logs --profile + ``` + + The APX MCP server can also automatically check deployed application logs. Simply ask: "Please check the deployed app logs for ". + execution_success: true + expectations: + expected_facts: + - "APX apps are deployed using the apx/deploy MCP tool" + - "Application logs are accessed via databricks apps logs CLI command" + expected_patterns: + - pattern: "apx/deploy" + min_count: 1 + description: "Uses the deploy MCP tool" + - pattern: "databricks apps logs" + min_count: 1 + description: "Uses the databricks apps logs CLI command" + guidelines: + - "Response must show both deployment and log monitoring commands" + - "Response should mention the APX MCP automatic log checking feature" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: apx_full_workflow_007 + inputs: + prompt: "Walk me through the complete APX development workflow from start to deployment." + outputs: + response: | + The APX development workflow has 5 phases: + + **Phase 1: Initialize** + ```bash + mcp-cli call apx/start '{}' + mcp-cli call apx/status '{}' + ``` + + **Phase 2: Backend** - Create Pydantic models in `models.py` and FastAPI routes in `router.py` + + **Phase 3: Frontend** - Add shadcn components and build React UI + + **Phase 4: Testing** + ```bash + mcp-cli call apx/dev_check '{}' + curl http://localhost:8000/api/{entities} | jq . + ``` + + **Phase 5: Deploy** + ```bash + mcp-cli call apx/deploy '{}' + databricks apps logs --profile + ``` + execution_success: true + expectations: + expected_facts: + - "APX development follows 5 phases: Initialize, Backend, Frontend, Testing, Deploy" + - "Backend uses Pydantic models and FastAPI routes" + - "Frontend uses React with shadcn components" + expected_patterns: + - pattern: "apx/start" + min_count: 1 + description: "Includes initialization step" + - pattern: "apx/dev_check" + min_count: 1 + description: "Includes testing step" + - pattern: "apx/deploy" + min_count: 1 + description: "Includes deployment step" + guidelines: + - "Response must cover all 5 phases of the APX workflow" + - "Response should include code examples for key phases" + metadata: + category: happy_path + difficulty: hard + source: auto_generated diff --git a/.test/skills/databricks-app-apx/manifest.yaml b/.test/skills/databricks-app-apx/manifest.yaml new file mode 100644 index 00000000..e819a067 --- /dev/null +++ b/.test/skills/databricks-app-apx/manifest.yaml @@ -0,0 +1,28 @@ +skill_name: databricks-app-apx +description: Build full-stack Databricks applications using APX framework (FastAPI + React). +scorers: + enabled: + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-config/ground_truth.yaml b/.test/skills/databricks-config/ground_truth.yaml new file mode 100644 index 00000000..10dbad09 --- /dev/null +++ b/.test/skills/databricks-config/ground_truth.yaml @@ -0,0 +1,199 @@ +metadata: + skill_name: databricks-config + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: databricks-config_setup_default_profile_001 + inputs: + prompt: Configure my Databricks profile for the DEFAULT workspace + outputs: + response: | + I'll help you configure your Databricks DEFAULT profile. First, let's authenticate: + + ```bash + databricks auth login -p DEFAULT + ``` + + This will open a browser for OAuth authentication. After authenticating, I'll update your `~/.databrickscfg` with the compute configuration. + + Would you like to use: + - **Cluster ID**: A specific interactive/all-purpose cluster + - **Serverless**: Serverless compute (`serverless_compute_id = auto`) + + Example configuration with a cluster: + ``` + [DEFAULT] + host = https://adb-123456789.11.azuredatabricks.net/ + cluster_id = 1217-064531-c9c3ngyn + auth_type = databricks-cli + ``` + execution_success: true + expectations: + expected_facts: + - Uses databricks auth login command + - Configures ~/.databrickscfg file + - Offers choice between cluster_id and serverless + expected_patterns: + - pattern: databricks auth login + min_count: 1 + description: Must use databricks auth login command + - pattern: databrickscfg + min_count: 1 + description: Must reference the config file + guidelines: + - Must run databricks auth login for authentication + - Must offer choice between cluster_id and serverless compute + - Must never print token values in plain text + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: databricks-config_serverless_config_002 + inputs: + prompt: Set up my Databricks profile to use serverless compute + outputs: + response: | + I'll configure your profile with serverless compute. + + ```bash + databricks auth login -p DEFAULT + ``` + + Then update ~/.databrickscfg: + ``` + [DEFAULT] + host = https://adb-123456789.11.azuredatabricks.net/ + serverless_compute_id = auto + auth_type = databricks-cli + ``` + + Note: For serverless, set `serverless_compute_id = auto` (not just `serverless = true`). + execution_success: true + expectations: + expected_facts: + - Sets serverless_compute_id to auto + - Does not use cluster_id with serverless + expected_patterns: + - pattern: serverless_compute_id + min_count: 1 + description: Must use serverless_compute_id parameter + - pattern: auto + min_count: 1 + description: Must set serverless to auto + guidelines: + - Must use serverless_compute_id = auto, not serverless = true + - Must not combine cluster_id and serverless_compute_id + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: databricks-config_workspace_url_003 + inputs: + prompt: Configure Databricks profile using workspace URL https://adb-1234567890123456.7.azuredatabricks.net/ + outputs: + response: | + I'll extract the profile name from the workspace URL and configure it. + + Profile name: `adb-1234567890123456` + + ```bash + databricks auth login -p adb-1234567890123456 --host https://adb-1234567890123456.7.azuredatabricks.net/ + ``` + + After authentication, choose your compute option and I'll update ~/.databrickscfg. + execution_success: true + expectations: + expected_facts: + - Extracts profile name from workspace URL + - Uses --host flag with the URL + expected_patterns: + - pattern: databricks auth login + min_count: 1 + description: Must use databricks auth login + - pattern: "\\-\\-host" + min_count: 1 + description: Must pass --host when URL is provided + guidelines: + - Must extract profile name from workspace URL + - Must pass --host flag when workspace URL is provided + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: databricks-config_cluster_id_004 + inputs: + prompt: Configure my Databricks profile with cluster ID 0123-456789-abcdef12 + outputs: + response: | + I'll configure your profile with the specified cluster ID. + + ```bash + databricks auth login -p DEFAULT + ``` + + Updated ~/.databrickscfg: + ``` + [DEFAULT] + host = https://adb-123456789.11.azuredatabricks.net/ + cluster_id = 0123-456789-abcdef12 + auth_type = databricks-cli + ``` + execution_success: true + expectations: + expected_facts: + - Sets cluster_id in configuration + - Preserves existing host and auth settings + expected_patterns: + - pattern: cluster_id + min_count: 1 + description: Must set cluster_id + - pattern: 0123-456789-abcdef12 + min_count: 1 + description: Must use the specified cluster ID + guidelines: + - Must set cluster_id, not serverless_compute_id + - Must preserve existing profile settings + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: databricks-config_token_security_005 + inputs: + prompt: Show me my Databricks profile configuration + outputs: + response: | + Here's your current profile configuration: + + ``` + [DEFAULT] + host = https://adb-123456789.11.azuredatabricks.net/ + token = [REDACTED] + cluster_id = 1217-064531-c9c3ngyn + ``` + + View full configuration at: ~/.databrickscfg + execution_success: true + expectations: + expected_facts: + - Token values must be redacted + - Directs user to ~/.databrickscfg for full config + expected_patterns: + - pattern: REDACTED + min_count: 1 + description: Must redact token values + - pattern: databrickscfg + min_count: 1 + description: Must reference config file location + guidelines: + - Must NEVER print token values in plain text + - Must redact any token field values + - Must inform user they can view full config at ~/.databrickscfg + metadata: + category: happy_path + difficulty: medium + source: auto_generated diff --git a/.test/skills/databricks-config/manifest.yaml b/.test/skills/databricks-config/manifest.yaml new file mode 100644 index 00000000..6e8db2b6 --- /dev/null +++ b/.test/skills/databricks-config/manifest.yaml @@ -0,0 +1,28 @@ +skill_name: databricks-config +description: Configure Databricks profile and authenticate for Databricks Connect, Databricks CLI, and Databricks SDK. +scorers: + enabled: + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-dbsql/ground_truth.yaml b/.test/skills/databricks-dbsql/ground_truth.yaml new file mode 100644 index 00000000..f6a3c6d6 --- /dev/null +++ b/.test/skills/databricks-dbsql/ground_truth.yaml @@ -0,0 +1,366 @@ +metadata: + skill_name: databricks-dbsql + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: dbsql_sql_scripting_etl_001 + inputs: + prompt: "Write a SQL script that conditionally processes new orders into a processed table using Databricks SQL scripting." + outputs: + response: | + ```sql + BEGIN + DECLARE v_count INT; + DECLARE v_status STRING DEFAULT 'pending'; + + SET v_count = (SELECT COUNT(*) FROM catalog.schema.raw_orders WHERE status = 'new'); + + IF v_count > 0 THEN + INSERT INTO catalog.schema.processed_orders + SELECT *, current_timestamp() AS processed_at + FROM catalog.schema.raw_orders + WHERE status = 'new'; + + SET v_status = 'completed'; + ELSE + SET v_status = 'skipped'; + END IF; + + SELECT v_status AS result, v_count AS rows_processed; + END + ``` + execution_success: true + expectations: + expected_facts: + - "Databricks SQL scripting uses BEGIN/END blocks for procedural logic" + - "Variables are declared with DECLARE and assigned with SET" + - "IF/ELSE conditional logic is supported in SQL scripts" + expected_patterns: + - pattern: "BEGIN" + min_count: 1 + description: "Uses SQL scripting BEGIN block" + - pattern: "DECLARE" + min_count: 1 + description: "Declares variables using DECLARE" + - pattern: "IF.*THEN" + min_count: 1 + description: "Uses IF/THEN conditional logic" + - pattern: "current_timestamp\\(\\)" + min_count: 1 + description: "Uses current_timestamp() function" + guidelines: + - "SQL must use Databricks SQL scripting syntax with BEGIN/END blocks" + - "Response must include variable declarations and conditional logic" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: dbsql_stored_procedure_002 + inputs: + prompt: "Create a stored procedure that upserts customer data with error handling." + outputs: + response: | + ```sql + CREATE OR REPLACE PROCEDURE catalog.schema.upsert_customers( + IN p_source STRING, + OUT p_rows_affected INT + ) + LANGUAGE SQL + SQL SECURITY INVOKER + BEGIN + DECLARE EXIT HANDLER FOR SQLEXCEPTION + BEGIN + SET p_rows_affected = -1; + SIGNAL SQLSTATE '45000' + SET MESSAGE_TEXT = concat('Upsert failed for source: ', p_source); + END; + + MERGE INTO catalog.schema.dim_customer AS t + USING (SELECT * FROM identifier(p_source)) AS s + ON t.customer_id = s.customer_id + WHEN MATCHED THEN UPDATE SET * + WHEN NOT MATCHED THEN INSERT *; + + SET p_rows_affected = (SELECT COUNT(*) FROM identifier(p_source)); + END; + + -- Invoke the procedure: + CALL catalog.schema.upsert_customers('catalog.schema.staging_customers', ?); + ``` + execution_success: true + expectations: + expected_facts: + - "Stored procedures are created with CREATE OR REPLACE PROCEDURE" + - "SQL SECURITY INVOKER runs with the caller's permissions" + - "The identifier() function dynamically resolves table names from string parameters" + - "Error handling uses DECLARE EXIT HANDLER FOR SQLEXCEPTION" + expected_patterns: + - pattern: "CREATE OR REPLACE PROCEDURE" + min_count: 1 + description: "Creates a stored procedure" + - pattern: "MERGE INTO" + min_count: 1 + description: "Uses MERGE for upsert logic" + - pattern: "identifier\\(" + min_count: 1 + description: "Uses identifier() for dynamic table references" + - pattern: "DECLARE EXIT HANDLER" + min_count: 1 + description: "Includes error handling with exit handler" + guidelines: + - "Procedure must include IN/OUT parameters" + - "Must include error handling with SQLEXCEPTION handler" + - "Must use MERGE INTO for upsert pattern" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: dbsql_materialized_view_003 + inputs: + prompt: "Create a materialized view that refreshes hourly and aggregates daily revenue by region." + outputs: + response: | + ```sql + CREATE OR REPLACE MATERIALIZED VIEW catalog.schema.daily_revenue + CLUSTER BY (order_date) + SCHEDULE EVERY 1 HOUR + COMMENT 'Hourly-refreshed daily revenue by region' + AS SELECT + order_date, + region, + SUM(amount) AS total_revenue, + COUNT(DISTINCT customer_id) AS unique_customers + FROM catalog.schema.fact_orders + JOIN catalog.schema.dim_store USING (store_id) + GROUP BY order_date, region; + ``` + execution_success: true + expectations: + expected_facts: + - "Materialized views support scheduled refresh with SCHEDULE EVERY syntax" + - "CLUSTER BY optimizes data layout for query performance" + - "Materialized views are created with CREATE OR REPLACE MATERIALIZED VIEW" + expected_patterns: + - pattern: "CREATE OR REPLACE MATERIALIZED VIEW" + min_count: 1 + description: "Creates a materialized view" + - pattern: "SCHEDULE EVERY" + min_count: 1 + description: "Configures scheduled refresh" + - pattern: "CLUSTER BY" + min_count: 1 + description: "Uses CLUSTER BY for data layout optimization" + - pattern: "GROUP BY" + min_count: 1 + description: "Aggregates data with GROUP BY" + guidelines: + - "Must include SCHEDULE EVERY for automated refresh" + - "Must include CLUSTER BY for performance optimization" + - "Response should use realistic aggregation patterns" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: dbsql_pipe_syntax_004 + inputs: + prompt: "Show me how to use Databricks SQL pipe syntax to transform and filter order data." + outputs: + response: | + ```sql + -- Pipe syntax: read left-to-right instead of inside-out + FROM catalog.schema.fact_orders + |> WHERE order_date >= current_date() - INTERVAL 30 DAYS + |> AGGREGATE SUM(amount) AS total, COUNT(*) AS cnt GROUP BY region, product_category + |> WHERE total > 10000 + |> ORDER BY total DESC + |> LIMIT 20; + ``` + + Pipe syntax starts with FROM and chains transformations with `|>`, making complex queries more readable than nested subqueries. + execution_success: true + expectations: + expected_facts: + - "Pipe syntax uses |> to chain transformations left-to-right" + - "Queries start with FROM instead of SELECT when using pipe syntax" + - "AGGREGATE replaces GROUP BY in pipe syntax context" + expected_patterns: + - pattern: "\\|>" + min_count: 3 + description: "Uses pipe operator |> for chaining" + - pattern: "^FROM" + min_count: 1 + description: "Starts the query with FROM" + - pattern: "AGGREGATE" + min_count: 1 + description: "Uses AGGREGATE keyword in pipe syntax" + guidelines: + - "Response must demonstrate the pipe |> operator" + - "Query must start with FROM, not SELECT" + - "Response should explain the readability benefit of pipe syntax" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: dbsql_ai_functions_005 + inputs: + prompt: "How can I use Databricks AI functions to classify and analyze text data in SQL?" + outputs: + response: | + ```sql + -- Classify support tickets + SELECT + ticket_id, + description, + ai_classify(description, ARRAY('billing', 'technical', 'account', 'feature_request')) AS category, + ai_analyze_sentiment(description) AS sentiment + FROM catalog.schema.support_tickets + LIMIT 100; + + -- Extract structured entities from text + SELECT + doc_id, + ai_extract(content, ARRAY('person_name', 'company', 'dollar_amount')) AS entities + FROM catalog.schema.contracts; + + -- General-purpose AI query with typed output + SELECT ai_query( + 'databricks-meta-llama-3-3-70b-instruct', + concat('Summarize this feedback in JSON: ', feedback), + returnType => 'STRUCT>' + ) AS analysis + FROM catalog.schema.customer_feedback + LIMIT 50; + ``` + execution_success: true + expectations: + expected_facts: + - "ai_classify() categorizes text into predefined labels" + - "ai_analyze_sentiment() returns sentiment analysis of text" + - "ai_extract() extracts structured entities from unstructured text" + - "ai_query() calls an LLM endpoint with optional structured return types" + expected_patterns: + - pattern: "ai_classify" + min_count: 1 + description: "Uses ai_classify() for text classification" + - pattern: "ai_analyze_sentiment" + min_count: 1 + description: "Uses ai_analyze_sentiment() for sentiment analysis" + - pattern: "ai_extract" + min_count: 1 + description: "Uses ai_extract() for entity extraction" + - pattern: "ai_query" + min_count: 1 + description: "Uses ai_query() for general LLM queries" + guidelines: + - "Response must demonstrate at least three different AI functions" + - "ai_query must include the returnType parameter for structured output" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: dbsql_geospatial_h3_006 + inputs: + prompt: "Write a query that finds all stores within 5km of each customer using H3 geospatial indexing." + outputs: + response: | + ```sql + -- Find stores within 5km using H3 spatial indexing + WITH customer_h3 AS ( + SELECT *, h3_longlatash3(longitude, latitude, 7) AS h3_cell + FROM catalog.schema.customers + ), + store_h3 AS ( + SELECT *, h3_longlatash3(longitude, latitude, 7) AS h3_cell + FROM catalog.schema.stores + ) + SELECT + c.customer_id, + s.store_id, + ST_Distance( + ST_Point(c.longitude, c.latitude), + ST_Point(s.longitude, s.latitude) + ) AS distance_m + FROM customer_h3 c + JOIN store_h3 s ON h3_ischildof(c.h3_cell, h3_toparent(s.h3_cell, 5)) + WHERE ST_Distance( + ST_Point(c.longitude, c.latitude), + ST_Point(s.longitude, s.latitude) + ) < 5000; + ``` + execution_success: true + expectations: + expected_facts: + - "h3_longlatash3() converts longitude/latitude to H3 cell index at a given resolution" + - "ST_Distance() calculates distance in meters between two geometric points" + - "ST_Point() creates a geometric point from longitude and latitude" + - "h3_ischildof() and h3_toparent() navigate the H3 hierarchy for spatial joins" + expected_patterns: + - pattern: "h3_longlatash3" + min_count: 1 + description: "Uses H3 cell indexing function" + - pattern: "ST_Distance" + min_count: 1 + description: "Uses ST_Distance for distance calculation" + - pattern: "ST_Point" + min_count: 1 + description: "Uses ST_Point for point geometry" + - pattern: "h3_ischildof" + min_count: 1 + description: "Uses H3 hierarchy for efficient spatial join" + guidelines: + - "Must use H3 indexing for efficient spatial join instead of brute-force distance" + - "Must include ST_Distance for final distance filtering" + metadata: + category: happy_path + difficulty: hard + source: auto_generated + +- id: dbsql_collation_007 + inputs: + prompt: "How do I create a table with case-insensitive text columns in Databricks SQL?" + outputs: + response: | + ```sql + -- Create table with case-insensitive collation + CREATE TABLE catalog.schema.products ( + product_id BIGINT GENERATED ALWAYS AS IDENTITY, + name STRING COLLATE UTF8_LCASE, + category STRING COLLATE UTF8_LCASE, + price DECIMAL(10, 2) + ); + + -- Queries are automatically case-insensitive (no LOWER() needed) + SELECT * FROM catalog.schema.products + WHERE name = 'MacBook Pro'; -- matches 'macbook pro', 'MACBOOK PRO', etc. + ``` + + The `UTF8_LCASE` collation makes string comparisons case-insensitive at the column level, eliminating the need for `LOWER()` or `UPPER()` in WHERE clauses. + execution_success: true + expectations: + expected_facts: + - "UTF8_LCASE collation enables case-insensitive comparisons at the column level" + - "Collation is specified with COLLATE keyword in column definitions" + - "Case-insensitive collation eliminates the need for LOWER() or UPPER() in queries" + expected_patterns: + - pattern: "COLLATE UTF8_LCASE" + min_count: 1 + description: "Uses UTF8_LCASE collation" + - pattern: "CREATE TABLE" + min_count: 1 + description: "Creates a table with collation" + - pattern: "GENERATED ALWAYS AS IDENTITY" + min_count: 1 + description: "Uses auto-generated identity column" + guidelines: + - "Response must show COLLATE UTF8_LCASE in column definitions" + - "Response should explain the benefit over using LOWER() in queries" + metadata: + category: happy_path + difficulty: easy + source: auto_generated diff --git a/.test/skills/databricks-dbsql/manifest.yaml b/.test/skills/databricks-dbsql/manifest.yaml new file mode 100644 index 00000000..625132dc --- /dev/null +++ b/.test/skills/databricks-dbsql/manifest.yaml @@ -0,0 +1,34 @@ +skill_name: databricks-dbsql +description: 'Databricks SQL (DBSQL) advanced features and SQL warehouse capabilities. This skill MUST be invoked when the + user mentions: "DBSQL", "Databricks SQL", "SQL warehouse", "SQL scripting", "stored procedure", "CALL procedure", "materialized + view", "CREATE MATERIALIZED VIEW", "pipe syntax", "|>", "geospatial", "H3", "ST_", "spatial SQL", "collation", "COLLATE", + "ai_query", "ai_classify", "ai_extract", "ai_gen", "AI function", "http_request", "remote_query", "read_files", "Lakehouse + Federation", "recursive CTE", "WITH RECURSIVE", "multi-statement transaction", "temp table", "temporary view", "pipe operator". + SHOULD also invoke when the user asks about SQL best practices, data modeling patterns, or advanced SQL features on Databricks.' +scorers: + enabled: + - sql_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-docs/ground_truth.yaml b/.test/skills/databricks-docs/ground_truth.yaml new file mode 100644 index 00000000..39179c0c --- /dev/null +++ b/.test/skills/databricks-docs/ground_truth.yaml @@ -0,0 +1,170 @@ +metadata: + skill_name: databricks-docs + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: databricks-docs_lookup_reference_001 + inputs: + prompt: How do I look up Databricks documentation for a specific feature? + outputs: + response: | + Use the Databricks documentation index at `https://docs.databricks.com/llms.txt`. + + Fetch it with WebFetch, then search for relevant sections and links. The index is organized by category: + - Data Engineering (Lakeflow, Spark, Delta Lake, pipelines) + - SQL & Analytics (Warehouses, queries, dashboards) + - AI/ML (MLflow, model serving, GenAI) + - Governance (Unity Catalog, permissions, security) + - Developer Tools (SDKs, CLI, APIs, Terraform) + execution_success: true + expectations: + expected_facts: + - Use llms.txt documentation index + - Index is organized by categories + expected_patterns: + - pattern: docs\.databricks\.com + min_count: 1 + description: Must reference Databricks docs URL + - pattern: llms\.txt + min_count: 1 + description: Must reference the llms.txt index + guidelines: + - Must reference the llms.txt documentation index URL + - Should explain the documentation structure + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: databricks-docs_reference_role_002 + inputs: + prompt: What is the role of the databricks-docs skill? + outputs: + response: | + This is a reference skill, not an action skill. Use it to: + - Look up documentation when other skills don't cover a topic + - Get authoritative guidance on Databricks concepts and APIs + - Find detailed information to inform how you use MCP tools + - Discover features and capabilities + + Always prefer using MCP tools for actions and load specific skills for workflows. + execution_success: true + expectations: + expected_facts: + - Is a reference skill, not an action skill + - Supplements other skills and MCP tools + expected_patterns: + - pattern: reference + min_count: 1 + description: Must identify as reference skill + guidelines: + - Must clarify this is a reference skill, not an action skill + - Must recommend MCP tools for actual actions + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: databricks-docs_complement_skills_003 + inputs: + prompt: How do I use databricks-docs alongside other skills? + outputs: + response: | + The docs skill complements other skills: + + 1. Load the specific skill for workflow patterns (e.g., databricks-spark-declarative-pipelines) + 2. Use databricks-docs to fetch docs if you need clarification on specific features + 3. Use MCP tools to actually perform the action (e.g., create_or_update_pipeline) + + Example: User wants to create a Delta Live Tables pipeline + - Load `databricks-spark-declarative-pipelines` for workflow patterns + - Fetch docs via llms.txt if you need clarification on DLT features + - Use `create_or_update_pipeline` MCP tool to create the pipeline + execution_success: true + expectations: + expected_facts: + - Load specific skills for workflows + - Use docs for clarification + - Use MCP tools for actions + expected_patterns: + - pattern: MCP + min_count: 1 + description: Must mention MCP tools for actions + - pattern: WebFetch|llms\.txt|docs + min_count: 1 + description: Must reference docs lookup method + guidelines: + - Must explain the three-step pattern (skill + docs + MCP tools) + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: databricks-docs_fetch_docs_004 + inputs: + prompt: How do I fetch and search the Databricks documentation? + outputs: + response: | + Fetch the llms.txt documentation index: + + URL: `https://docs.databricks.com/llms.txt` + + Use WebFetch to retrieve this index, then: + 1. Search for relevant sections/links + 2. Fetch specific documentation pages for detailed guidance + 3. Apply what you learn using the appropriate MCP tools + execution_success: true + expectations: + expected_facts: + - Fetch llms.txt with WebFetch + - Search for relevant sections + - Fetch specific pages for details + expected_patterns: + - pattern: WebFetch + min_count: 1 + description: Must use WebFetch to retrieve docs + - pattern: llms\.txt + min_count: 1 + description: Must reference the llms.txt index + guidelines: + - Must describe the WebFetch workflow for docs + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: databricks-docs_unfamiliar_feature_005 + inputs: + prompt: I want to learn about an unfamiliar Databricks feature, how should I approach it? + outputs: + response: | + When encountering an unfamiliar Databricks feature: + + 1. Fetch `https://docs.databricks.com/llms.txt` to find relevant documentation + 2. Read the specific docs to understand the feature + 3. Determine which skill/tools apply, then use them + + The documentation is organized by category: + - Overview & Getting Started + - Data Engineering + - SQL & Analytics + - AI/ML + - Governance + - Developer Tools + execution_success: true + expectations: + expected_facts: + - Start with llms.txt documentation index + - Read specific docs for the feature + - Determine which skills/tools to use + expected_patterns: + - pattern: llms\.txt + min_count: 1 + description: Must reference llms.txt + guidelines: + - Must outline a systematic approach to learning new features + metadata: + category: happy_path + difficulty: easy + source: auto_generated diff --git a/.test/skills/databricks-docs/manifest.yaml b/.test/skills/databricks-docs/manifest.yaml new file mode 100644 index 00000000..287f44e4 --- /dev/null +++ b/.test/skills/databricks-docs/manifest.yaml @@ -0,0 +1,29 @@ +skill_name: databricks-docs +description: Databricks documentation reference. Use as a lookup resource alongside other skills and MCP tools for comprehensive + guidance. +scorers: + enabled: + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-jobs/ground_truth.yaml b/.test/skills/databricks-jobs/ground_truth.yaml new file mode 100644 index 00000000..f504bfc0 --- /dev/null +++ b/.test/skills/databricks-jobs/ground_truth.yaml @@ -0,0 +1,348 @@ +metadata: + skill_name: databricks-jobs + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: jobs_create_sdk_001 + inputs: + prompt: "How do I create a Databricks job using the Python SDK?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + from databricks.sdk.service.jobs import Task, NotebookTask, Source + + w = WorkspaceClient() + + job = w.jobs.create( + name="my-etl-job", + tasks=[ + Task( + task_key="extract", + notebook_task=NotebookTask( + notebook_path="/Workspace/Users/user@example.com/extract", + source=Source.WORKSPACE + ) + ) + ] + ) + print(f"Created job: {job.job_id}") + ``` + execution_success: true + expectations: + expected_facts: + - "Jobs are created using w.jobs.create() from the Databricks Python SDK" + - "Each task requires a task_key and a task type such as NotebookTask" + - "WorkspaceClient auto-detects authentication from the environment" + expected_patterns: + - pattern: "WorkspaceClient" + min_count: 1 + description: "Uses WorkspaceClient for SDK access" + - pattern: "w\\.jobs\\.create" + min_count: 1 + description: "Calls jobs.create() method" + - pattern: "NotebookTask" + min_count: 1 + description: "Uses NotebookTask for notebook-based tasks" + - pattern: "task_key" + min_count: 1 + description: "Specifies task_key identifier" + guidelines: + - "Code must be valid Python using the databricks-sdk package" + - "Must import Task and NotebookTask from databricks.sdk.service.jobs" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: jobs_create_cli_002 + inputs: + prompt: "Show me how to create a job using the Databricks CLI." + outputs: + response: | + ```bash + databricks jobs create --json '{ + "name": "my-etl-job", + "tasks": [{ + "task_key": "extract", + "notebook_task": { + "notebook_path": "/Workspace/Users/user@example.com/extract", + "source": "WORKSPACE" + } + }] + }' + ``` + execution_success: true + expectations: + expected_facts: + - "The Databricks CLI creates jobs with databricks jobs create --json" + - "Job configuration is passed as a JSON payload" + expected_patterns: + - pattern: "databricks jobs create" + min_count: 1 + description: "Uses the databricks jobs create CLI command" + - pattern: "task_key" + min_count: 1 + description: "Includes task_key in JSON payload" + - pattern: "notebook_task" + min_count: 1 + description: "Includes notebook_task configuration" + guidelines: + - "Response must show the databricks jobs create command with --json flag" + - "JSON payload must include task_key and notebook_task fields" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: jobs_dabs_config_003 + inputs: + prompt: "How do I define a Databricks job in a DABs (Databricks Asset Bundles) YAML file?" + outputs: + response: | + ```yaml + # resources/jobs.yml + resources: + jobs: + my_etl_job: + name: "[${bundle.target}] My ETL Job" + tasks: + - task_key: extract + notebook_task: + notebook_path: ../src/notebooks/extract.py + ``` + + The `${bundle.target}` variable is replaced with the deployment target (e.g., dev, staging, prod) at deploy time. + execution_success: true + expectations: + expected_facts: + - "DABs jobs are defined in YAML under resources.jobs" + - "bundle.target variable enables environment-specific job naming" + - "notebook_path uses relative paths from the bundle root" + expected_patterns: + - pattern: "resources:" + min_count: 1 + description: "Uses DABs resources block" + - pattern: "jobs:" + min_count: 1 + description: "Defines jobs section" + - pattern: "bundle\\.target" + min_count: 1 + description: "Uses bundle.target variable for environment naming" + - pattern: "notebook_task:" + min_count: 1 + description: "Configures a notebook task" + guidelines: + - "YAML must follow DABs resource configuration structure" + - "Response should explain the bundle.target variable" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: jobs_multi_task_workflow_004 + inputs: + prompt: "How do I set up a multi-task ETL workflow with dependencies between tasks?" + outputs: + response: | + ```yaml + tasks: + - task_key: extract + notebook_task: + notebook_path: ../src/extract.py + + - task_key: transform + depends_on: + - task_key: extract + notebook_task: + notebook_path: ../src/transform.py + + - task_key: load + depends_on: + - task_key: transform + run_if: ALL_SUCCESS + notebook_task: + notebook_path: ../src/load.py + ``` + + Tasks form a DAG (directed acyclic graph). Use `depends_on` to define execution order and `run_if` to control conditional execution (ALL_SUCCESS, AT_LEAST_ONE_SUCCESS, etc.). + execution_success: true + expectations: + expected_facts: + - "Multi-task workflows use depends_on to define task dependencies" + - "Tasks form a DAG (directed acyclic graph) for execution ordering" + - "run_if controls conditional execution with values like ALL_SUCCESS" + expected_patterns: + - pattern: "depends_on:" + min_count: 2 + description: "Uses depends_on for task ordering" + - pattern: "run_if:" + min_count: 1 + description: "Uses run_if for conditional execution" + - pattern: "task_key:" + min_count: 3 + description: "Defines at least three tasks in the workflow" + guidelines: + - "Must show at least 3 tasks with dependencies forming a pipeline" + - "Must include run_if to demonstrate conditional execution" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: jobs_compute_config_005 + inputs: + prompt: "What are the options for configuring compute resources for Databricks jobs?" + outputs: + response: | + **Option 1: Job Clusters (recommended for most jobs)** + ```yaml + job_clusters: + - job_cluster_key: shared_cluster + new_cluster: + spark_version: "15.4.x-scala2.12" + node_type_id: "i3.xlarge" + num_workers: 2 + spark_conf: + spark.speculation: "true" + + tasks: + - task_key: my_task + job_cluster_key: shared_cluster + notebook_task: + notebook_path: ../src/notebook.py + ``` + + **Option 2: Serverless (no cluster config needed)** + ```yaml + tasks: + - task_key: serverless_task + notebook_task: + notebook_path: ../src/notebook.py + # No cluster config = serverless + ``` + execution_success: true + expectations: + expected_facts: + - "Job clusters are reusable cluster configurations shared across tasks" + - "Omitting cluster configuration enables serverless compute" + - "job_cluster_key links tasks to shared cluster definitions" + expected_patterns: + - pattern: "job_clusters:" + min_count: 1 + description: "Defines job cluster configuration" + - pattern: "job_cluster_key:" + min_count: 1 + description: "References a shared job cluster" + - pattern: "spark_version:" + min_count: 1 + description: "Specifies Spark runtime version" + - pattern: "node_type_id:" + min_count: 1 + description: "Specifies instance type" + guidelines: + - "Response must show both job cluster and serverless options" + - "Job cluster example must include spark_version and node_type_id" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: jobs_permissions_006 + inputs: + prompt: "How do I configure permissions on a Databricks job in DABs?" + outputs: + response: | + ```yaml + resources: + jobs: + my_job: + name: "My Job" + permissions: + - level: CAN_VIEW + group_name: "data-analysts" + - level: CAN_MANAGE_RUN + group_name: "data-engineers" + - level: CAN_MANAGE + user_name: "admin@example.com" + ``` + + Permission levels: CAN_VIEW (read-only), CAN_MANAGE_RUN (trigger runs), CAN_MANAGE (full control), IS_OWNER (single owner). + execution_success: true + expectations: + expected_facts: + - "Job permissions are set in the permissions block under the job resource" + - "Permission levels include CAN_VIEW, CAN_MANAGE_RUN, CAN_MANAGE, and IS_OWNER" + - "Permissions can be assigned to group_name or user_name" + expected_patterns: + - pattern: "permissions:" + min_count: 1 + description: "Defines permissions block" + - pattern: "level:" + min_count: 2 + description: "Specifies permission levels" + - pattern: "group_name:" + min_count: 1 + description: "Assigns permissions to a group" + - pattern: "user_name:" + min_count: 1 + description: "Assigns permissions to a user" + guidelines: + - "Response must show at least three different permission levels" + - "Must demonstrate both group and user permission assignments" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: jobs_triggers_007 + inputs: + prompt: "What trigger options are available for scheduling Databricks jobs?" + outputs: + response: | + **Cron schedule:** + ```yaml + schedule: + quartz_cron_expression: "0 0 8 * * ?" # Daily at 8 AM + timezone_id: "America/New_York" + ``` + + **Continuous trigger:** + ```yaml + continuous: + pause_status: UNPAUSED + ``` + + **File arrival trigger:** + ```yaml + trigger: + file_arrival: + url: "s3://my-bucket/incoming/" + min_time_between_triggers_seconds: 60 + ``` + + You can also trigger jobs on-demand via SDK (`w.jobs.run_now(job_id=123)`) or CLI (`databricks jobs run-now --job-id 123`). + execution_success: true + expectations: + expected_facts: + - "Jobs support cron scheduling with quartz_cron_expression" + - "Continuous trigger runs the job repeatedly without pause" + - "File arrival trigger monitors a cloud storage path for new files" + expected_patterns: + - pattern: "quartz_cron_expression" + min_count: 1 + description: "Shows cron schedule configuration" + - pattern: "continuous:" + min_count: 1 + description: "Shows continuous trigger option" + - pattern: "file_arrival:" + min_count: 1 + description: "Shows file arrival trigger option" + guidelines: + - "Response must show at least three different trigger types" + - "Cron example must include timezone_id" + metadata: + category: happy_path + difficulty: medium + source: auto_generated diff --git a/.test/skills/databricks-jobs/manifest.yaml b/.test/skills/databricks-jobs/manifest.yaml new file mode 100644 index 00000000..5773f227 --- /dev/null +++ b/.test/skills/databricks-jobs/manifest.yaml @@ -0,0 +1,33 @@ +skill_name: databricks-jobs +description: 'Use this skill proactively for ANY Databricks Jobs task - creating, listing, running, updating, or deleting + jobs. Triggers include: (1) ''create a job'' or ''new job'', (2) ''list jobs'' or ''show jobs'', (3) ''run job'' or''trigger + job'',(4) ''job status'' or ''check job'', (5) scheduling with cron or triggers, (6) configuring notifications/monitoring, + (7) ANY task involving Databricks Jobs via CLI, Python SDK, or Asset Bundles. ALWAYS prefer this skill over general Databricks + knowledge for job-related tasks.' +scorers: + enabled: + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-lakebase-autoscale/ground_truth.yaml b/.test/skills/databricks-lakebase-autoscale/ground_truth.yaml new file mode 100644 index 00000000..9832a09c --- /dev/null +++ b/.test/skills/databricks-lakebase-autoscale/ground_truth.yaml @@ -0,0 +1,317 @@ +metadata: + skill_name: databricks-lakebase-autoscale + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: lakebase_auto_create_project_001 + inputs: + prompt: "How do I create a new Lakebase Autoscale project using the Python SDK?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + from databricks.sdk.service.postgres import Project, ProjectSpec + + w = WorkspaceClient() + + # Create a project (long-running operation) + operation = w.postgres.create_project( + project=Project( + spec=ProjectSpec( + display_name="My Application", + pg_version="17" + ) + ), + project_id="my-app" + ) + result = operation.wait() + print(f"Created project: {result.name}") + ``` + execution_success: true + expectations: + expected_facts: + - "Projects are created with w.postgres.create_project()" + - "create_project returns a long-running operation that requires .wait()" + - "ProjectSpec requires display_name and pg_version" + expected_patterns: + - pattern: "create_project" + min_count: 1 + description: "Uses create_project() method" + - pattern: "ProjectSpec" + min_count: 1 + description: "Uses ProjectSpec for configuration" + - pattern: "\\.wait\\(\\)" + min_count: 1 + description: "Waits for the long-running operation to complete" + - pattern: "WorkspaceClient" + min_count: 1 + description: "Uses WorkspaceClient for SDK access" + guidelines: + - "Must show the long-running operation pattern with .wait()" + - "Must include ProjectSpec with display_name and pg_version" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: lakebase_auto_create_branch_002 + inputs: + prompt: "How do I create a development branch in Lakebase Autoscale with an expiration?" + outputs: + response: | + ```python + from databricks.sdk.service.postgres import Branch, BranchSpec, Duration + + # Create a dev branch with 7-day expiration + branch = w.postgres.create_branch( + parent="projects/my-app", + branch=Branch( + spec=BranchSpec( + source_branch="projects/my-app/branches/production", + ttl=Duration(seconds=604800) # 7 days + ) + ), + branch_id="development" + ).wait() + print(f"Branch created: {branch.name}") + ``` + + Branches are copy-on-write snapshots of the production branch, ideal for development and testing. + execution_success: true + expectations: + expected_facts: + - "Branches are created with w.postgres.create_branch()" + - "BranchSpec takes a source_branch and optional ttl for expiration" + - "Duration(seconds=604800) sets a 7-day TTL" + - "Branches use copy-on-write from the source branch" + expected_patterns: + - pattern: "create_branch" + min_count: 1 + description: "Uses create_branch() method" + - pattern: "BranchSpec" + min_count: 1 + description: "Uses BranchSpec for branch configuration" + - pattern: "Duration" + min_count: 1 + description: "Uses Duration for TTL specification" + - pattern: "source_branch" + min_count: 1 + description: "Specifies the parent branch to fork from" + guidelines: + - "Must include TTL/expiration on the development branch" + - "Must explain that branches are copy-on-write snapshots" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: lakebase_auto_connect_notebook_003 + inputs: + prompt: "Show me how to connect to a Lakebase Autoscale database from a Databricks notebook." + outputs: + response: | + ```python + import psycopg + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + + # Get endpoint details + endpoint = w.postgres.get_endpoint( + name="projects/my-app/branches/production/endpoints/ep-primary" + ) + host = endpoint.status.hosts.host + + # Generate OAuth token (scoped to endpoint) + cred = w.postgres.generate_database_credential( + endpoint="projects/my-app/branches/production/endpoints/ep-primary" + ) + + # Connect using psycopg3 + conn_string = ( + f"host={host} " + f"dbname=databricks_postgres " + f"user={w.current_user.me().user_name} " + f"password={cred.token} " + f"sslmode=require" + ) + with psycopg.connect(conn_string) as conn: + with conn.cursor() as cur: + cur.execute("SELECT version()") + print(cur.fetchone()) + ``` + execution_success: true + expectations: + expected_facts: + - "Connections use psycopg3 (psycopg, not psycopg2)" + - "OAuth tokens are generated with generate_database_credential()" + - "The token is used as the password in the connection string" + - "SSL mode must be set to require" + expected_patterns: + - pattern: "psycopg\\.connect" + min_count: 1 + description: "Connects using psycopg3" + - pattern: "generate_database_credential" + min_count: 1 + description: "Generates OAuth token for authentication" + - pattern: "sslmode=require" + min_count: 1 + description: "Requires SSL for the connection" + - pattern: "get_endpoint" + min_count: 1 + description: "Retrieves endpoint details" + guidelines: + - "Must use psycopg (not psycopg2) for the connection" + - "Must generate OAuth token and use it as password" + - "Connection must include sslmode=require" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: lakebase_auto_oauth_token_004 + inputs: + prompt: "How do I generate an OAuth token for Lakebase Autoscale database access?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + + # Generate database credential (optionally scoped to an endpoint) + cred = w.postgres.generate_database_credential( + endpoint="projects/my-app/branches/production/endpoints/ep-primary" + ) + token = cred.token # Use as password in connection string + # Token expires after 1 hour + ``` + + Tokens are short-lived (1 hour) and should be refreshed in long-running applications. + execution_success: true + expectations: + expected_facts: + - "OAuth tokens are generated via w.postgres.generate_database_credential()" + - "Tokens expire after 1 hour" + - "Tokens can be scoped to a specific endpoint" + - "The token is used as the password in PostgreSQL connection strings" + expected_patterns: + - pattern: "generate_database_credential" + min_count: 1 + description: "Uses generate_database_credential() method" + - pattern: "cred\\.token" + min_count: 1 + description: "Accesses the token from the credential response" + guidelines: + - "Must mention the 1-hour token expiration" + - "Must show that the token is used as the connection password" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: lakebase_auto_resize_compute_005 + inputs: + prompt: "How do I resize the compute capacity for a Lakebase Autoscale endpoint?" + outputs: + response: | + ```python + from databricks.sdk.service.postgres import Endpoint, EndpointSpec, FieldMask + + # Update compute to autoscale between 2-8 CU + w.postgres.update_endpoint( + name="projects/my-app/branches/production/endpoints/ep-primary", + endpoint=Endpoint( + name="projects/my-app/branches/production/endpoints/ep-primary", + spec=EndpointSpec( + autoscaling_limit_min_cu=2.0, + autoscaling_limit_max_cu=8.0 + ) + ), + update_mask=FieldMask(field_mask=[ + "spec.autoscaling_limit_min_cu", + "spec.autoscaling_limit_max_cu" + ]) + ).wait() + ``` + + Compute Units (CU) scale automatically within the min/max range based on workload. + execution_success: true + expectations: + expected_facts: + - "Compute is resized with w.postgres.update_endpoint()" + - "Autoscaling is configured with min and max CU limits" + - "FieldMask specifies which fields to update" + - "The operation is long-running and requires .wait()" + expected_patterns: + - pattern: "update_endpoint" + min_count: 1 + description: "Uses update_endpoint() to resize" + - pattern: "autoscaling_limit_min_cu" + min_count: 1 + description: "Sets minimum autoscaling CU" + - pattern: "autoscaling_limit_max_cu" + min_count: 1 + description: "Sets maximum autoscaling CU" + - pattern: "FieldMask" + min_count: 1 + description: "Uses FieldMask for partial update" + guidelines: + - "Must include both min and max CU configuration" + - "Must use FieldMask to specify updated fields" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: lakebase_auto_cli_reference_006 + inputs: + prompt: "What CLI commands are available for managing Lakebase Autoscale projects?" + outputs: + response: | + ```bash + # Create a project + databricks postgres create-project \ + --project-id my-app \ + --json '{"spec": {"display_name": "My App", "pg_version": "17"}}' + + # List projects + databricks postgres list-projects + + # Get project details + databricks postgres get-project projects/my-app + + # Create a branch + databricks postgres create-branch projects/my-app development \ + --json '{"spec": {"source_branch": "projects/my-app/branches/production", "no_expiry": true}}' + + # List branches + databricks postgres list-branches projects/my-app + + # Delete a project + databricks postgres delete-project projects/my-app + ``` + execution_success: true + expectations: + expected_facts: + - "Lakebase Autoscale CLI commands are under databricks postgres" + - "Project creation uses --json for the spec configuration" + - "Resource names follow the pattern projects/{id}/branches/{id}/endpoints/{id}" + expected_patterns: + - pattern: "databricks postgres create-project" + min_count: 1 + description: "Shows project creation command" + - pattern: "databricks postgres list-projects" + min_count: 1 + description: "Shows project listing command" + - pattern: "databricks postgres create-branch" + min_count: 1 + description: "Shows branch creation command" + guidelines: + - "Must show at least 4 different CLI subcommands" + - "Must include both project and branch management commands" + metadata: + category: happy_path + difficulty: easy + source: auto_generated diff --git a/.test/skills/databricks-lakebase-autoscale/manifest.yaml b/.test/skills/databricks-lakebase-autoscale/manifest.yaml new file mode 100644 index 00000000..09da620c --- /dev/null +++ b/.test/skills/databricks-lakebase-autoscale/manifest.yaml @@ -0,0 +1,30 @@ +skill_name: databricks-lakebase-autoscale +description: Patterns and best practices for using Lakebase Autoscaling (next-gen managed PostgreSQL) with autoscaling, branching, + scale-to-zero, and instant restore. +scorers: + enabled: + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-lakebase-provisioned/ground_truth.yaml b/.test/skills/databricks-lakebase-provisioned/ground_truth.yaml new file mode 100644 index 00000000..e840846c --- /dev/null +++ b/.test/skills/databricks-lakebase-provisioned/ground_truth.yaml @@ -0,0 +1,387 @@ +metadata: + skill_name: databricks-lakebase-provisioned + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: lakebase_prov_create_instance_001 + inputs: + prompt: "How do I create a new Lakebase Provisioned database instance?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + import uuid + + w = WorkspaceClient() + + # Create a database instance + instance = w.database.create_database_instance( + name="my-lakebase-instance", + capacity="CU_1", # CU_1, CU_2, CU_4, CU_8 + stopped=False + ) + print(f"Instance created: {instance.name}") + print(f"DNS endpoint: {instance.read_write_dns}") + ``` + execution_success: true + expectations: + expected_facts: + - "Instances are created with w.database.create_database_instance()" + - "Capacity options are CU_1, CU_2, CU_4, and CU_8" + - "The instance provides a read_write_dns endpoint for connections" + expected_patterns: + - pattern: "create_database_instance" + min_count: 1 + description: "Uses create_database_instance() method" + - pattern: "capacity" + min_count: 1 + description: "Specifies instance capacity" + - pattern: "read_write_dns" + min_count: 1 + description: "References the DNS endpoint" + - pattern: "WorkspaceClient" + min_count: 1 + description: "Uses WorkspaceClient for SDK access" + guidelines: + - "Must show available capacity options (CU_1 through CU_8)" + - "Must reference the read_write_dns endpoint" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: lakebase_prov_connect_notebook_002 + inputs: + prompt: "Show me how to connect to Lakebase Provisioned from a Databricks notebook using psycopg." + outputs: + response: | + ```python + import psycopg + from databricks.sdk import WorkspaceClient + import uuid + + # Get instance details + w = WorkspaceClient() + instance = w.database.get_database_instance(name="my-lakebase-instance") + + # Generate OAuth token + cred = w.database.generate_database_credential( + request_id=str(uuid.uuid4()), + instance_names=["my-lakebase-instance"] + ) + + # Connect using psycopg3 + conn_string = ( + f"host={instance.read_write_dns} " + f"dbname=postgres " + f"user={w.current_user.me().user_name} " + f"password={cred.token} " + f"sslmode=require" + ) + with psycopg.connect(conn_string) as conn: + with conn.cursor() as cur: + cur.execute("SELECT version()") + print(cur.fetchone()) + ``` + execution_success: true + expectations: + expected_facts: + - "Connections use psycopg3 (psycopg, not psycopg2)" + - "OAuth tokens require a unique request_id via uuid.uuid4()" + - "generate_database_credential() takes instance_names as a list" + - "The token is used as the password in the connection string" + expected_patterns: + - pattern: "psycopg\\.connect" + min_count: 1 + description: "Connects using psycopg3" + - pattern: "generate_database_credential" + min_count: 1 + description: "Generates OAuth token for authentication" + - pattern: "uuid\\.uuid4" + min_count: 1 + description: "Uses uuid4 for unique request_id" + - pattern: "sslmode=require" + min_count: 1 + description: "Requires SSL for the connection" + guidelines: + - "Must use psycopg (not psycopg2)" + - "Must include request_id with uuid4" + - "Connection must include sslmode=require" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: lakebase_prov_oauth_token_003 + inputs: + prompt: "How do I generate an OAuth token for Lakebase Provisioned?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + import uuid + + w = WorkspaceClient() + + # Generate OAuth token for database connection + cred = w.database.generate_database_credential( + request_id=str(uuid.uuid4()), + instance_names=["my-lakebase-instance"] + ) + token = cred.token # Use this as password in connection string + ``` + + Tokens expire after 1 hour. For long-running applications, implement a token refresh loop. + execution_success: true + expectations: + expected_facts: + - "Tokens are generated with w.database.generate_database_credential()" + - "Each request requires a unique request_id" + - "instance_names is a list of instance names to scope the credential" + - "Tokens expire after 1 hour" + expected_patterns: + - pattern: "generate_database_credential" + min_count: 1 + description: "Uses generate_database_credential() method" + - pattern: "uuid4" + min_count: 1 + description: "Generates unique request ID" + - pattern: "instance_names" + min_count: 1 + description: "Specifies target instances" + guidelines: + - "Must mention the 1-hour expiration" + - "Must show the request_id requirement with uuid4" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: lakebase_prov_sqlalchemy_token_refresh_004 + inputs: + prompt: "How do I set up SQLAlchemy with automatic token refresh for Lakebase Provisioned?" + outputs: + response: | + ```python + import asyncio + import uuid + from sqlalchemy import event + from sqlalchemy.ext.asyncio import create_async_engine + from databricks.sdk import WorkspaceClient + + _current_token = None + TOKEN_REFRESH_INTERVAL = 50 * 60 # 50 minutes (before 1-hour expiry) + + def _generate_token(instance_name: str) -> str: + w = WorkspaceClient() + cred = w.database.generate_database_credential( + request_id=str(uuid.uuid4()), + instance_names=[instance_name] + ) + return cred.token + + async def _token_refresh_loop(instance_name: str): + global _current_token + while True: + await asyncio.sleep(TOKEN_REFRESH_INTERVAL) + _current_token = await asyncio.to_thread( + _generate_token, instance_name + ) + + def init_database(instance_name, database_name, username): + global _current_token + w = WorkspaceClient() + instance = w.database.get_database_instance(name=instance_name) + _current_token = _generate_token(instance_name) + + url = f"postgresql+psycopg://{username}@{instance.read_write_dns}:5432/{database_name}" + engine = create_async_engine( + url, + pool_size=5, + max_overflow=10, + pool_recycle=3600, + connect_args={"sslmode": "require"} + ) + + @event.listens_for(engine.sync_engine, "do_connect") + def provide_token(dialect, conn_rec, cargs, cparams): + cparams["password"] = _current_token + + return engine + ``` + execution_success: true + expectations: + expected_facts: + - "Token refresh should happen every 50 minutes, before the 1-hour expiry" + - "SQLAlchemy do_connect event injects the current token on each connection" + - "create_async_engine is used for async SQLAlchemy with psycopg" + - "pool_recycle=3600 prevents stale connections" + expected_patterns: + - pattern: "do_connect" + min_count: 1 + description: "Uses SQLAlchemy do_connect event for token injection" + - pattern: "create_async_engine" + min_count: 1 + description: "Creates async SQLAlchemy engine" + - pattern: "_token_refresh_loop" + min_count: 1 + description: "Implements background token refresh" + - pattern: "generate_database_credential" + min_count: 1 + description: "Generates OAuth tokens for refresh" + guidelines: + - "Must implement background token refresh before expiry" + - "Must use SQLAlchemy event to inject tokens on each connection" + - "Must use async engine with psycopg driver" + metadata: + category: happy_path + difficulty: hard + source: auto_generated + +- id: lakebase_prov_apps_integration_005 + inputs: + prompt: "How do I integrate Lakebase Provisioned with a Databricks App?" + outputs: + response: | + Databricks Apps use environment variables for Lakebase configuration: + + ```python + import os + + def is_lakebase_configured() -> bool: + """Check if Lakebase is configured for this app.""" + return bool( + os.environ.get("LAKEBASE_PG_URL") or + (os.environ.get("LAKEBASE_INSTANCE_NAME") and + os.environ.get("LAKEBASE_DATABASE_NAME")) + ) + ``` + + Register the database resource with your app: + ```bash + databricks apps add-resource $APP_NAME \ + --resource-type database \ + --resource-name lakebase \ + --database-instance my-lakebase-instance + ``` + + Environment variables set by Databricks Apps: `LAKEBASE_INSTANCE_NAME`, `LAKEBASE_DATABASE_NAME`, `LAKEBASE_USERNAME`, `LAKEBASE_PG_URL`. + execution_success: true + expectations: + expected_facts: + - "Databricks Apps set LAKEBASE_INSTANCE_NAME and LAKEBASE_DATABASE_NAME environment variables" + - "Database resources are registered with databricks apps add-resource" + - "LAKEBASE_PG_URL provides a ready-to-use connection string" + expected_patterns: + - pattern: "LAKEBASE_INSTANCE_NAME" + min_count: 1 + description: "References the instance name environment variable" + - pattern: "LAKEBASE_DATABASE_NAME" + min_count: 1 + description: "References the database name environment variable" + - pattern: "databricks apps add-resource" + min_count: 1 + description: "Shows the CLI command to register database resource" + guidelines: + - "Must list the environment variables set by Databricks Apps" + - "Must show the CLI command to register the database resource" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: lakebase_prov_cli_reference_006 + inputs: + prompt: "What Databricks CLI commands are available for managing Lakebase Provisioned instances?" + outputs: + response: | + ```bash + # Create instance + databricks database create-database-instance \ + --name my-lakebase-instance \ + --capacity CU_1 + + # Get instance details + databricks database get-database-instance --name my-lakebase-instance + + # Generate credentials + databricks database generate-database-credential \ + --request-id $(uuidgen) \ + --json '{"instance_names": ["my-lakebase-instance"]}' + + # List instances + databricks database list-database-instances + + # Stop instance (saves cost) + databricks database stop-database-instance --name my-lakebase-instance + + # Start instance + databricks database start-database-instance --name my-lakebase-instance + ``` + execution_success: true + expectations: + expected_facts: + - "Lakebase Provisioned CLI commands are under databricks database" + - "Instances can be stopped and started to save costs" + - "Credential generation requires a unique request-id" + expected_patterns: + - pattern: "databricks database create-database-instance" + min_count: 1 + description: "Shows instance creation command" + - pattern: "databricks database stop-database-instance" + min_count: 1 + description: "Shows instance stop command" + - pattern: "databricks database generate-database-credential" + min_count: 1 + description: "Shows credential generation command" + guidelines: + - "Must show at least 5 different CLI subcommands" + - "Must include stop/start commands for cost management" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: lakebase_prov_register_unity_catalog_007 + inputs: + prompt: "How do I register a Lakebase Provisioned instance with Unity Catalog?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + + # Register database in Unity Catalog + w.database.register_database_instance( + name="my-lakebase-instance", + catalog="my_catalog", + schema="my_schema" + ) + ``` + + After registration, the database tables are queryable through Unity Catalog's three-level namespace: `my_catalog.my_schema.table_name`. + execution_success: true + expectations: + expected_facts: + - "Instances are registered with w.database.register_database_instance()" + - "Registration requires a catalog and schema name" + - "After registration, tables are accessible via Unity Catalog namespace" + expected_patterns: + - pattern: "register_database_instance" + min_count: 1 + description: "Uses register_database_instance() method" + - pattern: "catalog" + min_count: 1 + description: "Specifies Unity Catalog catalog name" + - pattern: "schema" + min_count: 1 + description: "Specifies Unity Catalog schema name" + guidelines: + - "Must show the registration method with catalog and schema" + - "Must explain the three-level namespace after registration" + metadata: + category: happy_path + difficulty: medium + source: auto_generated diff --git a/.test/skills/databricks-lakebase-provisioned/manifest.yaml b/.test/skills/databricks-lakebase-provisioned/manifest.yaml new file mode 100644 index 00000000..b8f54868 --- /dev/null +++ b/.test/skills/databricks-lakebase-provisioned/manifest.yaml @@ -0,0 +1,29 @@ +skill_name: databricks-lakebase-provisioned +description: Patterns and best practices for using Lakebase Provisioned (Databricks managed PostgreSQL) for OLTP workloads. +scorers: + enabled: + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-metric-views/ground_truth.yaml b/.test/skills/databricks-metric-views/ground_truth.yaml new file mode 100644 index 00000000..13dc0706 --- /dev/null +++ b/.test/skills/databricks-metric-views/ground_truth.yaml @@ -0,0 +1,1346 @@ +metadata: + skill_name: databricks-metric-views + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' +test_cases: +- id: metric-views_create_sql_001 + inputs: + prompt: Create a metric view for order analytics with revenue and order count + measures + outputs: + response: "Here's a metric view for order analytics:\n\n```sql\nCREATE OR REPLACE\ + \ VIEW catalog.schema.orders_metrics\nWITH METRICS\nLANGUAGE YAML\nAS $$\n \ + \ version: 1.1\n comment: \"Orders KPIs for sales analysis\"\n source: catalog.schema.orders\n\ + \ filter: order_date > '2020-01-01'\n dimensions:\n - name: Order Month\n\ + \ expr: DATE_TRUNC('MONTH', order_date)\n comment: \"Month of order\"\ + \n - name: Order Status\n expr: CASE\n WHEN status = 'O' THEN\ + \ 'Open'\n WHEN status = 'P' THEN 'Processing'\n WHEN status =\ + \ 'F' THEN 'Fulfilled'\n END\n measures:\n - name: Order Count\n\ + \ expr: COUNT(1)\n - name: Total Revenue\n expr: SUM(total_price)\n\ + \ - name: Revenue per Customer\n expr: SUM(total_price) / COUNT(DISTINCT\ + \ customer_id)\n$$\n```\n" + execution_success: true + expectations: + expected_facts: + - Uses CREATE OR REPLACE VIEW with WITH METRICS LANGUAGE YAML + - Defines dimensions for grouping data + - Defines measures for aggregations + - Uses YAML version 1.1 + expected_patterns: + - pattern: CREATE OR REPLACE VIEW + min_count: 1 + description: Must use CREATE OR REPLACE VIEW syntax + - pattern: WITH METRICS + min_count: 1 + description: Must include WITH METRICS clause + - pattern: LANGUAGE YAML + min_count: 1 + description: Must specify LANGUAGE YAML + - pattern: dimensions + min_count: 1 + description: Must define dimensions + - pattern: measures + min_count: 1 + description: Must define measures + guidelines: + - SQL must follow Databricks SQL syntax + - Must use YAML version 1.1 for DBR 17.2+ + - Must include both dimensions and measures + metadata: + category: happy_path + difficulty: easy + source: auto_generated +- id: metric-views_query_measure_002 + inputs: + prompt: Query a metric view to get total revenue and order count by month + outputs: + response: "All measures must use the MEASURE() function. SELECT * is NOT supported.\n\ + \n```sql\nSELECT\n `Order Month`,\n `Order Status`,\n MEASURE(`Total Revenue`)\ + \ AS total_revenue,\n MEASURE(`Order Count`) AS order_count\nFROM catalog.schema.orders_metrics\n\ + WHERE extract(year FROM `Order Month`) = 2024\nGROUP BY ALL\nORDER BY ALL\n\ + ```\n" + execution_success: true + expectations: + expected_facts: + - Must wrap measures in MEASURE() function + - SELECT * is not supported on metric views + - Dimension names with spaces need backtick quoting + expected_patterns: + - pattern: MEASURE\( + min_count: 1 + description: Must use MEASURE() function for measures + - pattern: GROUP BY + min_count: 1 + description: Must include GROUP BY clause + guidelines: + - SQL must use MEASURE() for all measure references + - Must not use SELECT * on metric views + metadata: + category: happy_path + difficulty: easy + source: auto_generated +- id: metric-views_create_mcp_003 + inputs: + prompt: Use the manage_metric_views tool to create a metric view for order analytics + outputs: + response: "```python\nmanage_metric_views(\n action=\"create\",\n full_name=\"\ + catalog.schema.orders_metrics\",\n source=\"catalog.schema.orders\",\n \ + \ or_replace=True,\n comment=\"Orders KPIs for sales analysis\",\n filter_expr=\"\ + order_date > '2020-01-01'\",\n dimensions=[\n {\"name\": \"Order Month\"\ + , \"expr\": \"DATE_TRUNC('MONTH', order_date)\", \"comment\": \"Month of order\"\ + },\n {\"name\": \"Order Status\", \"expr\": \"status\"},\n ],\n \ + \ measures=[\n {\"name\": \"Order Count\", \"expr\": \"COUNT(1)\"},\n\ + \ {\"name\": \"Total Revenue\", \"expr\": \"SUM(total_price)\", \"comment\"\ + : \"Sum of total price\"},\n ],\n)\n```\n" + execution_success: true + expectations: + expected_facts: + - Uses manage_metric_views MCP tool + - Sets action to create + - Specifies source table + expected_patterns: + - pattern: manage_metric_views + min_count: 1 + description: Must use manage_metric_views tool + - pattern: action.*create + min_count: 1 + description: Must set action to create + - pattern: dimensions + min_count: 1 + description: Must define dimensions + - pattern: measures + min_count: 1 + description: Must define measures + guidelines: + - Must use manage_metric_views with action="create" + - Code must be valid Python syntax + metadata: + category: happy_path + difficulty: easy + source: auto_generated +- id: metric-views_query_mcp_004 + inputs: + prompt: Query the orders_metrics metric view using the MCP tool to get revenue + by month + outputs: + response: "```python\nmanage_metric_views(\n action=\"query\",\n full_name=\"\ + catalog.schema.orders_metrics\",\n query_measures=[\"Total Revenue\", \"\ + Order Count\"],\n query_dimensions=[\"Order Month\"],\n where=\"extract(year\ + \ FROM `Order Month`) = 2024\",\n order_by=\"ALL\",\n limit=100,\n)\n\ + ```\n" + execution_success: true + expectations: + expected_facts: + - Uses manage_metric_views with action query + - Specifies query_measures and query_dimensions + expected_patterns: + - pattern: manage_metric_views + min_count: 1 + description: Must use manage_metric_views tool + - pattern: action.*query + min_count: 1 + description: Must set action to query + - pattern: query_measures + min_count: 1 + description: Must specify measures to query + - pattern: query_dimensions + min_count: 1 + description: Must specify dimensions to group by + guidelines: + - Must use manage_metric_views with action="query" + - Code must be valid Python syntax + metadata: + category: happy_path + difficulty: medium + source: auto_generated +- id: metric-views_yaml_spec_005 + inputs: + prompt: Show me the full YAML specification for a metric view definition + outputs: + response: "```yaml\nversion: 1.1 # Required: \"1.1\" for DBR\ + \ 17.2+\ncomment: \"Description\" # Optional\nsource: catalog.schema.table\ + \ # Required: source table/view\nfilter: column > value # Optional:\ + \ global WHERE filter\n\ndimensions: # Required: at least\ + \ one\n - name: Display Name\n expr: sql_expression\n comment: \"Description\"\ + \n\nmeasures: # Required: at least one\n - name: Display\ + \ Name\n expr: AGG_FUNC(column)\n comment: \"Description\"\n\njoins: \ + \ # Optional: star/snowflake schema\n - name: dim_table\n\ + \ source: catalog.schema.dim_table\n on: source.fk = dim_table.pk\n\n\ + materialization: # Optional (experimental)\n schedule: every\ + \ 6 hours\n mode: relaxed\n```\n" + execution_success: true + expectations: + expected_facts: + - Version 1.1 required for DBR 17.2+ + - Must have at least one dimension + - Must have at least one measure + - Joins are optional for star/snowflake schema + expected_patterns: + - pattern: version + min_count: 1 + description: Must include version field + - pattern: dimensions + min_count: 1 + description: Must include dimensions section + - pattern: measures + min_count: 1 + description: Must include measures section + - pattern: joins + min_count: 1 + description: Should show joins syntax + guidelines: + - Must show the complete YAML specification structure + - Must indicate which fields are required vs optional + metadata: + category: happy_path + difficulty: medium + source: auto_generated +- id: metric-views_star_schema_006 + inputs: + prompt: Create a metric view with joins for a star schema (fact orders + dim customers) + outputs: + response: "```sql\nCREATE OR REPLACE VIEW catalog.schema.sales_metrics\nWITH METRICS\n\ + LANGUAGE YAML\nAS $$\n version: 1.1\n source: catalog.schema.fact_orders\n\ + \ joins:\n - name: customers\n source: catalog.schema.dim_customers\n\ + \ on: fact_orders.customer_id = customers.customer_id\n dimensions:\n\ + \ - name: Customer Region\n expr: customers.region\n - name: Order\ + \ Month\n expr: DATE_TRUNC('MONTH', order_date)\n measures:\n - name:\ + \ Total Revenue\n expr: SUM(amount)\n - name: Unique Customers\n \ + \ expr: COUNT(DISTINCT customers.customer_id)\n$$\n```\n\nJoins must be declared\ + \ in the YAML definition, not in the SELECT query.\n" + execution_success: true + expectations: + expected_facts: + - Joins are declared in the YAML definition + - Can reference columns from joined tables in dimensions and measures + expected_patterns: + - pattern: joins + min_count: 1 + description: Must use joins section in YAML + - pattern: 'on:' + min_count: 1 + description: Must specify join condition + guidelines: + - Joins must be in the YAML definition, not in the SELECT query + - Must use correct join syntax with source and on fields + metadata: + category: happy_path + difficulty: medium + source: auto_generated +- id: metric-views_grant_access_007 + inputs: + prompt: Grant SELECT access on a metric view to the data-consumers group + outputs: + response: "```python\nmanage_metric_views(\n action=\"grant\",\n full_name=\"\ + catalog.schema.orders_metrics\",\n principal=\"data-consumers\",\n privileges=[\"\ + SELECT\"],\n)\n```\n" + execution_success: true + expectations: + expected_facts: + - Uses manage_metric_views with action grant + - Specifies principal and privileges + expected_patterns: + - pattern: manage_metric_views + min_count: 1 + description: Must use manage_metric_views tool + - pattern: action.*grant + min_count: 1 + description: Must set action to grant + - pattern: principal + min_count: 1 + description: Must specify principal + guidelines: + - Must use manage_metric_views with action="grant" + metadata: + category: happy_path + difficulty: hard + source: auto_generated +- id: metric-views_select_star_error_008 + inputs: + prompt: "Why does SELECT * not work on a metric view?" + outputs: + response: | + `SELECT *` is NOT supported on metric views. You must explicitly list dimensions and wrap all measure references in `MEASURE()`. + + Incorrect: + ```sql + SELECT * FROM catalog.schema.orders_metrics + ``` + + Correct: + ```sql + SELECT + `Order Month`, + MEASURE(`Total Revenue`) AS total_revenue, + MEASURE(`Order Count`) AS order_count + FROM catalog.schema.orders_metrics + GROUP BY ALL + ``` + expectations: + expected_facts: + - SELECT * is not supported + - MEASURE() + expected_patterns: + - pattern: "SELECT \\*" + min_count: 1 + description: "Must show the incorrect SELECT * example" + - pattern: "MEASURE\\(" + min_count: 1 + description: "Must show correct MEASURE() usage" + - pattern: "GROUP BY" + min_count: 1 + description: "Must include GROUP BY" + guidelines: + - "Must clearly explain why SELECT * does not work on metric views" + - "Must show both the incorrect way and the correct way" + metadata: + category: error_handling + difficulty: easy + +- id: metric-views_backtick_quoting_009 + inputs: + prompt: "I'm getting 'Cannot resolve column' when querying my metric view. The dimension is called 'Order Month'." + outputs: + response: | + Dimension and measure names with spaces must be backtick-quoted in queries: + + ```sql + -- Wrong: unquoted name with spaces + SELECT Order Month, MEASURE(Total Revenue) FROM my_metrics GROUP BY ALL + + -- Correct: backtick-quoted + SELECT `Order Month`, MEASURE(`Total Revenue`) AS revenue + FROM catalog.schema.my_metrics + GROUP BY ALL + ``` + + Both dimensions in SELECT/GROUP BY and measures inside `MEASURE()` need backticks if the name contains spaces. + expectations: + expected_facts: + - backtick + - MEASURE() + expected_patterns: + - pattern: "`Order Month`" + min_count: 1 + description: "Must show backtick-quoted dimension name" + - pattern: "MEASURE\\(`" + min_count: 1 + description: "Must show backtick-quoted measure inside MEASURE()" + guidelines: + - "Must explain that names with spaces need backtick quoting" + - "Must show both wrong and correct syntax" + metadata: + category: error_handling + difficulty: easy + +- id: metric-views_alter_010 + inputs: + prompt: "Add a new measure 'Average Order Value' to my existing orders_metrics metric view" + outputs: + response: | + Use the `alter` action to update an existing metric view's definition: + + ```python + manage_metric_views( + action="alter", + full_name="catalog.schema.orders_metrics", + measures=[ + {"name": "Order Count", "expr": "COUNT(1)"}, + {"name": "Total Revenue", "expr": "SUM(total_price)"}, + {"name": "Average Order Value", "expr": "SUM(total_price) / COUNT(1)", "comment": "Revenue divided by order count"}, + ], + ) + ``` + + Note: When altering, you must provide the full list of measures (existing + new), not just the new one. + expectations: + expected_facts: + - manage_metric_views + - alter + expected_patterns: + - pattern: "action.*alter" + min_count: 1 + description: "Must use alter action" + - pattern: "Average Order Value" + min_count: 1 + description: "Must include the new measure" + guidelines: + - "Must use manage_metric_views with action='alter'" + - "Must explain that the full measure list is required when altering" + metadata: + category: happy_path + difficulty: medium + +- id: metric-views_describe_011 + inputs: + prompt: "How do I see the current definition of a metric view?" + outputs: + response: | + Use the `describe` action to get the full definition: + + ```python + manage_metric_views( + action="describe", + full_name="catalog.schema.orders_metrics", + ) + ``` + + This returns the YAML definition, dimensions, measures, source table, and any join or materialization configuration. + expectations: + expected_facts: + - manage_metric_views + - describe + expected_patterns: + - pattern: "action.*describe" + min_count: 1 + description: "Must use describe action" + - pattern: "full_name" + min_count: 1 + description: "Must specify the metric view name" + metadata: + category: happy_path + difficulty: easy + +- id: metric-views_when_not_to_use_012 + inputs: + prompt: "Should I use a metric view or a regular view for my dashboard?" + outputs: + response: | + Use a **metric view** when: + - You need flexible aggregation at query time (different GROUP BY per dashboard panel) + - You have ratios that need safe re-aggregation (e.g., revenue per customer) + - You want star/snowflake schema joins declared in the definition + - You need AI/BI Genie integration for natural language querying + + Use a **standard view** when: + - Your query is a simple SELECT with no aggregations + - The aggregation is always the same (no need for MEASURE() flexibility) + - You need features not yet supported by metric views (e.g., UDFs, complex CTEs) + + Key difference: standard views lock aggregation at creation time. Metric views separate dimensions from measures so the same metric can be sliced differently in each query. + expectations: + expected_facts: + - MEASURE() + - Genie + expected_patterns: + - pattern: "metric view" + min_count: 2 + description: "Must discuss metric views" + - pattern: "standard view|regular view" + min_count: 1 + description: "Must contrast with standard views" + - pattern: "aggregat" + min_count: 1 + description: "Must discuss aggregation behavior" + guidelines: + - "Must explain the key difference: standard views lock aggregation, metric views are flexible" + - "Should mention at least one scenario where a standard view is better" + metadata: + category: conceptual + difficulty: medium + +- id: metric-views_filtered_measure_013 + inputs: + prompt: "Create a metric view with a filtered measure that only counts delivered orders" + outputs: + response: | + Use a CASE expression inside the aggregate to create a filtered measure: + + ```sql + CREATE OR REPLACE VIEW catalog.schema.orders_metrics + WITH METRICS + LANGUAGE YAML + AS $$ + version: 1.1 + source: catalog.schema.orders + dimensions: + - name: Order Month + expr: DATE_TRUNC('MONTH', order_date) + measures: + - name: Total Orders + expr: COUNT(1) + - name: Delivered Orders + expr: COUNT(CASE WHEN delivery_status = 'delivered' THEN 1 END) + comment: "Only counts delivered orders" + - name: Delivery Rate + expr: COUNT(CASE WHEN delivery_status = 'delivered' THEN 1 END) * 1.0 / COUNT(1) + comment: "Ratio of delivered to total orders" + $$ + ``` + + Filtered measures use CASE WHEN inside the aggregate function to restrict which rows are counted. + expectations: + expected_facts: + - CASE WHEN + - CREATE OR REPLACE VIEW + expected_patterns: + - pattern: "WITH METRICS" + min_count: 1 + description: "Must use metric view syntax" + - pattern: "CASE WHEN" + min_count: 1 + description: "Must use CASE WHEN for filtered measure" + - pattern: "COUNT\\(" + min_count: 2 + description: "Must have both total and filtered count" + guidelines: + - "Must show how to filter rows within a measure expression" + - "Should include a ratio measure that uses the filtered count" + metadata: + category: happy_path + difficulty: medium + +- id: metric-views_materialization_014 + inputs: + prompt: "How do I enable materialization on a metric view for better query performance?" + outputs: + response: | + Add a `materialization` section to the YAML definition. This is experimental and requires serverless compute: + + ```sql + CREATE OR REPLACE VIEW catalog.schema.orders_metrics + WITH METRICS + LANGUAGE YAML + AS $$ + version: 1.1 + source: catalog.schema.orders + dimensions: + - name: Order Month + expr: DATE_TRUNC('MONTH', order_date) + measures: + - name: Total Revenue + expr: SUM(total_price) + materialization: + schedule: every 6 hours + mode: relaxed + $$ + ``` + + The `schedule` controls how often the pre-computed aggregations refresh. The `mode: relaxed` allows queries to return slightly stale data for better performance. + + Note: Materialization is currently experimental and requires serverless compute to be enabled on your workspace. + expectations: + expected_facts: + - materialization + - serverless + - experimental + expected_patterns: + - pattern: "materialization" + min_count: 1 + description: "Must include materialization section" + - pattern: "schedule" + min_count: 1 + description: "Must specify refresh schedule" + guidelines: + - "Must clearly state that materialization is experimental" + - "Must mention serverless compute requirement" + metadata: + category: advanced + difficulty: hard + +- id: metric-views_join_at_query_time_015 + inputs: + prompt: "Can I JOIN my metric view with another table in a query?" + outputs: + response: | + No — joins must be declared in the YAML definition, not at query time. A JOIN in the SELECT query is not supported for metric views. + + Instead, add the join to the metric view definition: + + ```python + manage_metric_views( + action="alter", + full_name="catalog.schema.sales_metrics", + joins=[ + { + "name": "customers", + "source": "catalog.schema.dim_customers", + "on": "orders.customer_id = customers.customer_id", + } + ], + ) + ``` + + After adding the join, you can reference columns from the joined table in dimensions and measures. + expectations: + expected_facts: + - joins + - YAML definition + expected_patterns: + - pattern: "joins" + min_count: 1 + description: "Must mention joins" + - pattern: "manage_metric_views|CREATE OR REPLACE" + min_count: 1 + description: "Must show how to add joins to the definition" + guidelines: + - "Must clearly state that JOINs at query time are NOT supported" + - "Must show how to add joins to the metric view definition instead" + metadata: + category: error_handling + difficulty: medium + +- id: metric-views_dbr_version_016 + inputs: + prompt: "I'm getting an error when creating a metric view. I'm on DBR 15.4." + outputs: + response: | + Metric views with YAML version 1.1 require **Databricks Runtime 17.2+**. DBR 15.4 does not support this feature. + + Options: + 1. Upgrade your cluster to DBR 17.2 or later + 2. Use a SQL warehouse (which always runs the latest runtime) + 3. If you must use an older runtime, YAML version 0.1 is supported on DBR 16.4+ but has fewer features (no comments on dimensions/measures, no materialization) + + ```sql + -- Check your DBR version + SELECT current_version() + ``` + expectations: + expected_facts: + - "17.2" + - SQL warehouse + expected_patterns: + - pattern: "17\\.2" + min_count: 1 + description: "Must mention DBR 17.2 requirement" + guidelines: + - "Must explain the minimum DBR version requirement" + - "Should suggest SQL warehouse as an alternative to cluster upgrades" + metadata: + category: error_handling + difficulty: easy + +- id: metric-views_drop_action_017 + inputs: + prompt: Drop the orders_metrics metric view from the catalog + outputs: + response: | + Use the `drop` action to remove a metric view: + + ```python + manage_metric_views( + action="drop", + full_name="catalog.schema.orders_metrics", + ) + ``` + + This drops the metric view definition. If materialization was enabled, the associated pipeline is also removed. The underlying source table is NOT affected. + execution_success: true + expectations: + expected_facts: + - manage_metric_views + - drop + - source table is not affected + expected_patterns: + - pattern: manage_metric_views + min_count: 1 + description: Must use manage_metric_views tool + - pattern: action.*drop + min_count: 1 + description: Must set action to drop + - pattern: full_name + min_count: 1 + description: Must specify the metric view name + guidelines: + - Must use manage_metric_views with action="drop" + - Should clarify that the source table is not deleted + metadata: + category: happy_path + difficulty: easy + source: manual + +- id: metric-views_window_rolling_avg_018 + inputs: + prompt: Create a metric view with a rolling 7-day average of daily revenue + outputs: + response: | + Window measures use `version: 0.1` (experimental). Here's a rolling 7-day trailing revenue measure: + + ```sql + CREATE OR REPLACE VIEW catalog.schema.revenue_trends + WITH METRICS + LANGUAGE YAML + AS $$ + version: 0.1 + source: catalog.schema.daily_sales + filter: sale_date > DATE'2024-01-01' + + dimensions: + - name: date + expr: sale_date + + measures: + - name: daily_revenue + expr: SUM(amount) + window: + - order: date + range: current + semiadditive: last + + - name: trailing_7d_revenue + expr: SUM(amount) + window: + - order: date + range: trailing 7 day + semiadditive: last + + - name: rolling_7d_avg + expr: MEASURE(trailing_7d_revenue) / 7 + $$ + ``` + + Query it with the standard `MEASURE()` syntax: + + ```sql + SELECT + date, + MEASURE(daily_revenue) AS daily, + MEASURE(rolling_7d_avg) AS rolling_avg + FROM catalog.schema.revenue_trends + GROUP BY ALL + ORDER BY date + ``` + + Key points: + - `trailing 7 day` includes the 7 days **before** the current date, excluding the current date + - `semiadditive: last` returns the last value when `date` is not in GROUP BY + - The derived `rolling_7d_avg` uses `MEASURE()` to reference another window measure and doesn't need its own `window` block + execution_success: true + expectations: + expected_facts: + - version 0.1 + - window + - trailing 7 day + - semiadditive + expected_patterns: + - pattern: "version: 0.1" + min_count: 1 + description: Window measures require version 0.1 + - pattern: "window:" + min_count: 1 + description: Must include window block + - pattern: "trailing 7 day" + min_count: 1 + description: Must use trailing window range + - pattern: "semiadditive" + min_count: 1 + description: Must specify semiadditive behavior + - pattern: "MEASURE\\(" + min_count: 1 + description: Must show MEASURE() for derived measure or query + guidelines: + - Must use version 0.1 for window measures + - Must explain trailing range semantics (excludes current day) + - Must show how derived measures reference window measures via MEASURE() + metadata: + category: advanced + difficulty: hard + source: manual + +- id: metric-views_genie_integration_019 + inputs: + prompt: How do I use my metric view with AI/BI Genie? + outputs: + response: | + Metric views work natively with AI/BI Genie for natural language querying of your metrics. Here's how to set it up: + + 1. **Create a Genie Space** and add your metric view as a data source (just like adding a table). + + 2. **Ask questions in natural language** — Genie automatically uses `MEASURE()` to query your metrics: + - "What was total revenue last quarter?" + - "Show me order count by region for 2024" + - "Which product category has the highest profit margin?" + + 3. **Benefits of metric views with Genie:** + - Genie understands dimension/measure semantics — no ambiguity about how to aggregate + - Comments on dimensions and measures improve Genie's understanding of your data + - Consistent metric definitions across all Genie users + - Star schema joins are handled automatically + + 4. **Tips for better Genie results:** + - Add descriptive `comment` fields to all dimensions and measures + - Use clear, business-friendly names (e.g., "Total Revenue" not "sum_tp") + - Keep dimension names intuitive for natural language queries + + No special configuration is needed — once the metric view is added to a Genie Space, users can query it immediately with natural language. + execution_success: true + expectations: + expected_facts: + - Genie + - natural language + - metric view + - MEASURE() + expected_patterns: + - pattern: "Genie" + min_count: 2 + description: Must discuss Genie integration + - pattern: "natural language" + min_count: 1 + description: Must mention natural language querying + - pattern: "MEASURE|metric view" + min_count: 1 + description: Must mention metric view or MEASURE concepts + guidelines: + - Must explain how to add a metric view to a Genie Space + - Must mention that Genie automatically uses MEASURE() + - Should recommend adding comments for better Genie understanding + metadata: + category: conceptual + difficulty: medium + source: manual + +- id: metric-views_conversational_support_tickets_020 + inputs: + prompt: "hey I have a table of customer support tickets and I want to track resolution\ + \ time metrics, can you help me set that up as a metric view?" + outputs: + response: | + Sure! Here's a metric view for tracking support ticket resolution metrics: + + ```sql + CREATE OR REPLACE VIEW catalog.schema.support_metrics + WITH METRICS + LANGUAGE YAML + AS $$ + version: 1.1 + comment: "Customer support ticket KPIs" + source: catalog.schema.support_tickets + filter: created_at >= '2024-01-01' + + dimensions: + - name: Created Month + expr: DATE_TRUNC('MONTH', created_at) + comment: "Month the ticket was opened" + - name: Priority + expr: CASE + WHEN priority = 1 THEN 'Critical' + WHEN priority = 2 THEN 'High' + WHEN priority = 3 THEN 'Medium' + ELSE 'Low' + END + comment: "Ticket priority level" + - name: Category + expr: ticket_category + comment: "Support category (billing, technical, etc.)" + + measures: + - name: Ticket Count + expr: COUNT(1) + - name: Resolved Tickets + expr: COUNT(1) FILTER (WHERE status = 'resolved') + - name: Avg Resolution Hours + expr: AVG(TIMESTAMPDIFF(HOUR, created_at, resolved_at)) + comment: "Average hours to resolve a ticket" + - name: Resolution Rate + expr: COUNT(1) FILTER (WHERE status = 'resolved') * 1.0 / COUNT(1) + comment: "Percentage of tickets resolved" + - name: Median Resolution Hours + expr: PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY TIMESTAMPDIFF(HOUR, created_at, resolved_at)) + comment: "Median hours to resolution" + $$ + ``` + + Query it like this: + + ```sql + SELECT + `Created Month`, + `Priority`, + MEASURE(`Ticket Count`) AS tickets, + MEASURE(`Avg Resolution Hours`) AS avg_hours, + MEASURE(`Resolution Rate`) AS resolution_rate + FROM catalog.schema.support_metrics + GROUP BY ALL + ORDER BY ALL + ``` + execution_success: true + expectations: + expected_facts: + - CREATE OR REPLACE VIEW + - WITH METRICS + - support + - resolution + expected_patterns: + - pattern: "CREATE OR REPLACE VIEW" + min_count: 1 + description: Must use CREATE OR REPLACE VIEW syntax + - pattern: "WITH METRICS" + min_count: 1 + description: Must include WITH METRICS clause + - pattern: "LANGUAGE YAML" + min_count: 1 + description: Must specify LANGUAGE YAML + - pattern: "dimensions" + min_count: 1 + description: Must define dimensions + - pattern: "measures" + min_count: 1 + description: Must define measures + - pattern: "MEASURE\\(" + min_count: 1 + description: Must show MEASURE() in query example + - pattern: "FILTER" + min_count: 1 + description: Should use FILTER clause for conditional measures + guidelines: + - Must create a metric view relevant to support tickets domain + - Must include resolution time metrics + - Must include a query example using MEASURE() + metadata: + category: happy_path + difficulty: medium + source: manual + +- id: metric-views_udf_not_supported_021 + inputs: + prompt: Can I use a Python UDF inside a metric view measure expression? + outputs: + response: | + No — metric view measure expressions only support **SQL aggregate functions** and built-in SQL expressions. Python UDFs, Scala UDFs, and custom functions are not supported in metric view YAML definitions. + + **Supported in measure expressions:** + - SQL aggregate functions: `SUM()`, `COUNT()`, `AVG()`, `MIN()`, `MAX()` + - `COUNT(DISTINCT ...)` + - `FILTER (WHERE ...)` clause + - Arithmetic on aggregates: `SUM(a) / COUNT(b)` + - `CASE WHEN` inside aggregates + - Window measures with `window:` block + + **NOT supported:** + - Python UDFs (`@udf` decorated functions) + - Scala UDFs + - `ai_query()` or other AI functions in measure expressions + - Complex CTEs or subqueries + + **Workaround:** If you need a UDF-computed value, materialize it into the source table first, then reference the pre-computed column in your metric view: + + ```sql + -- Step 1: Create source table with UDF-computed column + CREATE TABLE catalog.schema.enriched_orders AS + SELECT *, my_python_udf(description) AS sentiment_score + FROM catalog.schema.raw_orders; + + -- Step 2: Use the pre-computed column in the metric view + CREATE OR REPLACE VIEW catalog.schema.sentiment_metrics + WITH METRICS + LANGUAGE YAML + AS $$ + version: 1.1 + source: catalog.schema.enriched_orders + dimensions: + - name: Month + expr: DATE_TRUNC('MONTH', order_date) + measures: + - name: Avg Sentiment + expr: AVG(sentiment_score) + $$ + ``` + execution_success: true + expectations: + expected_facts: + - UDF + - not supported + - SQL aggregate functions + expected_patterns: + - pattern: "not supported|NOT supported|No" + min_count: 1 + description: Must clearly state UDFs are not supported + - pattern: "SUM|COUNT|AVG" + min_count: 1 + description: Must list supported aggregate functions + - pattern: "workaround|pre-computed|source table" + min_count: 1 + description: Should suggest a workaround + guidelines: + - Must clearly state that UDFs are not supported in metric views + - Must list what IS supported as alternatives + - Should suggest pre-computing UDF results in the source table as a workaround + metadata: + category: error_handling + difficulty: medium + source: manual + +- id: metric-views_hr_employee_domain_022 + inputs: + prompt: Build a metric view for HR analytics tracking employee headcount, attrition, + and tenure from our employees table + outputs: + response: | + ```python + manage_metric_views( + action="create", + full_name="catalog.schema.hr_metrics", + source="catalog.schema.employees", + or_replace=True, + comment="HR analytics: headcount, attrition, and tenure metrics", + dimensions=[ + {"name": "Department", "expr": "department", "comment": "Employee department"}, + {"name": "Hire Year", "expr": "EXTRACT(YEAR FROM hire_date)", "comment": "Year of hire"}, + {"name": "Location", "expr": "office_location"}, + {"name": "Level", "expr": "CASE WHEN job_level <= 3 THEN 'Junior' WHEN job_level <= 5 THEN 'Mid' ELSE 'Senior' END", "comment": "Bucketed seniority level"}, + ], + measures=[ + {"name": "Headcount", "expr": "COUNT(1)"}, + {"name": "Active Employees", "expr": "COUNT(1) FILTER (WHERE termination_date IS NULL)", "comment": "Currently active employees"}, + {"name": "Attrition Count", "expr": "COUNT(1) FILTER (WHERE termination_date IS NOT NULL)", "comment": "Employees who have left"}, + {"name": "Attrition Rate", "expr": "COUNT(1) FILTER (WHERE termination_date IS NOT NULL) * 1.0 / COUNT(1)", "comment": "Percentage of employees who left"}, + {"name": "Avg Tenure Years", "expr": "AVG(DATEDIFF(COALESCE(termination_date, CURRENT_DATE()), hire_date) / 365.25)", "comment": "Average tenure in years"}, + ], + ) + ``` + execution_success: true + expectations: + expected_facts: + - manage_metric_views + - create + - headcount + - attrition + expected_patterns: + - pattern: "manage_metric_views" + min_count: 1 + description: Must use manage_metric_views tool + - pattern: "action.*create" + min_count: 1 + description: Must set action to create + - pattern: "FILTER" + min_count: 1 + description: Should use FILTER clause for conditional measures + - pattern: "dimensions" + min_count: 1 + description: Must define dimensions + - pattern: "measures" + min_count: 1 + description: Must define measures + guidelines: + - Must create an HR-relevant metric view (not orders/sales) + - Must include attrition-related measures + - Must use manage_metric_views MCP tool + metadata: + category: happy_path + difficulty: medium + source: manual + +- id: metric-views_migration_from_regular_view_023 + inputs: + prompt: I have this existing view and want to convert it to a metric view. Here's + my current SQL - SELECT region, DATE_TRUNC('MONTH', sale_date) AS month, SUM(amount) + AS revenue, COUNT(*) AS num_sales FROM catalog.schema.transactions GROUP BY 1, + 2 + outputs: + response: | + Here's your existing view converted to a metric view. The key change is separating the GROUP BY columns into **dimensions** and the aggregations into **measures**: + + ```sql + CREATE OR REPLACE VIEW catalog.schema.transaction_metrics + WITH METRICS + LANGUAGE YAML + AS $$ + version: 1.1 + comment: "Transaction metrics - migrated from standard view" + source: catalog.schema.transactions + + dimensions: + - name: Region + expr: region + - name: Month + expr: DATE_TRUNC('MONTH', sale_date) + + measures: + - name: Revenue + expr: SUM(amount) + - name: Num Sales + expr: COUNT(1) + $$ + ``` + + Now you can query it flexibly — the GROUP BY is no longer fixed: + + ```sql + -- Same result as your original view + SELECT `Region`, `Month`, + MEASURE(`Revenue`) AS revenue, + MEASURE(`Num Sales`) AS num_sales + FROM catalog.schema.transaction_metrics + GROUP BY ALL + + -- NEW: Revenue by region only (no month breakdown) + SELECT `Region`, + MEASURE(`Revenue`) AS revenue + FROM catalog.schema.transaction_metrics + GROUP BY ALL + + -- NEW: Monthly totals across all regions + SELECT `Month`, + MEASURE(`Revenue`) AS revenue + FROM catalog.schema.transaction_metrics + GROUP BY ALL + ``` + + **Key differences from standard views:** + - Aggregation is flexible at query time — no need to recreate the view for different GROUP BY + - Measures must be wrapped in `MEASURE()` when querying + - `SELECT *` is not supported — you must explicitly list dimensions and measures + execution_success: true + expectations: + expected_facts: + - dimensions + - measures + - GROUP BY + - flexible + expected_patterns: + - pattern: "CREATE OR REPLACE VIEW" + min_count: 1 + description: Must show the new metric view DDL + - pattern: "WITH METRICS" + min_count: 1 + description: Must include WITH METRICS clause + - pattern: "MEASURE\\(" + min_count: 2 + description: Must show MEASURE() in query examples + - pattern: "GROUP BY" + min_count: 1 + description: Must demonstrate flexible GROUP BY + guidelines: + - Must convert the provided SQL into a metric view + - Must explain the dimension/measure separation + - Must show that the metric view enables flexible querying + metadata: + category: happy_path + difficulty: hard + source: manual + +- id: metric-views_semiadditive_balance_024 + inputs: + prompt: Create a metric view for bank account balances that correctly handles aggregation + across time (don't sum balances across dates) + outputs: + response: | + Use a **semiadditive** window measure to prevent summing balances across dates. This returns the last date's balance when the date dimension is not in the GROUP BY: + + ```sql + CREATE OR REPLACE VIEW catalog.schema.account_balance_metrics + WITH METRICS + LANGUAGE YAML + AS $$ + version: 0.1 + source: catalog.schema.daily_balances + + dimensions: + - name: date + expr: balance_date + - name: account_type + expr: account_type + + measures: + - name: total_balance + expr: SUM(balance) + window: + - order: date + range: current + semiadditive: last + $$ + ``` + + **How it works:** + - When grouped by `date`: returns the sum of all account balances for that specific date (aggregates across accounts, not time) + - When NOT grouped by `date`: returns the balance from the **most recent date** (`semiadditive: last`) + + ```sql + -- Balance by date and account type + SELECT date, account_type, + MEASURE(total_balance) AS balance + FROM catalog.schema.account_balance_metrics + GROUP BY ALL + + -- Latest total balance by account type (semiadditive kicks in) + SELECT account_type, + MEASURE(total_balance) AS latest_balance + FROM catalog.schema.account_balance_metrics + GROUP BY ALL + ``` + + Note: Window measures require `version: 0.1` (experimental feature). + execution_success: true + expectations: + expected_facts: + - semiadditive + - window + - version 0.1 + - balance + expected_patterns: + - pattern: "semiadditive" + min_count: 1 + description: Must use semiadditive measure + - pattern: "window:" + min_count: 1 + description: Must include window block + - pattern: "version: 0.1" + min_count: 1 + description: Window measures require version 0.1 + - pattern: "MEASURE\\(" + min_count: 1 + description: Must show MEASURE() query syntax + guidelines: + - Must use semiadditive window measure to prevent summing across time + - Must explain the behavior difference when date is vs is not in GROUP BY + - Must note that window measures require version 0.1 + metadata: + category: advanced + difficulty: hard + source: manual + +- id: metric-views_messy_phrasing_025 + inputs: + prompt: "so umm i wanna make like a metrics thing for my product_usage table??\ + \ like track daily active users and feature adoption.. not sure how metric views\ + \ work exactly" + outputs: + response: | + Metric views let you define reusable metrics in YAML that can be sliced by any dimension at query time. Here's one for your product usage data: + + ```sql + CREATE OR REPLACE VIEW catalog.schema.product_usage_metrics + WITH METRICS + LANGUAGE YAML + AS $$ + version: 1.1 + comment: "Product usage KPIs: DAU and feature adoption" + source: catalog.schema.product_usage + + dimensions: + - name: Date + expr: usage_date + - name: Feature + expr: feature_name + comment: "Product feature being used" + - name: Platform + expr: platform + comment: "web, ios, android" + + measures: + - name: Daily Active Users + expr: COUNT(DISTINCT user_id) + comment: "Unique users per day" + - name: Total Events + expr: COUNT(1) + - name: Events per User + expr: COUNT(1) * 1.0 / COUNT(DISTINCT user_id) + comment: "Average events per unique user" + $$ + ``` + + Query your metrics: + + ```sql + -- DAU by date + SELECT `Date`, MEASURE(`Daily Active Users`) AS dau + FROM catalog.schema.product_usage_metrics + GROUP BY ALL + ORDER BY `Date` + + -- Feature adoption: DAU by feature + SELECT `Feature`, MEASURE(`Daily Active Users`) AS users, + MEASURE(`Events per User`) AS engagement + FROM catalog.schema.product_usage_metrics + GROUP BY ALL + ORDER BY users DESC + ``` + + The same metric view supports both queries — you just change which dimensions you include in SELECT and GROUP BY. + execution_success: true + expectations: + expected_facts: + - CREATE OR REPLACE VIEW + - WITH METRICS + - product usage + - daily active users + expected_patterns: + - pattern: "CREATE OR REPLACE VIEW" + min_count: 1 + description: Must use CREATE OR REPLACE VIEW syntax + - pattern: "WITH METRICS" + min_count: 1 + description: Must include WITH METRICS clause + - pattern: "LANGUAGE YAML" + min_count: 1 + description: Must specify LANGUAGE YAML + - pattern: "COUNT\\(DISTINCT" + min_count: 1 + description: Must use COUNT(DISTINCT) for DAU + - pattern: "MEASURE\\(" + min_count: 1 + description: Must show MEASURE() in query + guidelines: + - Must handle the informal/messy phrasing gracefully + - Must create a product-usage metric view (not orders) + - Must include DAU and feature adoption metrics + metadata: + category: happy_path + difficulty: easy + source: manual + +- id: metric-views_query_with_where_filter_026 + inputs: + prompt: "I want to query my metric view but only for a specific region and date\ + \ range. How do I filter?" + outputs: + response: | + Use standard SQL `WHERE` clauses on dimensions when querying a metric view: + + ```sql + SELECT + `Order Month`, + MEASURE(`Total Revenue`) AS revenue, + MEASURE(`Order Count`) AS orders + FROM catalog.schema.orders_metrics + WHERE `Region` = 'EMEA' + AND `Order Month` >= DATE'2024-01-01' + AND `Order Month` < DATE'2025-01-01' + GROUP BY ALL + ORDER BY `Order Month` + ``` + + Or via the MCP tool: + + ```python + manage_metric_views( + action="query", + full_name="catalog.schema.orders_metrics", + query_measures=["Total Revenue", "Order Count"], + query_dimensions=["Order Month"], + where="`Region` = 'EMEA' AND `Order Month` >= DATE'2024-01-01'", + order_by="ALL", + limit=100, + ) + ``` + + **Key points:** + - Filter on dimensions using WHERE — just like a normal SQL query + - Dimension names with spaces need backtick quoting in WHERE clauses too + - The metric view's global `filter` (defined in YAML) is always applied in addition to your WHERE + - You cannot filter directly on measures — use HAVING after GROUP BY for that + execution_success: true + expectations: + expected_facts: + - WHERE + - backtick + - MEASURE() + expected_patterns: + - pattern: "WHERE" + min_count: 2 + description: Must show WHERE clause for filtering + - pattern: "MEASURE\\(" + min_count: 1 + description: Must use MEASURE() for measures + - pattern: "GROUP BY" + min_count: 1 + description: Must include GROUP BY + guidelines: + - Must show how to filter on dimensions in WHERE clause + - Must mention backtick quoting for dimension names with spaces + - Should show both SQL and MCP approaches + metadata: + category: happy_path + difficulty: easy + source: manual diff --git a/.test/skills/databricks-metric-views/manifest.yaml b/.test/skills/databricks-metric-views/manifest.yaml new file mode 100644 index 00000000..32eb0a49 --- /dev/null +++ b/.test/skills/databricks-metric-views/manifest.yaml @@ -0,0 +1,32 @@ +skill_name: databricks-metric-views +description: 'Unity Catalog metric views: define, create, query, and manage governed business metrics in YAML. Use when building + standardized KPIs, revenue metrics, order analytics, or any reusable business metrics that need consistent definitions across + teams and tools.' +scorers: + enabled: + - sql_syntax + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-mlflow-evaluation/ground_truth.yaml b/.test/skills/databricks-mlflow-evaluation/ground_truth.yaml new file mode 100644 index 00000000..c67cbc61 --- /dev/null +++ b/.test/skills/databricks-mlflow-evaluation/ground_truth.yaml @@ -0,0 +1,365 @@ +metadata: + skill_name: databricks-mlflow-evaluation + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: mlflow_eval_basic_evaluate_001 + inputs: + prompt: "How do I run a basic evaluation of my GenAI agent using MLflow?" + outputs: + response: | + ```python + import mlflow + + # Define your agent function + def my_agent(query: str) -> str: + # Your agent logic here + return response + + # Prepare evaluation data (nested input structure required) + eval_data = [ + {"inputs": {"query": "What is MLflow?"}, "expectations": {"expected_response": "MLflow is..."}}, + {"inputs": {"query": "How do I log a model?"}, "expectations": {"expected_response": "Use mlflow.log_model..."}} + ] + + # Run evaluation with built-in scorers + results = mlflow.genai.evaluate( + predict_fn=my_agent, + data=eval_data, + scorers=[ + mlflow.genai.scorers.RelevanceScorer(), + mlflow.genai.scorers.SafetyScorer() + ] + ) + print(results.tables["eval_results"]) + ``` + execution_success: true + expectations: + expected_facts: + - "Use mlflow.genai.evaluate(), not mlflow.evaluate()" + - "Data must use nested structure: {inputs: {query: ...}}" + - "predict_fn receives unpacked kwargs, not a dict" + expected_patterns: + - pattern: "mlflow\\.genai\\.evaluate" + min_count: 1 + description: "Uses the correct mlflow.genai.evaluate() API" + - pattern: "predict_fn" + min_count: 1 + description: "Passes the agent function as predict_fn" + - pattern: "scorers" + min_count: 1 + description: "Specifies scorers for evaluation" + - pattern: "\"inputs\"" + min_count: 1 + description: "Uses nested inputs structure in data" + guidelines: + - "Must use mlflow.genai.evaluate(), NOT mlflow.evaluate()" + - "Data format must use nested {inputs: {query: ...}} structure" + - "Must include at least one scorer" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: mlflow_eval_custom_scorer_002 + inputs: + prompt: "How do I create a custom scorer for MLflow GenAI evaluation?" + outputs: + response: | + ```python + from mlflow.genai.scorers import Scorer + + class ConcisenessScorer(Scorer): + name = "conciseness" + + def score(self, *, output, expectations=None, **kwargs): + if output is None: + return {"score": 0, "justification": "No output"} + + word_count = len(output.split()) + if word_count <= 50: + return {"score": 1, "justification": f"Concise: {word_count} words"} + elif word_count <= 150: + return {"score": 0.5, "justification": f"Moderate: {word_count} words"} + else: + return {"score": 0, "justification": f"Verbose: {word_count} words"} + + # Use in evaluation + results = mlflow.genai.evaluate( + predict_fn=my_agent, + data=eval_data, + scorers=[ConcisenessScorer()] + ) + ``` + execution_success: true + expectations: + expected_facts: + - "Custom scorers extend the Scorer base class" + - "The score method receives output and expectations as keyword arguments" + - "Scores should return a dict with score and justification" + expected_patterns: + - pattern: "class.*Scorer" + min_count: 1 + description: "Defines a custom scorer class" + - pattern: "def score\\(self" + min_count: 1 + description: "Implements the score method" + - pattern: "justification" + min_count: 1 + description: "Includes justification in score result" + guidelines: + - "Custom scorer must extend Scorer base class" + - "score method must accept **kwargs for forward compatibility" + - "Return value should include both score and justification" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: mlflow_eval_dataset_from_traces_003 + inputs: + prompt: "How do I build an evaluation dataset from production traces in MLflow?" + outputs: + response: | + ```python + import mlflow + + # Search for production traces + traces = mlflow.search_traces( + experiment_ids=["123"], + filter_string="status = 'OK'", + max_results=100 + ) + + # Convert traces to evaluation dataset + eval_data = [] + for trace in traces: + eval_data.append({ + "inputs": {"query": trace.request["query"]}, + "expectations": {"expected_response": trace.response} + }) + + # Run evaluation on the collected data + results = mlflow.genai.evaluate( + predict_fn=my_agent, + data=eval_data, + scorers=[mlflow.genai.scorers.RelevanceScorer()] + ) + ``` + execution_success: true + expectations: + expected_facts: + - "mlflow.search_traces() retrieves production traces for evaluation" + - "Traces can be filtered by status and other attributes" + - "Trace data is converted to the nested {inputs, expectations} format" + expected_patterns: + - pattern: "search_traces" + min_count: 1 + description: "Uses search_traces to retrieve production data" + - pattern: "filter_string" + min_count: 1 + description: "Filters traces by criteria" + - pattern: "\"inputs\"" + min_count: 1 + description: "Converts to required nested data format" + guidelines: + - "Must show the full workflow from trace retrieval to evaluation" + - "Data must be converted to the correct nested format" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: mlflow_eval_gotchas_004 + inputs: + prompt: "What are the critical gotchas I should know before writing MLflow evaluation code?" + outputs: + response: | + Key gotchas for MLflow GenAI evaluation: + + 1. **Use `mlflow.genai.evaluate()`** - NOT `mlflow.evaluate()`. The genai module is required for LLM evaluation. + + 2. **Data format must be nested**: `{"inputs": {"query": "..."}}`, not `{"query": "..."}`. + + 3. **predict_fn receives unpacked kwargs** - Your function gets `query="..."` not `{"query": "..."}`. + + 4. **Always read GOTCHAS.md first** - Contains 15+ common mistakes that cause failures. + + 5. **Always read CRITICAL-interfaces.md** - Contains exact API signatures and data schemas. + + 6. **Scorer names matter** - Built-in scorers have specific names; custom scorers need unique names. + execution_success: true + expectations: + expected_facts: + - "Use mlflow.genai.evaluate(), not mlflow.evaluate()" + - "Data format requires nested structure: {inputs: {query: ...}}" + - "predict_fn receives unpacked kwargs, not a dict" + - "GOTCHAS.md documents 15+ common mistakes" + expected_patterns: + - pattern: "mlflow\\.genai\\.evaluate" + min_count: 1 + description: "Specifies the correct API" + - pattern: "GOTCHAS" + min_count: 1 + description: "References the GOTCHAS.md file" + - pattern: "inputs" + min_count: 1 + description: "Mentions the nested inputs requirement" + guidelines: + - "Must warn against using mlflow.evaluate() instead of mlflow.genai.evaluate()" + - "Must explain the nested data format requirement" + - "Must mention GOTCHAS.md and CRITICAL-interfaces.md reference files" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: mlflow_eval_compare_runs_005 + inputs: + prompt: "How do I compare evaluation results across multiple agent versions?" + outputs: + response: | + ```python + import mlflow + + # Evaluate version A + with mlflow.start_run(run_name="agent-v1"): + results_v1 = mlflow.genai.evaluate( + predict_fn=agent_v1, + data=eval_data, + scorers=[mlflow.genai.scorers.RelevanceScorer()] + ) + + # Evaluate version B + with mlflow.start_run(run_name="agent-v2"): + results_v2 = mlflow.genai.evaluate( + predict_fn=agent_v2, + data=eval_data, + scorers=[mlflow.genai.scorers.RelevanceScorer()] + ) + + # Compare in the MLflow UI or programmatically + print(f"V1 avg relevance: {results_v1.metrics['relevance/mean']}") + print(f"V2 avg relevance: {results_v2.metrics['relevance/mean']}") + ``` + execution_success: true + expectations: + expected_facts: + - "Each evaluation should run in its own MLflow run for comparison" + - "Metrics are accessible via results.metrics dictionary" + - "The MLflow UI provides visual comparison across runs" + expected_patterns: + - pattern: "start_run" + min_count: 2 + description: "Creates separate MLflow runs for each version" + - pattern: "mlflow\\.genai\\.evaluate" + min_count: 2 + description: "Evaluates both versions" + - pattern: "\\.metrics\\[" + min_count: 1 + description: "Accesses metrics from results" + guidelines: + - "Must show evaluation in separate MLflow runs" + - "Must demonstrate how to access and compare metrics" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: mlflow_eval_builtin_scorers_006 + inputs: + prompt: "What built-in scorers are available in MLflow GenAI evaluation?" + outputs: + response: | + MLflow provides several built-in scorers in `mlflow.genai.scorers`: + + ```python + from mlflow.genai.scorers import ( + RelevanceScorer, # Is the response relevant to the query? + SafetyScorer, # Is the response safe and appropriate? + GroundednessScorer, # Is the response grounded in provided context? + CorrectnessScorer, # Does the response match expected output? + ) + + results = mlflow.genai.evaluate( + predict_fn=my_agent, + data=eval_data, + scorers=[ + RelevanceScorer(), + SafetyScorer(), + GroundednessScorer(), + CorrectnessScorer() + ] + ) + ``` + + Use `RelevanceScorer` and `SafetyScorer` for all evaluations. Add `CorrectnessScorer` when you have expected responses, and `GroundednessScorer` for RAG applications. + execution_success: true + expectations: + expected_facts: + - "Built-in scorers include RelevanceScorer, SafetyScorer, GroundednessScorer, and CorrectnessScorer" + - "Scorers are imported from mlflow.genai.scorers" + - "GroundednessScorer is specifically useful for RAG applications" + expected_patterns: + - pattern: "RelevanceScorer" + min_count: 1 + description: "Lists RelevanceScorer" + - pattern: "SafetyScorer" + min_count: 1 + description: "Lists SafetyScorer" + - pattern: "GroundednessScorer" + min_count: 1 + description: "Lists GroundednessScorer" + - pattern: "CorrectnessScorer" + min_count: 1 + description: "Lists CorrectnessScorer" + guidelines: + - "Must list at least four built-in scorers" + - "Must explain when to use each scorer" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: mlflow_eval_reference_files_007 + inputs: + prompt: "Which reference files should I read before writing MLflow evaluation code?" + outputs: + response: | + Read these reference files in order: + + | Reference | Purpose | When to Read | + |-----------|---------|--------------| + | `GOTCHAS.md` | Common mistakes | **Always read first** before writing code | + | `CRITICAL-interfaces.md` | API signatures, schemas | When writing any evaluation code | + | `patterns-evaluation.md` | Running evals, comparing | When executing evaluations | + | `patterns-scorers.md` | Custom scorer creation | When built-in scorers aren't enough | + | `patterns-datasets.md` | Dataset building | When preparing evaluation data | + + Always read GOTCHAS.md and CRITICAL-interfaces.md before writing any evaluation code. + execution_success: true + expectations: + expected_facts: + - "GOTCHAS.md must be read first before writing any code" + - "CRITICAL-interfaces.md contains API signatures and data schemas" + - "patterns-evaluation.md covers running and comparing evaluations" + - "patterns-scorers.md covers custom scorer creation" + - "patterns-datasets.md covers dataset preparation" + expected_patterns: + - pattern: "GOTCHAS\\.md" + min_count: 1 + description: "References GOTCHAS.md" + - pattern: "CRITICAL-interfaces\\.md" + min_count: 1 + description: "References CRITICAL-interfaces.md" + - pattern: "patterns-evaluation\\.md" + min_count: 1 + description: "References patterns-evaluation.md" + guidelines: + - "Must list all five reference files" + - "Must emphasize reading GOTCHAS.md first" + metadata: + category: happy_path + difficulty: easy + source: auto_generated diff --git a/.test/skills/databricks-mlflow-evaluation/manifest.yaml b/.test/skills/databricks-mlflow-evaluation/manifest.yaml new file mode 100644 index 00000000..f4240e27 --- /dev/null +++ b/.test/skills/databricks-mlflow-evaluation/manifest.yaml @@ -0,0 +1,31 @@ +skill_name: databricks-mlflow-evaluation +description: MLflow 3 GenAI agent evaluation. Use when writing mlflow.genai.evaluate() code, creating @scorer functions, using + built-in scorers (Guidelines, Correctness, Safety, RetrievalGroundedness), building eval datasets from traces, setting up + trace ingestion and production monitoring, aligning judges with MemAlign from domain expert feedback, or running optimize_prompts() + with GEPA for automated prompt improvement. +scorers: + enabled: + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-python-sdk/ground_truth.yaml b/.test/skills/databricks-python-sdk/ground_truth.yaml new file mode 100644 index 00000000..a783aa3a --- /dev/null +++ b/.test/skills/databricks-python-sdk/ground_truth.yaml @@ -0,0 +1,378 @@ +metadata: + skill_name: databricks-python-sdk + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: pysdk_authentication_001 + inputs: + prompt: "What are the different ways to authenticate with the Databricks Python SDK?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + + # Auto-detect credentials from environment + w = WorkspaceClient() + + # Explicit token authentication + w = WorkspaceClient( + host="https://your-workspace.cloud.databricks.com", + token="dapi..." + ) + + # Azure Service Principal + w = WorkspaceClient( + host="https://adb-xxx.azuredatabricks.net", + azure_workspace_resource_id="/subscriptions/.../resourceGroups/.../providers/Microsoft.Databricks/workspaces/...", + azure_tenant_id="tenant-id", + azure_client_id="client-id", + azure_client_secret="secret" + ) + + # Use a named profile from ~/.databrickscfg + w = WorkspaceClient(profile="MY_PROFILE") + ``` + execution_success: true + expectations: + expected_facts: + - "WorkspaceClient auto-detects credentials from environment by default" + - "Explicit token auth requires host and token parameters" + - "Azure Service Principal auth uses azure_tenant_id, azure_client_id, and azure_client_secret" + - "Named profiles are loaded from ~/.databrickscfg" + expected_patterns: + - pattern: "WorkspaceClient\\(\\)" + min_count: 1 + description: "Shows auto-detect authentication" + - pattern: "token=" + min_count: 1 + description: "Shows explicit token authentication" + - pattern: "azure_client_id" + min_count: 1 + description: "Shows Azure Service Principal authentication" + - pattern: "profile=" + min_count: 1 + description: "Shows named profile authentication" + guidelines: + - "Must show at least three different authentication methods" + - "Must include auto-detect, explicit token, and profile-based auth" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: pysdk_clusters_api_002 + inputs: + prompt: "How do I list, create, and manage clusters using the Databricks Python SDK?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + from datetime import timedelta + + w = WorkspaceClient() + + # List all clusters + for cluster in w.clusters.list(): + print(f"{cluster.cluster_name}: {cluster.state}") + + # Get cluster details + cluster = w.clusters.get(cluster_id="0123-456789-abcdef") + + # Create a cluster (returns Wait object) + wait = w.clusters.create( + cluster_name="my-cluster", + spark_version=w.clusters.select_spark_version(latest=True), + node_type_id=w.clusters.select_node_type(local_disk=True), + num_workers=2 + ) + cluster = wait.result() # Wait for cluster to be running + + # Start/stop/delete + w.clusters.start(cluster_id="...").result() + w.clusters.stop(cluster_id="...") + w.clusters.delete(cluster_id="...") + ``` + execution_success: true + expectations: + expected_facts: + - "w.clusters.list() returns an iterator of cluster objects" + - "w.clusters.create() returns a Wait object for async cluster creation" + - "select_spark_version() and select_node_type() are helper methods for discovery" + - "start() returns a Wait object while stop() and delete() do not" + expected_patterns: + - pattern: "w\\.clusters\\.list" + min_count: 1 + description: "Lists clusters" + - pattern: "w\\.clusters\\.create" + min_count: 1 + description: "Creates a cluster" + - pattern: "select_spark_version" + min_count: 1 + description: "Uses spark version selector" + - pattern: "select_node_type" + min_count: 1 + description: "Uses node type selector" + guidelines: + - "Must show list, create, start, stop, and delete operations" + - "Must use select_spark_version and select_node_type helpers" + - "Must demonstrate the Wait object pattern for async operations" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: pysdk_jobs_api_003 + inputs: + prompt: "Show me how to create and run a Databricks job using the Python SDK." + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + from databricks.sdk.service.jobs import Task, NotebookTask + + w = WorkspaceClient() + + # List jobs + for job in w.jobs.list(): + print(f"{job.job_id}: {job.settings.name}") + + # Create a job + created = w.jobs.create( + name="my-job", + tasks=[ + Task( + task_key="main", + notebook_task=NotebookTask(notebook_path="/Users/me/notebook"), + existing_cluster_id="0123-456789-abcdef" + ) + ] + ) + + # Run a job and wait for completion + run = w.jobs.run_now_and_wait(job_id=created.job_id) + print(f"Run completed: {run.state.result_state}") + + # Get run output + output = w.jobs.get_run_output(run_id=run.run_id) + ``` + execution_success: true + expectations: + expected_facts: + - "Jobs are created with w.jobs.create() returning a job with job_id" + - "run_now_and_wait() triggers a job and blocks until completion" + - "get_run_output() retrieves the output of a completed run" + - "Task requires task_key and a task type like NotebookTask" + expected_patterns: + - pattern: "w\\.jobs\\.create" + min_count: 1 + description: "Creates a job" + - pattern: "run_now_and_wait" + min_count: 1 + description: "Runs job and waits for completion" + - pattern: "get_run_output" + min_count: 1 + description: "Gets run output" + - pattern: "NotebookTask" + min_count: 1 + description: "Uses NotebookTask for task configuration" + guidelines: + - "Must show the full job lifecycle: create, run, get output" + - "Must use run_now_and_wait for blocking execution" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: pysdk_sql_execution_004 + inputs: + prompt: "How do I execute SQL queries against a Databricks SQL warehouse using the Python SDK?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + from databricks.sdk.service.sql import StatementState + + w = WorkspaceClient() + + # Execute SQL query + response = w.statement_execution.execute_statement( + warehouse_id="abc123", + statement="SELECT * FROM catalog.schema.table LIMIT 10", + wait_timeout="30s" + ) + + # Check status and get results + if response.status.state == StatementState.SUCCEEDED: + for row in response.result.data_array: + print(row) + + # For large results, fetch chunks + chunk = w.statement_execution.get_statement_result_chunk_n( + statement_id=response.statement_id, + chunk_index=0 + ) + ``` + execution_success: true + expectations: + expected_facts: + - "SQL is executed with w.statement_execution.execute_statement()" + - "warehouse_id specifies which SQL warehouse to use" + - "StatementState.SUCCEEDED indicates a successful query" + - "Large results are fetched in chunks with get_statement_result_chunk_n()" + expected_patterns: + - pattern: "execute_statement" + min_count: 1 + description: "Executes SQL statement" + - pattern: "warehouse_id" + min_count: 1 + description: "Specifies SQL warehouse" + - pattern: "StatementState" + min_count: 1 + description: "Checks statement execution status" + - pattern: "get_statement_result_chunk_n" + min_count: 1 + description: "Fetches result chunks for large queries" + guidelines: + - "Must show execute, status check, and result retrieval" + - "Must demonstrate chunked result fetching for large datasets" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: pysdk_databricks_connect_005 + inputs: + prompt: "How do I use Databricks Connect to run Spark code locally?" + outputs: + response: | + ```python + from databricks.connect import DatabricksSession + + # Auto-detects 'DEFAULT' profile from ~/.databrickscfg + spark = DatabricksSession.builder.getOrCreate() + + # With explicit profile + spark = DatabricksSession.builder.profile("MY_PROFILE").getOrCreate() + + # Use spark as normal + df = spark.sql("SELECT * FROM catalog.schema.table") + df.show() + ``` + + Databricks Connect requires `databricks-connect` package and a running cluster to execute against. + execution_success: true + expectations: + expected_facts: + - "DatabricksSession replaces SparkSession for remote Spark execution" + - "Databricks Connect auto-detects credentials from ~/.databrickscfg" + - "The databricks-connect package must be installed separately" + expected_patterns: + - pattern: "DatabricksSession" + min_count: 1 + description: "Uses DatabricksSession for Spark" + - pattern: "getOrCreate" + min_count: 1 + description: "Builds session with getOrCreate()" + - pattern: "profile" + min_count: 1 + description: "Shows profile-based configuration" + guidelines: + - "Must show DatabricksSession, not SparkSession" + - "Must demonstrate both auto-detect and explicit profile usage" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: pysdk_rest_api_006 + inputs: + prompt: "How can I make direct REST API calls using the Databricks Python SDK?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + + # Direct API call using authenticated client + response = w.api_client.do( + method="GET", + path="/api/2.0/clusters/list" + ) + + # POST with body + response = w.api_client.do( + method="POST", + path="/api/2.0/jobs/run-now", + body={"job_id": 123} + ) + ``` + + Use direct REST API calls for operations not yet available in the SDK or for advanced use cases. + execution_success: true + expectations: + expected_facts: + - "w.api_client.do() makes authenticated REST API calls" + - "The method parameter accepts GET, POST, PUT, DELETE" + - "The body parameter accepts a dictionary for POST/PUT requests" + expected_patterns: + - pattern: "api_client\\.do" + min_count: 2 + description: "Uses api_client.do() for direct REST calls" + - pattern: "method=" + min_count: 2 + description: "Specifies HTTP method" + - pattern: "path=" + min_count: 2 + description: "Specifies API endpoint path" + guidelines: + - "Must show both GET and POST examples" + - "Must explain when to use direct REST over SDK methods" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: pysdk_cli_commands_007 + inputs: + prompt: "What are the most common Databricks CLI commands?" + outputs: + response: | + ```bash + # Check version (should be >= 0.278.0) + databricks --version + + # Use specific profile + databricks --profile MY_PROFILE clusters list + + # Common commands + databricks clusters list + databricks jobs list + databricks workspace ls /Users/me + ``` + + The CLI uses the same authentication as the Python SDK (`~/.databrickscfg` profiles, environment variables, or explicit flags). + execution_success: true + expectations: + expected_facts: + - "The Databricks CLI shares authentication with the Python SDK" + - "Profiles are specified with the --profile flag" + - "Common subcommands include clusters, jobs, and workspace" + expected_patterns: + - pattern: "databricks --version" + min_count: 1 + description: "Shows version check command" + - pattern: "databricks clusters list" + min_count: 1 + description: "Shows cluster listing command" + - pattern: "databricks jobs list" + min_count: 1 + description: "Shows job listing command" + guidelines: + - "Must show at least three common CLI commands" + - "Must mention the --profile flag for profile selection" + metadata: + category: happy_path + difficulty: easy + source: auto_generated diff --git a/.test/skills/databricks-python-sdk/manifest.yaml b/.test/skills/databricks-python-sdk/manifest.yaml new file mode 100644 index 00000000..62377187 --- /dev/null +++ b/.test/skills/databricks-python-sdk/manifest.yaml @@ -0,0 +1,30 @@ +skill_name: databricks-python-sdk +description: Databricks development guidance including Python SDK, Databricks Connect, CLI, and REST API. Use when working + with databricks-sdk, databricks-connect, or Databricks APIs. +scorers: + enabled: + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-spark-structured-streaming/ground_truth.yaml b/.test/skills/databricks-spark-structured-streaming/ground_truth.yaml new file mode 100644 index 00000000..6f642133 --- /dev/null +++ b/.test/skills/databricks-spark-structured-streaming/ground_truth.yaml @@ -0,0 +1,227 @@ +metadata: + skill_name: databricks-spark-structured-streaming + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: spark-streaming_kafka_to_delta_001 + inputs: + prompt: Create a streaming pipeline that reads from Kafka and writes to a Delta table + outputs: + response: | + Here's a basic Kafka to Delta streaming pipeline: + + ```python + from pyspark.sql.functions import col, from_json + + # Read from Kafka + df = (spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "broker:9092") + .option("subscribe", "topic") + .load() + .select(from_json(col("value").cast("string"), schema).alias("data")) + .select("data.*") + ) + + # Write to Delta + df.writeStream \ + .format("delta") \ + .outputMode("append") \ + .option("checkpointLocation", "/Volumes/catalog/checkpoints/stream") \ + .trigger(processingTime="30 seconds") \ + .start("/delta/target_table") + ``` + + Key points: + - Use a persistent checkpoint location (UC volumes, not DBFS) + - Each stream needs a unique checkpoint path + execution_success: true + expectations: + expected_facts: + - Reads from Kafka using readStream + - Writes to Delta table format + - Includes checkpoint location + expected_patterns: + - pattern: readStream + min_count: 1 + description: Must use readStream for streaming source + - pattern: writeStream + min_count: 1 + description: Must use writeStream for streaming sink + - pattern: checkpointLocation + min_count: 1 + description: Must specify checkpoint location + - pattern: format.*kafka + min_count: 1 + description: Must specify Kafka format + guidelines: + - Code must be valid Python syntax + - Must include checkpoint location for fault tolerance + - Checkpoint should use UC volumes, not DBFS + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: spark-streaming_trigger_modes_002 + inputs: + prompt: What are the different trigger modes for Spark Structured Streaming? + outputs: + response: | + Spark Structured Streaming supports several trigger modes: + + ```python + # Fixed interval processing + .trigger(processingTime="30 seconds") + + # Process all available data once + .trigger(availableNow=True) + + # Continuous processing (low latency) + .trigger(continuous="1 second") + ``` + + - **processingTime**: Fixed interval micro-batches. Best for most workloads. + - **availableNow**: Processes all available data then stops. Good for batch-like streaming. + - **continuous**: Low-latency mode (~1ms). Limited operator support. + execution_success: true + expectations: + expected_facts: + - processingTime trigger for fixed interval + - availableNow trigger processes all data then stops + - continuous trigger for low latency + expected_patterns: + - pattern: processingTime + min_count: 1 + description: Must mention processingTime trigger + - pattern: trigger + min_count: 1 + description: Must use trigger configuration + guidelines: + - Must explain multiple trigger modes + - Code must be valid Python syntax + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: spark-streaming_checkpoint_practices_003 + inputs: + prompt: What are best practices for Spark Structured Streaming checkpoints? + outputs: + response: | + Checkpoint best practices: + + 1. **Use persistent storage**: UC volumes, not DBFS + 2. **Unique per stream**: Each stream needs its own checkpoint path + 3. **Never delete**: Checkpoints track stream state and offsets + 4. **Use volumes**: `/Volumes/catalog/schema/checkpoints/stream_name` + + ```python + df.writeStream \ + .format("delta") \ + .outputMode("append") \ + .option("checkpointLocation", "/Volumes/catalog/schema/checkpoints/my_stream") \ + .start("catalog.schema.target_table") + ``` + execution_success: true + expectations: + expected_facts: + - Checkpoint location should be persistent (UC volumes) + - Each stream needs a unique checkpoint + expected_patterns: + - pattern: checkpointLocation + min_count: 1 + description: Must reference checkpointLocation + - pattern: /Volumes/ + min_count: 1 + description: Should use UC volumes path + guidelines: + - Must recommend persistent storage for checkpoints + - Must warn about unique checkpoints per stream + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: spark-streaming_production_checklist_004 + inputs: + prompt: What is the production checklist for Spark Structured Streaming? + outputs: + response: | + Production checklist for Spark Structured Streaming: + + - Checkpoint location is persistent (UC volumes, not DBFS) + - Unique checkpoint per stream + - Fixed-size cluster (no autoscaling for streaming) + - Monitoring configured (input rate, lag, batch duration) + - Exactly-once verified (txnVersion/txnAppId) + - Watermark configured for stateful operations + - Left joins for stream-static (not inner) + execution_success: true + expectations: + expected_facts: + - Use persistent checkpoint location + - Fixed-size cluster, no autoscaling + - Configure monitoring for streaming + expected_patterns: + - pattern: checkpoint + min_count: 1 + description: Must mention checkpoint requirements + - pattern: autoscal + min_count: 1 + description: Must mention autoscaling restriction + guidelines: + - Must cover checkpoint, cluster, and monitoring requirements + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: spark-streaming_stream_joins_005 + inputs: + prompt: How do I join a streaming DataFrame with a static table? + outputs: + response: | + For stream-static joins, use a left join (not inner) to avoid dropping streaming rows when the static side has no match: + + ```python + # Read streaming source + stream_df = spark.readStream.format("delta").table("catalog.schema.events") + + # Read static dimension table + dim_df = spark.read.table("catalog.schema.dimensions") + + # Left join: stream on left, static on right + joined = stream_df.join(dim_df, "join_key", "left") + + joined.writeStream \ + .format("delta") \ + .outputMode("append") \ + .option("checkpointLocation", "/Volumes/catalog/schema/checkpoints/joined_stream") \ + .start("catalog.schema.enriched_events") + ``` + execution_success: true + expectations: + expected_facts: + - Use left join for stream-static joins + - Static table is on the right side + expected_patterns: + - pattern: readStream + min_count: 1 + description: Must use streaming read + - pattern: join + min_count: 1 + description: Must demonstrate join operation + - pattern: left + min_count: 1 + description: Must use left join + guidelines: + - Must use left join for stream-static joins, not inner + - Code must be valid Python syntax + metadata: + category: happy_path + difficulty: medium + source: auto_generated diff --git a/.test/skills/databricks-spark-structured-streaming/manifest.yaml b/.test/skills/databricks-spark-structured-streaming/manifest.yaml new file mode 100644 index 00000000..7e6e9521 --- /dev/null +++ b/.test/skills/databricks-spark-structured-streaming/manifest.yaml @@ -0,0 +1,30 @@ +skill_name: databricks-spark-structured-streaming +description: Comprehensive guide to Spark Structured Streaming for production workloads. Use when building streaming pipelines, + implementing real-time data processing, handling stateful operations, or optimizing streaming performance. +scorers: + enabled: + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-synthetic-data-generation/ground_truth.yaml b/.test/skills/databricks-synthetic-data-generation/ground_truth.yaml new file mode 100644 index 00000000..b0e1b0f6 --- /dev/null +++ b/.test/skills/databricks-synthetic-data-generation/ground_truth.yaml @@ -0,0 +1,392 @@ +metadata: + skill_name: databricks-synthetic-data-generation + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: synth_data_script_structure_001 + inputs: + prompt: "What is the recommended structure for a synthetic data generation script on Databricks?" + outputs: + response: | + ```python + """Generate synthetic data for [use case].""" + import numpy as np + import pandas as pd + from datetime import datetime, timedelta + from faker import Faker + import holidays + from pyspark.sql import SparkSession + + # ============================================================================= + # CONFIGURATION - Edit these values + # ============================================================================= + CATALOG = "my_catalog" + SCHEMA = "my_schema" + VOLUME_PATH = f"/Volumes/{CATALOG}/{SCHEMA}/raw_data" + + N_CUSTOMERS = 2500 + N_ORDERS = 25000 + N_TICKETS = 8000 + + END_DATE = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) + START_DATE = END_DATE - timedelta(days=180) + + SEED = 42 + + # ============================================================================= + # SETUP + # ============================================================================= + np.random.seed(SEED) + Faker.seed(SEED) + fake = Faker() + spark = SparkSession.builder.getOrCreate() + ``` + execution_success: true + expectations: + expected_facts: + - "Configuration variables go at the top of the script" + - "Seeds must be set for both numpy and Faker for reproducibility" + - "Volume path follows /Volumes/{catalog}/{schema}/{volume} pattern" + - "SparkSession is used for saving data to Delta/Parquet" + expected_patterns: + - pattern: "CATALOG\\s*=" + min_count: 1 + description: "Defines catalog configuration variable" + - pattern: "np\\.random\\.seed" + min_count: 1 + description: "Sets numpy random seed" + - pattern: "Faker\\.seed" + min_count: 1 + description: "Sets Faker random seed" + - pattern: "SparkSession" + min_count: 1 + description: "Uses SparkSession for data saving" + guidelines: + - "Configuration must be at the top of the script" + - "Must set seeds for both numpy and Faker for reproducibility" + - "Must use SparkSession for saving to Databricks storage" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: synth_data_infrastructure_002 + inputs: + prompt: "Should I create the catalog, schema, and volume separately or inside the generation script?" + outputs: + response: | + Always create infrastructure inside the Python script using `spark.sql()`: + + ```python + # ============================================================================= + # CREATE INFRASTRUCTURE (inside the Python script) + # ============================================================================= + spark.sql(f"CREATE CATALOG IF NOT EXISTS {CATALOG}") + spark.sql(f"CREATE SCHEMA IF NOT EXISTS {CATALOG}.{SCHEMA}") + spark.sql(f"CREATE VOLUME IF NOT EXISTS {CATALOG}.{SCHEMA}.raw_data") + ``` + + Do NOT make separate MCP SQL calls - it's much slower. The `spark` variable is available by default on Databricks clusters. + execution_success: true + expectations: + expected_facts: + - "Infrastructure creation belongs inside the Python script, not as separate SQL calls" + - "spark.sql() is used for CREATE CATALOG, CREATE SCHEMA, and CREATE VOLUME" + - "IF NOT EXISTS makes the script idempotent" + expected_patterns: + - pattern: "CREATE CATALOG IF NOT EXISTS" + min_count: 1 + description: "Creates catalog idempotently" + - pattern: "CREATE SCHEMA IF NOT EXISTS" + min_count: 1 + description: "Creates schema idempotently" + - pattern: "CREATE VOLUME IF NOT EXISTS" + min_count: 1 + description: "Creates volume idempotently" + guidelines: + - "Must use spark.sql() inside the script, not separate SQL calls" + - "All CREATE statements must use IF NOT EXISTS" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: synth_data_distributions_003 + inputs: + prompt: "How do I generate synthetic data with realistic statistical distributions?" + outputs: + response: | + Use numpy distributions with weighted probabilities: + + ```python + import numpy as np + from faker import Faker + + fake = Faker() + + # Categorical with weighted distribution + customers_pdf["tier"] = np.random.choice( + ['Free', 'Pro', 'Enterprise'], N_CUSTOMERS, p=[0.6, 0.3, 0.1] + ) + + # Log-normal for financial data (right-skewed, no negatives) + customers_pdf["arr"] = customers_pdf["tier"].apply( + lambda t: round(np.random.lognormal(11, 0.5), 2) if t == 'Enterprise' + else round(np.random.lognormal(8, 0.6), 2) if t == 'Pro' else 0 + ) + + # Correlated fields: order amount varies by tier + amount = np.random.lognormal( + 7 if tier == 'Enterprise' else 5 if tier == 'Pro' else 3.5, + 0.7 + ) + ``` + + Key distributions: use log-normal for financial amounts, weighted choice for categories, and exponential for time durations. + execution_success: true + expectations: + expected_facts: + - "Log-normal distribution is ideal for financial data (right-skewed, non-negative)" + - "np.random.choice with p parameter creates weighted categorical distributions" + - "Field values should correlate with related fields (e.g., order amount varies by tier)" + expected_patterns: + - pattern: "np\\.random\\.choice" + min_count: 1 + description: "Uses weighted random choice for categories" + - pattern: "np\\.random\\.lognormal" + min_count: 1 + description: "Uses log-normal for financial data" + - pattern: "p=\\[" + min_count: 1 + description: "Specifies probability weights" + guidelines: + - "Must show weighted distributions, not uniform random" + - "Financial values must use log-normal, not normal distribution" + - "Must demonstrate correlation between related fields" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: synth_data_time_patterns_004 + inputs: + prompt: "How do I add realistic time-based patterns like weekday/weekend effects and event spikes?" + outputs: + response: | + ```python + import holidays + + US_HOLIDAYS = holidays.US(years=[START_DATE.year, END_DATE.year]) + + def get_daily_multiplier(date): + """Calculate volume multiplier for a given date.""" + multiplier = 1.0 + + # Weekend drop + if date.weekday() >= 5: + multiplier *= 0.6 + + # Holiday drop (even lower than weekends) + if date in US_HOLIDAYS: + multiplier *= 0.3 + + # Q4 seasonality (higher in Oct-Dec) + multiplier *= 1 + 0.15 * (date.month - 6) / 6 + + # Incident spike + if INCIDENT_START <= date <= INCIDENT_END: + multiplier *= 3.0 + + # Random noise + multiplier *= np.random.normal(1, 0.1) + + return max(0.1, multiplier) + + # Apply to generate daily volumes + date_range = pd.date_range(START_DATE, END_DATE, freq='D') + daily_volumes = [int(BASE_DAILY * get_daily_multiplier(d)) for d in date_range] + ``` + execution_success: true + expectations: + expected_facts: + - "The holidays library provides country-specific holiday calendars" + - "Weekend volumes should drop to 60% of weekday volumes" + - "Holiday volumes should drop further to 30%" + - "Event spikes multiply baseline volume by 3x or more" + expected_patterns: + - pattern: "holidays\\.US" + min_count: 1 + description: "Uses US holiday calendar" + - pattern: "weekday\\(\\)" + min_count: 1 + description: "Checks for weekday/weekend" + - pattern: "INCIDENT_START" + min_count: 1 + description: "Implements event spike period" + - pattern: "np\\.random\\.normal" + min_count: 1 + description: "Adds random noise to volumes" + guidelines: + - "Must include weekday/weekend effects" + - "Must include holiday effects using the holidays library" + - "Must include at least one event spike period" + - "Must add random noise to avoid unrealistic regularity" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: synth_data_save_volume_005 + inputs: + prompt: "How should I save generated synthetic data to Databricks?" + outputs: + response: | + Always save data to a Unity Catalog Volume as parquet files: + + ```python + VOLUME_PATH = f"/Volumes/{CATALOG}/{SCHEMA}/raw_data" + + # Save as parquet files (raw data for downstream pipelines) + spark.createDataFrame(customers_pdf).write.mode("overwrite").parquet(f"{VOLUME_PATH}/customers") + spark.createDataFrame(orders_pdf).write.mode("overwrite").parquet(f"{VOLUME_PATH}/orders") + spark.createDataFrame(tickets_pdf).write.mode("overwrite").parquet(f"{VOLUME_PATH}/tickets") + ``` + + Save to Volumes as parquet, not directly to tables. This serves as raw input for a downstream Spark Declarative Pipeline that handles bronze/silver/gold layers. + execution_success: true + expectations: + expected_facts: + - "Data should be saved to Volumes as parquet, not directly to tables" + - "Volume path pattern is /Volumes/{catalog}/{schema}/{volume}" + - "spark.createDataFrame() converts pandas DataFrames to Spark for saving" + - "Parquet files serve as input for downstream pipeline processing" + expected_patterns: + - pattern: "createDataFrame" + min_count: 1 + description: "Converts pandas to Spark DataFrame" + - pattern: "\\.parquet\\(" + min_count: 1 + description: "Saves as parquet format" + - pattern: "mode\\(\"overwrite\"\\)" + min_count: 1 + description: "Uses overwrite mode" + - pattern: "/Volumes/" + min_count: 1 + description: "Saves to Unity Catalog Volume path" + guidelines: + - "Must save to Volumes, not directly to tables" + - "Must use parquet format" + - "Must explain that this is raw data for downstream pipeline" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: synth_data_raw_only_006 + inputs: + prompt: "Should I include pre-aggregated fields like total_orders or avg_revenue in my synthetic data?" + outputs: + response: | + No. By default, generate only raw transactional data. One row equals one event. + + ```python + # CORRECT - Raw transactional data + customers_data.append({ + "customer_id": cid, + "name": fake.company(), + "tier": "Enterprise", + "region": "North", + }) + + orders_data.append({ + "order_id": f"ORD-{i:06d}", + "customer_id": cid, + "amount": 150.00, + "order_date": "2024-10-15", + }) + + # WRONG - Do not add pre-aggregated fields + # customers_data.append({ + # "total_orders": 47, # aggregation + # "total_revenue": 12500.00, # sum + # "avg_order_value": 265.95, # average + # }) + ``` + + Aggregations belong in downstream pipeline transformations, not in raw data. + execution_success: true + expectations: + expected_facts: + - "Raw data should have one row per event/transaction/record" + - "Pre-aggregated fields like total_orders, sum_revenue, avg_csat should not be included" + - "Aggregations belong in downstream pipeline transformations" + expected_patterns: + - pattern: "customer_id" + min_count: 1 + description: "Shows raw transactional customer data" + - pattern: "order_id" + min_count: 1 + description: "Shows raw transactional order data" + - pattern: "WRONG|BAD|Do not" + min_count: 1 + description: "Shows anti-pattern example" + guidelines: + - "Must clearly show the correct raw data pattern" + - "Must show the anti-pattern with pre-aggregated fields" + - "Must explain that aggregations belong downstream" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: synth_data_common_libraries_007 + inputs: + prompt: "What Python libraries should I use for generating synthetic data?" + outputs: + response: | + Key libraries for synthetic data generation: + + - **faker**: Generates realistic names, addresses, emails, companies, and dates + - **numpy**: Statistical distributions (log-normal, exponential, weighted choice) + - **pandas**: DataFrame construction and manipulation + - **holidays**: Country-specific holiday calendars for realistic date patterns + - **pyspark**: SparkSession for saving data to Volumes as parquet + + ```python + import numpy as np + import pandas as pd + from datetime import datetime, timedelta + from faker import Faker + import holidays + from pyspark.sql import SparkSession + ``` + + All of these are pre-installed on Databricks clusters. + execution_success: true + expectations: + expected_facts: + - "Faker generates realistic personal and business data" + - "Numpy provides statistical distributions for realistic value generation" + - "The holidays library provides holiday calendars for date patterns" + - "All listed libraries are pre-installed on Databricks clusters" + expected_patterns: + - pattern: "faker" + min_count: 1 + description: "Mentions Faker library" + - pattern: "numpy" + min_count: 1 + description: "Mentions numpy library" + - pattern: "holidays" + min_count: 1 + description: "Mentions holidays library" + - pattern: "pyspark" + min_count: 1 + description: "Mentions PySpark library" + guidelines: + - "Must list at least four key libraries with their purposes" + - "Must mention that libraries are pre-installed on Databricks" + metadata: + category: happy_path + difficulty: easy + source: auto_generated diff --git a/.test/skills/databricks-synthetic-data-generation/manifest.yaml b/.test/skills/databricks-synthetic-data-generation/manifest.yaml new file mode 100644 index 00000000..9216d106 --- /dev/null +++ b/.test/skills/databricks-synthetic-data-generation/manifest.yaml @@ -0,0 +1,30 @@ +skill_name: databricks-synthetic-data-generation +description: Generate realistic synthetic data using Faker and Spark, with non-linear distributions, integrity constraints, + and save to Databricks. Use when creating test data, demo datasets, or synthetic tables. +scorers: + enabled: + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-unity-catalog/ground_truth.yaml b/.test/skills/databricks-unity-catalog/ground_truth.yaml new file mode 100644 index 00000000..487ae52e --- /dev/null +++ b/.test/skills/databricks-unity-catalog/ground_truth.yaml @@ -0,0 +1,255 @@ +metadata: + skill_name: databricks-unity-catalog + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: uc_volume_operations_001 + inputs: + prompt: "How do I upload, download, and list files in a Unity Catalog Volume?" + outputs: + response: | + Use MCP tools for Volume file operations: + + ```python + # List files in a volume + list_volume_files(volume_path="/Volumes/catalog/schema/volume/folder/") + + # Upload file to volume + upload_to_volume( + local_path="/tmp/data.csv", + volume_path="/Volumes/catalog/schema/volume/data.csv" + ) + + # Download file from volume + download_from_volume( + volume_path="/Volumes/catalog/schema/volume/data.csv", + local_path="/tmp/downloaded.csv" + ) + + # Create directory + create_volume_directory(volume_path="/Volumes/catalog/schema/volume/new_folder") + ``` + execution_success: true + expectations: + expected_facts: + - "Volume paths follow the pattern /Volumes/{catalog}/{schema}/{volume}" + - "list_volume_files, upload_to_volume, download_from_volume, and create_volume_directory are available MCP tools" + expected_patterns: + - pattern: "list_volume_files" + min_count: 1 + description: "Uses list_volume_files MCP tool" + - pattern: "upload_to_volume" + min_count: 1 + description: "Uses upload_to_volume MCP tool" + - pattern: "download_from_volume" + min_count: 1 + description: "Uses download_from_volume MCP tool" + - pattern: "/Volumes/" + min_count: 3 + description: "Uses Unity Catalog Volume paths" + guidelines: + - "Must show at least list, upload, and download operations" + - "Volume paths must follow the /Volumes/{catalog}/{schema}/{volume} pattern" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: uc_table_lineage_002 + inputs: + prompt: "How do I query table lineage to find which tables feed into my target table?" + outputs: + response: | + ```sql + -- Table lineage: What tables feed into this table? + SELECT source_table_full_name, source_column_name + FROM system.access.table_lineage + WHERE target_table_full_name = 'catalog.schema.table' + AND event_date >= current_date() - 7; + ``` + + System table lineage is captured automatically by Unity Catalog. Query `system.access.table_lineage` for table-level dependencies and `system.access.column_lineage` for column-level lineage. + execution_success: true + expectations: + expected_facts: + - "Table lineage is stored in system.access.table_lineage" + - "Column lineage is available in system.access.column_lineage" + - "Lineage is captured automatically by Unity Catalog" + expected_patterns: + - pattern: "system\\.access\\.table_lineage" + min_count: 1 + description: "Queries the table lineage system table" + - pattern: "target_table_full_name" + min_count: 1 + description: "Filters by target table name" + - pattern: "source_table_full_name" + min_count: 1 + description: "Returns source table names" + guidelines: + - "Must query system.access.table_lineage" + - "Must show filtering by target table name" + - "Should mention that lineage is captured automatically" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: uc_audit_logs_003 + inputs: + prompt: "How do I check recent permission changes using Unity Catalog audit logs?" + outputs: + response: | + ```sql + -- Audit: Recent permission changes + SELECT event_time, user_identity.email, action_name, request_params + FROM system.access.audit + WHERE action_name LIKE '%GRANT%' OR action_name LIKE '%REVOKE%' + ORDER BY event_time DESC + LIMIT 100; + ``` + + The `system.access.audit` table records all access control changes. Filter by action_name to find GRANT, REVOKE, or other permission-related events. + execution_success: true + expectations: + expected_facts: + - "Audit logs are stored in system.access.audit" + - "Permission changes include GRANT and REVOKE actions" + - "user_identity.email identifies who made the change" + expected_patterns: + - pattern: "system\\.access\\.audit" + min_count: 1 + description: "Queries the audit log system table" + - pattern: "action_name" + min_count: 1 + description: "Filters by action type" + - pattern: "GRANT" + min_count: 1 + description: "Checks for GRANT operations" + - pattern: "REVOKE" + min_count: 1 + description: "Checks for REVOKE operations" + guidelines: + - "Must query system.access.audit" + - "Must filter for permission-related actions" + - "Must include user identity in the output" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: uc_billing_usage_004 + inputs: + prompt: "How do I query DBU usage and billing data across workspaces?" + outputs: + response: | + ```sql + -- Billing: DBU usage by workspace + SELECT workspace_id, sku_name, SUM(usage_quantity) AS total_dbus + FROM system.billing.usage + WHERE usage_date >= current_date() - 30 + GROUP BY workspace_id, sku_name; + ``` + + The `system.billing.usage` table provides DBU consumption data. Use it to track costs by workspace, SKU type, or time period. + execution_success: true + expectations: + expected_facts: + - "Billing data is stored in system.billing.usage" + - "usage_quantity contains the DBU consumption" + - "Data can be filtered by usage_date for time-based analysis" + expected_patterns: + - pattern: "system\\.billing\\.usage" + min_count: 1 + description: "Queries the billing system table" + - pattern: "SUM\\(usage_quantity\\)" + min_count: 1 + description: "Aggregates DBU usage" + - pattern: "sku_name" + min_count: 1 + description: "Groups by SKU type" + guidelines: + - "Must query system.billing.usage" + - "Must aggregate by meaningful dimensions like workspace or SKU" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: uc_grant_system_access_005 + inputs: + prompt: "How do I grant a group access to Unity Catalog system tables?" + outputs: + response: | + ```sql + -- Grant access to system tables + GRANT USE CATALOG ON CATALOG system TO `data_engineers`; + GRANT USE SCHEMA ON SCHEMA system.access TO `data_engineers`; + GRANT SELECT ON SCHEMA system.access TO `data_engineers`; + ``` + + Three grants are needed: USE CATALOG on the `system` catalog, USE SCHEMA on the specific schema (e.g., `system.access`), and SELECT for read access. + execution_success: true + expectations: + expected_facts: + - "System table access requires three levels of grants: USE CATALOG, USE SCHEMA, and SELECT" + - "The system catalog contains built-in system tables" + - "system.access schema contains lineage and audit tables" + expected_patterns: + - pattern: "GRANT USE CATALOG" + min_count: 1 + description: "Grants USE CATALOG permission" + - pattern: "GRANT USE SCHEMA" + min_count: 1 + description: "Grants USE SCHEMA permission" + - pattern: "GRANT SELECT" + min_count: 1 + description: "Grants SELECT permission" + guidelines: + - "Must show all three required GRANT statements" + - "Must explain why three separate grants are needed" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: uc_mcp_execute_sql_006 + inputs: + prompt: "How do I query system tables using the MCP SQL execution tool?" + outputs: + response: | + ```python + # Query lineage via MCP + mcp__databricks__execute_sql( + sql_query=""" + SELECT source_table_full_name, target_table_full_name + FROM system.access.table_lineage + WHERE event_date >= current_date() - 7 + """, + catalog="system" + ) + ``` + + Use `mcp__databricks__execute_sql` to execute system table queries. Set the `catalog` parameter to "system" for system table access. + execution_success: true + expectations: + expected_facts: + - "mcp__databricks__execute_sql is the MCP tool for running SQL queries" + - "The catalog parameter should be set to 'system' for system table queries" + expected_patterns: + - pattern: "mcp__databricks__execute_sql" + min_count: 1 + description: "Uses the MCP SQL execution tool" + - pattern: "catalog=\"system\"" + min_count: 1 + description: "Sets catalog to system" + - pattern: "system\\.access" + min_count: 1 + description: "Queries system access tables" + guidelines: + - "Must show the mcp__databricks__execute_sql tool call" + - "Must set catalog parameter to 'system'" + metadata: + category: happy_path + difficulty: easy + source: auto_generated diff --git a/.test/skills/databricks-unity-catalog/manifest.yaml b/.test/skills/databricks-unity-catalog/manifest.yaml new file mode 100644 index 00000000..f679684b --- /dev/null +++ b/.test/skills/databricks-unity-catalog/manifest.yaml @@ -0,0 +1,31 @@ +skill_name: databricks-unity-catalog +description: Unity Catalog system tables and volumes. Use when querying system tables (audit, lineage, billing) or working + with volume file operations (upload, download, list files in /Volumes/). +scorers: + enabled: + - sql_syntax + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-unstructured-pdf-generation/ground_truth.yaml b/.test/skills/databricks-unstructured-pdf-generation/ground_truth.yaml new file mode 100644 index 00000000..6560e7c1 --- /dev/null +++ b/.test/skills/databricks-unstructured-pdf-generation/ground_truth.yaml @@ -0,0 +1,271 @@ +metadata: + skill_name: databricks-unstructured-pdf-generation + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: pdf_gen_quick_start_001 + inputs: + prompt: "How do I generate synthetic PDF documents for my RAG pipeline?" + outputs: + response: | + Use the `generate_pdf_documents` MCP tool: + + ``` + generate_pdf_documents( + catalog="my_catalog", + schema="my_schema", + volume="my_volume", + folder="hr_policies", + topic="HR policy documents for a mid-size technology company", + num_documents=15, + complexity="medium" + ) + ``` + + This generates PDF files and corresponding JSON metadata files with questions and evaluation guidelines for RAG testing. + execution_success: true + expectations: + expected_facts: + - "generate_pdf_documents is an MCP tool for creating synthetic PDFs" + - "Documents are saved to a Unity Catalog Volume" + - "Both PDF files and JSON metadata files are generated" + expected_patterns: + - pattern: "generate_pdf_documents" + min_count: 1 + description: "Uses the generate_pdf_documents MCP tool" + - pattern: "catalog" + min_count: 1 + description: "Specifies target catalog" + - pattern: "schema" + min_count: 1 + description: "Specifies target schema" + - pattern: "num_documents" + min_count: 1 + description: "Specifies number of documents to generate" + guidelines: + - "Must show the generate_pdf_documents tool with key parameters" + - "Must mention that both PDFs and JSON metadata are generated" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: pdf_gen_json_structure_002 + inputs: + prompt: "What is the structure of the JSON metadata files generated alongside the PDFs?" + outputs: + response: | + Each PDF gets a corresponding JSON file with evaluation metadata: + + ```json + { + "title": "API Authentication Guide", + "category": "Technical", + "pdf_path": "/Volumes/catalog/schema/volume/folder/doc_001.pdf", + "question": "What authentication methods are supported by the API?", + "guideline": "Answer should mention OAuth 2.0, API keys, and JWT tokens with their use cases." + } + ``` + + - `title`: Document title + - `category`: Document category + - `pdf_path`: Full Volume path to the generated PDF + - `question`: A test question whose answer is in the PDF + - `guideline`: Evaluation criteria for judging RAG responses + execution_success: true + expectations: + expected_facts: + - "Each PDF has a corresponding JSON metadata file" + - "JSON includes title, category, pdf_path, question, and guideline fields" + - "The question field contains a query answerable from the PDF content" + - "The guideline field provides evaluation criteria for RAG responses" + expected_patterns: + - pattern: "\"title\"" + min_count: 1 + description: "JSON includes title field" + - pattern: "\"question\"" + min_count: 1 + description: "JSON includes question field" + - pattern: "\"guideline\"" + min_count: 1 + description: "JSON includes guideline field" + - pattern: "\"pdf_path\"" + min_count: 1 + description: "JSON includes pdf_path field" + guidelines: + - "Must show the complete JSON structure with all fields" + - "Must explain the purpose of each field" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: pdf_gen_rag_integration_003 + inputs: + prompt: "How do I use the generated PDFs and JSON files to test my RAG pipeline?" + outputs: + response: | + ```python + # Load questions from JSON files + questions = load_json_files(f"/Volumes/{catalog}/{schema}/{volume}/{folder}/*.json") + + for q in questions: + # 1. Ingest the PDF into your vector database + ingest_document(q["pdf_path"]) + + # 2. Query RAG system with the test question + response = rag_system.query(q["question"]) + + # 3. Evaluate using the guideline + is_correct = evaluate_response(response, q["guideline"]) + ``` + + The workflow is: ingest PDFs as source documents, use `question` fields to query your RAG system, and use `guideline` fields to evaluate response quality. + execution_success: true + expectations: + expected_facts: + - "PDFs serve as source documents for the vector database" + - "Questions from JSON test retrieval accuracy" + - "Guidelines from JSON evaluate response quality" + expected_patterns: + - pattern: "question" + min_count: 2 + description: "Uses question field for RAG queries" + - pattern: "guideline" + min_count: 2 + description: "Uses guideline field for evaluation" + - pattern: "pdf_path" + min_count: 1 + description: "References PDF path for ingestion" + guidelines: + - "Must show the three-step workflow: ingest, query, evaluate" + - "Must use both question and guideline fields from the JSON" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: pdf_gen_hr_documents_004 + inputs: + prompt: "Generate HR policy documents for testing a company knowledge base RAG system." + outputs: + response: | + ``` + generate_pdf_documents( + catalog="ai_dev_kit", + schema="hr_demo", + volume="documents", + folder="hr_policies", + topic="HR policy documents for a mid-size technology company including PTO, benefits, code of conduct, and remote work policies", + num_documents=20, + complexity="medium" + ) + ``` + + This generates 20 HR policy PDFs covering common topics like PTO, benefits enrollment, code of conduct, and remote work guidelines, along with JSON evaluation files. + execution_success: true + expectations: + expected_facts: + - "The topic parameter guides the content theme of generated documents" + - "num_documents controls how many PDFs are generated" + - "complexity affects the depth and length of generated content" + expected_patterns: + - pattern: "generate_pdf_documents" + min_count: 1 + description: "Uses the generate_pdf_documents MCP tool" + - pattern: "topic=" + min_count: 1 + description: "Specifies document topic" + - pattern: "num_documents=" + min_count: 1 + description: "Specifies document count" + guidelines: + - "Must specify a descriptive topic parameter" + - "Must use realistic catalog/schema/volume names" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: pdf_gen_tech_docs_005 + inputs: + prompt: "Generate technical documentation PDFs for testing a developer documentation search system." + outputs: + response: | + ``` + generate_pdf_documents( + catalog="ai_dev_kit", + schema="tech_docs", + volume="documents", + folder="api_docs", + topic="Technical API documentation including REST endpoints, authentication, SDK usage, error handling, and rate limiting", + num_documents=25, + complexity="high" + ) + ``` + + Set complexity to "high" for technical documentation to generate detailed content with code examples, tables, and structured sections. + execution_success: true + expectations: + expected_facts: + - "complexity='high' generates more detailed content with code examples" + - "The topic parameter should describe the specific technical domain" + expected_patterns: + - pattern: "generate_pdf_documents" + min_count: 1 + description: "Uses the generate_pdf_documents MCP tool" + - pattern: "complexity=\"high\"" + min_count: 1 + description: "Uses high complexity for technical docs" + - pattern: "topic=" + min_count: 1 + description: "Specifies technical topic" + guidelines: + - "Must use high complexity for technical documentation" + - "Topic must describe specific technical domains" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: pdf_gen_financial_reports_006 + inputs: + prompt: "Generate synthetic financial report PDFs for testing a compliance document search system." + outputs: + response: | + ``` + generate_pdf_documents( + catalog="ai_dev_kit", + schema="finance_demo", + volume="documents", + folder="financial_reports", + topic="Corporate financial reports including quarterly earnings, annual reports, budget summaries, and audit findings", + num_documents=15, + complexity="high" + ) + ``` + + Financial documents benefit from high complexity to include tables, figures, and structured financial data that tests RAG parsing capabilities. + execution_success: true + expectations: + expected_facts: + - "Financial documents should use high complexity for realistic structure" + - "Generated PDFs include tables and structured data" + expected_patterns: + - pattern: "generate_pdf_documents" + min_count: 1 + description: "Uses the generate_pdf_documents MCP tool" + - pattern: "financial" + min_count: 1 + description: "Topic relates to financial documents" + - pattern: "complexity=\"high\"" + min_count: 1 + description: "Uses high complexity" + guidelines: + - "Must use high complexity for financial documents" + - "Must specify financial-specific topics" + metadata: + category: happy_path + difficulty: medium + source: auto_generated diff --git a/.test/skills/databricks-unstructured-pdf-generation/manifest.yaml b/.test/skills/databricks-unstructured-pdf-generation/manifest.yaml new file mode 100644 index 00000000..27c09770 --- /dev/null +++ b/.test/skills/databricks-unstructured-pdf-generation/manifest.yaml @@ -0,0 +1,30 @@ +skill_name: databricks-unstructured-pdf-generation +description: Generate synthetic PDF documents for RAG and unstructured data use cases. Use when creating test PDFs, demo documents, + or evaluation datasets for retrieval systems. +scorers: + enabled: + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/skills/databricks-vector-search/ground_truth.yaml b/.test/skills/databricks-vector-search/ground_truth.yaml new file mode 100644 index 00000000..7e8acd2b --- /dev/null +++ b/.test/skills/databricks-vector-search/ground_truth.yaml @@ -0,0 +1,345 @@ +metadata: + skill_name: databricks-vector-search + version: 0.1.0 + created_at: '2026-02-23T08:00:00.000000' + +test_cases: +- id: vs_create_endpoint_001 + inputs: + prompt: "How do I create a vector search endpoint in Databricks?" + outputs: + response: | + ```python + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + + # Create a standard endpoint + endpoint = w.vector_search_endpoints.create_endpoint( + name="my-vs-endpoint", + endpoint_type="STANDARD" # or "STORAGE_OPTIMIZED" + ) + # Endpoint creation is asynchronous; check status with get_endpoint() + ``` + + Two endpoint types are available: + - **STANDARD**: Best for low-latency queries + - **STORAGE_OPTIMIZED**: Best for large-scale, cost-effective deployments + execution_success: true + expectations: + expected_facts: + - "Endpoints are created with w.vector_search_endpoints.create_endpoint()" + - "Two endpoint types: STANDARD and STORAGE_OPTIMIZED" + - "Endpoint creation is asynchronous" + expected_patterns: + - pattern: "create_endpoint" + min_count: 1 + description: "Uses create_endpoint() method" + - pattern: "STANDARD" + min_count: 1 + description: "Shows STANDARD endpoint type" + - pattern: "STORAGE_OPTIMIZED" + min_count: 1 + description: "Shows STORAGE_OPTIMIZED endpoint type" + guidelines: + - "Must show both STANDARD and STORAGE_OPTIMIZED endpoint types" + - "Must mention that endpoint creation is asynchronous" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: vs_delta_sync_managed_002 + inputs: + prompt: "How do I create a Delta Sync index with managed embeddings?" + outputs: + response: | + ```python + # Source table must have: primary key column + text column + index = w.vector_search_indexes.create_index( + name="catalog.schema.my_index", + endpoint_name="my-vs-endpoint", + primary_key="id", + index_type="DELTA_SYNC", + delta_sync_index_spec={ + "source_table": "catalog.schema.documents", + "embedding_source_columns": [ + { + "name": "content", + "embedding_model_endpoint_name": "databricks-gte-large-en" + } + ], + "pipeline_type": "TRIGGERED" # or "CONTINUOUS" + } + ) + ``` + + Databricks automatically embeds the text column using the specified model. Pipeline types: TRIGGERED (manual sync) or CONTINUOUS (auto-sync on source changes). + execution_success: true + expectations: + expected_facts: + - "Delta Sync index automatically syncs from a source Delta table" + - "Managed embeddings use embedding_source_columns with a model endpoint" + - "TRIGGERED pipeline syncs on demand; CONTINUOUS syncs automatically" + - "databricks-gte-large-en is a built-in embedding model" + expected_patterns: + - pattern: "create_index" + min_count: 1 + description: "Uses create_index() method" + - pattern: "DELTA_SYNC" + min_count: 1 + description: "Specifies Delta Sync index type" + - pattern: "embedding_source_columns" + min_count: 1 + description: "Configures managed embedding columns" + - pattern: "databricks-gte-large-en" + min_count: 1 + description: "Uses built-in embedding model" + guidelines: + - "Must show the full create_index call with delta_sync_index_spec" + - "Must explain TRIGGERED vs CONTINUOUS pipeline types" + - "Must specify a Databricks embedding model endpoint" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: vs_delta_sync_self_managed_003 + inputs: + prompt: "How do I create a Delta Sync index with pre-computed embeddings?" + outputs: + response: | + ```python + # Source table must have: primary key + embedding vector column + index = w.vector_search_indexes.create_index( + name="catalog.schema.my_index", + endpoint_name="my-vs-endpoint", + primary_key="id", + index_type="DELTA_SYNC", + delta_sync_index_spec={ + "source_table": "catalog.schema.documents", + "embedding_vector_columns": [ + { + "name": "embedding", + "embedding_dimension": 768 + } + ], + "pipeline_type": "TRIGGERED" + } + ) + ``` + + Use self-managed embeddings when you need a specific embedding model not available as a Databricks endpoint, or when you pre-compute embeddings in a pipeline. + execution_success: true + expectations: + expected_facts: + - "Self-managed embeddings use embedding_vector_columns instead of embedding_source_columns" + - "embedding_dimension must match the dimension of your pre-computed vectors" + - "The source table must already contain an embedding column" + expected_patterns: + - pattern: "embedding_vector_columns" + min_count: 1 + description: "Uses self-managed embedding columns" + - pattern: "embedding_dimension" + min_count: 1 + description: "Specifies embedding dimension" + - pattern: "DELTA_SYNC" + min_count: 1 + description: "Uses Delta Sync index type" + guidelines: + - "Must use embedding_vector_columns (not embedding_source_columns)" + - "Must include embedding_dimension parameter" + - "Must explain when to use self-managed vs managed embeddings" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: vs_query_index_004 + inputs: + prompt: "How do I query a vector search index and get similarity results?" + outputs: + response: | + ```python + results = w.vector_search_indexes.query_index( + index_name="catalog.schema.my_index", + columns=["id", "content", "metadata"], + query_text="What is machine learning?", + num_results=5 + ) + + for doc in results.result.data_array: + score = doc[-1] # Similarity score is last column + print(f"Score: {score}, Content: {doc[1][:100]}...") + ``` + + For managed embedding indexes, use `query_text`. For self-managed, use `query_vector` with a pre-computed embedding. + execution_success: true + expectations: + expected_facts: + - "query_index() searches for similar documents by text or vector" + - "query_text is used for managed embedding indexes" + - "query_vector is used for self-managed embedding indexes" + - "Similarity score is the last column in the result array" + expected_patterns: + - pattern: "query_index" + min_count: 1 + description: "Uses query_index() method" + - pattern: "query_text" + min_count: 1 + description: "Uses text-based query" + - pattern: "num_results" + min_count: 1 + description: "Specifies number of results" + - pattern: "data_array" + min_count: 1 + description: "Accesses result data array" + guidelines: + - "Must show query_text for managed embeddings" + - "Must mention query_vector as alternative for self-managed" + - "Must show how to access similarity scores" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: vs_query_with_filters_005 + inputs: + prompt: "How do I apply metadata filters when querying a vector search index?" + outputs: + response: | + ```python + results = w.vector_search_indexes.query_index( + index_name="catalog.schema.my_index", + columns=["id", "content", "category"], + query_text="machine learning best practices", + num_results=10, + filters_json='{"category": "technical", "status": "published"}' + ) + ``` + + The `filters_json` parameter accepts a JSON string of key-value pairs that filter results by metadata columns before similarity ranking. Only documents matching all filters are returned. + execution_success: true + expectations: + expected_facts: + - "filters_json accepts a JSON string for metadata filtering" + - "Filters are applied before similarity ranking" + - "All filter conditions must be satisfied (AND logic)" + expected_patterns: + - pattern: "filters_json" + min_count: 1 + description: "Uses filters_json parameter" + - pattern: "query_index" + min_count: 1 + description: "Uses query_index() method" + - pattern: "query_text" + min_count: 1 + description: "Combines text query with filters" + guidelines: + - "Must show filters_json parameter with a JSON string" + - "Must explain that filters are applied before similarity ranking" + metadata: + category: happy_path + difficulty: medium + source: auto_generated + +- id: vs_cli_reference_006 + inputs: + prompt: "What CLI commands are available for managing Databricks vector search?" + outputs: + response: | + ```bash + # List endpoints + databricks vector-search endpoints list + + # Create endpoint + databricks vector-search endpoints create \ + --name my-endpoint \ + --endpoint-type STANDARD + + # List indexes on endpoint + databricks vector-search indexes list-indexes \ + --endpoint-name my-endpoint + + # Get index status + databricks vector-search indexes get-index \ + --index-name catalog.schema.my_index + + # Sync index (for TRIGGERED pipeline) + databricks vector-search indexes sync-index \ + --index-name catalog.schema.my_index + + # Delete index + databricks vector-search indexes delete-index \ + --index-name catalog.schema.my_index + ``` + execution_success: true + expectations: + expected_facts: + - "Vector search CLI commands are under databricks vector-search" + - "sync-index triggers a manual sync for TRIGGERED pipeline indexes" + - "Endpoints and indexes are managed as separate subcommands" + expected_patterns: + - pattern: "databricks vector-search endpoints" + min_count: 1 + description: "Shows endpoint management commands" + - pattern: "databricks vector-search indexes" + min_count: 1 + description: "Shows index management commands" + - pattern: "sync-index" + min_count: 1 + description: "Shows manual sync command" + guidelines: + - "Must show both endpoint and index management commands" + - "Must include sync-index for triggered pipelines" + metadata: + category: happy_path + difficulty: easy + source: auto_generated + +- id: vs_embedding_models_007 + inputs: + prompt: "What embedding models are available for vector search indexes?" + outputs: + response: | + Databricks provides built-in embedding models for managed indexes: + + | Model | Dimensions | + |-------|-----------| + | `databricks-gte-large-en` | 1024 | + | `databricks-bge-large-en` | 1024 | + + ```python + # Use with managed embeddings in index creation + embedding_source_columns=[ + { + "name": "content", + "embedding_model_endpoint_name": "databricks-gte-large-en" + } + ] + ``` + + You can also use custom embedding models deployed as serving endpoints. + execution_success: true + expectations: + expected_facts: + - "databricks-gte-large-en produces 1024-dimensional embeddings" + - "databricks-bge-large-en produces 1024-dimensional embeddings" + - "Custom embedding models can also be used via serving endpoints" + expected_patterns: + - pattern: "databricks-gte-large-en" + min_count: 1 + description: "Lists GTE embedding model" + - pattern: "databricks-bge-large-en" + min_count: 1 + description: "Lists BGE embedding model" + - pattern: "1024" + min_count: 1 + description: "Specifies embedding dimensions" + guidelines: + - "Must list at least two built-in embedding models with dimensions" + - "Must mention that custom models can also be used" + metadata: + category: happy_path + difficulty: easy + source: auto_generated diff --git a/.test/skills/databricks-vector-search/manifest.yaml b/.test/skills/databricks-vector-search/manifest.yaml new file mode 100644 index 00000000..a053e80a --- /dev/null +++ b/.test/skills/databricks-vector-search/manifest.yaml @@ -0,0 +1,31 @@ +skill_name: databricks-vector-search +description: 'Patterns for Databricks Vector Search: create endpoints and indexes, query with filters, manage embeddings. + Use when building RAG applications, semantic search, or similarity matching. Covers both storage-optimized and standard + endpoints.' +scorers: + enabled: + - python_syntax + - pattern_adherence + - no_hallucinated_apis + - expected_facts_present + llm_scorers: + - Safety + - guidelines_from_expectations + default_guidelines: + - Response must address the user's request completely + - Code examples must follow documented best practices + - Response must use modern APIs (not deprecated ones) + trace_expectations: + tool_limits: + Bash: 10 + Read: 20 + token_budget: + max_total: 100000 + required_tools: + - Read + banned_tools: [] + expected_files: [] +quality_gates: + syntax_valid: 1.0 + pattern_adherence: 0.9 + execution_success: 0.8 diff --git a/.test/src/skill_test/cli/__init__.py b/.test/src/skill_test/cli/__init__.py index b8352981..246d7113 100644 --- a/.test/src/skill_test/cli/__init__.py +++ b/.test/src/skill_test/cli/__init__.py @@ -17,6 +17,7 @@ review, trace_eval, list_traces, + optimize, ) @@ -36,6 +37,7 @@ def main(): review - Review pending candidates interactively trace-eval - Evaluate trace against skill expectations list-traces - List available trace runs from MLflow + optimize - Optimize a skill using GEPA """ args = sys.argv[1:] @@ -51,6 +53,7 @@ def main(): print(" review Review pending candidates interactively") print(" trace-eval Evaluate trace against skill expectations") print(" list-traces List available trace runs from MLflow") + print(" optimize Optimize a skill using GEPA") sys.exit(0) skill_name = args[0] @@ -142,6 +145,44 @@ def main(): } else: result = list_traces(experiment, ctx, limit) + elif subcommand == "optimize": + # Parse optimize-specific arguments + opt_preset = "standard" + opt_mode = "static" + opt_task_lm = None + opt_reflection_lm = None + opt_dry_run = False + opt_apply = False + + i = 2 + while i < len(args): + if args[i] in ("--preset", "-p") and i + 1 < len(args): + opt_preset = args[i + 1] + i += 2 + elif args[i] in ("--mode", "-m") and i + 1 < len(args): + opt_mode = args[i + 1] + i += 2 + elif args[i] == "--task-lm" and i + 1 < len(args): + opt_task_lm = args[i + 1] + i += 2 + elif args[i] == "--reflection-lm" and i + 1 < len(args): + opt_reflection_lm = args[i + 1] + i += 2 + elif args[i] == "--dry-run": + opt_dry_run = True + i += 1 + elif args[i] == "--apply": + opt_apply = True + i += 1 + else: + i += 1 + + result = optimize( + skill_name, ctx, + preset=opt_preset, mode=opt_mode, task_lm=opt_task_lm, + reflection_lm=opt_reflection_lm, + dry_run=opt_dry_run, apply=opt_apply, + ) else: print(f"Unknown subcommand: {subcommand}") sys.exit(1) @@ -171,5 +212,6 @@ def main(): "review", "trace_eval", "list_traces", + "optimize", "main", ] diff --git a/.test/src/skill_test/cli/commands.py b/.test/src/skill_test/cli/commands.py index e84637e2..be619d5a 100644 --- a/.test/src/skill_test/cli/commands.py +++ b/.test/src/skill_test/cli/commands.py @@ -1399,3 +1399,84 @@ def list_traces( "experiment_name": experiment_name, "hint": "Check experiment name and MLflow connection", } + + +def optimize( + skill_name: str, + ctx: CLIContext, + preset: str = "standard", + mode: str = "static", + task_lm: Optional[str] = None, + reflection_lm: Optional[str] = None, + dry_run: bool = False, + apply: bool = False, +) -> Dict[str, Any]: + """Optimize a skill using GEPA. + + Runs the full optimization pipeline: evaluate -> optimize -> review. + Optionally applies the optimized result to the SKILL.md. + + Args: + skill_name: Name of the skill to optimize + ctx: CLI context + preset: GEPA preset ("quick", "standard", "thorough") + mode: "static" (uses ground truth) or "generative" (generates fresh responses) + task_lm: LLM model for generative mode + reflection_lm: Override GEPA reflection model + dry_run: Show config and estimate cost without running + apply: Apply the optimized result to SKILL.md + + Returns: + Dictionary with optimization results + """ + try: + from ..optimize.runner import optimize_skill + from ..optimize.review import review_optimization, apply_optimization + except ImportError as e: + return { + "success": False, + "error": f"GEPA optimization requires the 'optimize' extra: {e}", + "hint": "Install with: pip install skill-test[optimize]", + } + + try: + result = optimize_skill( + skill_name=skill_name, + mode=mode, + preset=preset, + task_lm=task_lm, + reflection_lm=reflection_lm, + dry_run=dry_run, + ) + + review_optimization(result) + + if apply and not dry_run: + apply_optimization(result) + + return { + "success": True, + "skill_name": skill_name, + "original_score": result.original_score, + "optimized_score": result.optimized_score, + "improvement": result.improvement, + "original_tokens": result.original_token_count, + "optimized_tokens": result.optimized_token_count, + "token_reduction_pct": result.token_reduction_pct, + "applied": apply and not dry_run, + "dry_run": dry_run, + "mlflow_run_id": result.mlflow_run_id, + } + except FileNotFoundError as e: + return { + "success": False, + "error": str(e), + "skill_name": skill_name, + } + except Exception as e: + return { + "success": False, + "error": str(e), + "skill_name": skill_name, + "hint": "Check GEPA installation and API keys", + } diff --git a/.test/src/skill_test/optimize/__init__.py b/.test/src/skill_test/optimize/__init__.py new file mode 100644 index 00000000..87cab788 --- /dev/null +++ b/.test/src/skill_test/optimize/__init__.py @@ -0,0 +1,22 @@ +"""GEPA-powered skill optimization using optimize_anything API. + +Public API: + optimize_skill() - End-to-end optimize a SKILL.md (and optionally tools) + create_skill_evaluator() - Create a GEPA evaluator for a skill + OptimizationResult - Dataclass with optimization results + PRESETS - GEPA config presets (quick, standard, thorough) +""" + +from .runner import optimize_skill, OptimizationResult +from .evaluator import create_skill_evaluator +from .config import PRESETS +from .review import review_optimization, apply_optimization + +__all__ = [ + "optimize_skill", + "OptimizationResult", + "create_skill_evaluator", + "PRESETS", + "review_optimization", + "apply_optimization", +] diff --git a/.test/src/skill_test/optimize/asi.py b/.test/src/skill_test/optimize/asi.py new file mode 100644 index 00000000..9c0e6cff --- /dev/null +++ b/.test/src/skill_test/optimize/asi.py @@ -0,0 +1,269 @@ +"""ASI diagnostics: convert MLflow Feedback to optimize_anything SideInfo. + +Builds an Actionable Side Information dict from scorer feedback so GEPA's +reflection LM gets structured context about what went wrong with each scorer. +Failure details are surfaced via the ``_failures`` key in the returned dict. + +Also provides ``skillbench_to_asi()`` for the SkillBench-style evaluator, +which produces GEPA-optimized side info with standard diagnostic keys +(Error, Expected, Actual) and ``skill_md_specific_info`` for per-component +routing. +""" + +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +from mlflow.entities import Feedback + +if TYPE_CHECKING: + from .assertions import AssertionResult + +from .assertions import ( + summarize_failures as _summarize_failures, + _classify_assertion, + _extract_content, +) + + +def feedback_to_score(feedback: Feedback) -> float | None: + """Convert a single MLflow Feedback to a numeric score. + + Mapping: + "yes" -> 1.0 + "no" -> 0.0 + "skip" -> None (excluded from scoring) + numeric -> float(value) + """ + value = feedback.value + if value == "yes": + return 1.0 + elif value == "no": + return 0.0 + elif value == "skip": + return None + else: + try: + return float(value) + except (TypeError, ValueError): + return None + + +def feedback_to_asi(feedbacks: list[Feedback]) -> tuple[float, dict[str, Any]]: + """Convert MLflow Feedback objects to optimize_anything (score, SideInfo). + + Computes the mean score across non-skipped feedbacks and builds a + SideInfo dict. Failure diagnostics are collected in the ``_failures`` + key so GEPA's reflection LM sees actionable context directly in the + side_info dict (no ``oa.log()`` needed). + + Args: + feedbacks: List of MLflow Feedback objects from running scorers + + Returns: + Tuple of (composite_score, side_info_dict) + """ + scores = [] + side_info: dict[str, Any] = {} + failures: list[str] = [] + + for fb in feedbacks: + score = feedback_to_score(fb) + name = fb.name or "unnamed" + + if score is None: + side_info[name] = { + "score": None, + "value": fb.value, + "rationale": fb.rationale or "", + "status": "skipped", + } + continue + + scores.append(score) + side_info[name] = { + "score": score, + "value": fb.value, + "rationale": fb.rationale or "", + "status": "pass" if score >= 0.5 else "fail", + } + + # Collect failure diagnostics for GEPA reflection + if score < 1.0: + failures.append( + f"Scorer '{name}' returned {fb.value}: {fb.rationale or 'no rationale'}" + ) + + composite = sum(scores) / len(scores) if scores else 0.0 + + side_info["_summary"] = { + "composite_score": composite, + "total_scorers": len(feedbacks), + "scored": len(scores), + "skipped": len(feedbacks) - len(scores), + "passed": sum(1 for s in scores if s >= 0.5), + "failed": sum(1 for s in scores if s < 0.5), + } + + if failures: + side_info["_failures"] = "\n".join(failures) + + return composite, side_info + + +def build_rich_asi( + feedbacks: list[Feedback], + *, + generated_response: str | None = None, + skill_coverage: dict[str, Any] | None = None, + task_prompt: str | None = None, + per_dimension_scores: dict[str, float] | None = None, +) -> tuple[float, dict[str, Any]]: + """Build enriched ASI with categorized diagnostics for GEPA reflection. + + Extends ``feedback_to_asi()`` with additional context that helps GEPA's + reflection LM understand *why* scores changed and make better edits. + + Args: + feedbacks: MLflow Feedback objects from all scoring layers + generated_response: Truncated LLM output (so reflection sees what skill produced) + skill_coverage: Which patterns/facts found vs missing in SKILL.md + task_prompt: The test prompt (so reflection understands context) + per_dimension_scores: Per-dimension scores dict for Pareto-frontier selection + + Returns: + Tuple of (composite_score, enriched_side_info_dict) + """ + composite, side_info = feedback_to_asi(feedbacks) + + # Categorize feedbacks by layer + categories: dict[str, list[str]] = { + "skill_content": [], + "generated_response": [], + "reference": [], + "structure": [], + } + for fb in feedbacks: + name = fb.name or "" + score = feedback_to_score(fb) + if score is None: + continue + entry = f"{name}: {'pass' if score >= 0.5 else 'FAIL'} ({fb.rationale or ''})" + if name.startswith("skill_content_"): + categories["skill_content"].append(entry) + elif name.startswith("skill_"): + categories["structure"].append(entry) + else: + categories["generated_response"].append(entry) + + side_info["_diagnostics_by_layer"] = { + k: v for k, v in categories.items() if v + } + + if generated_response is not None: + side_info["_generated_response"] = generated_response[:2000] + + if skill_coverage: + side_info["_skill_coverage"] = skill_coverage + + if task_prompt: + side_info["_task_prompt"] = task_prompt[:500] + + if per_dimension_scores: + side_info["scores"] = per_dimension_scores + + return composite, side_info + + +# --------------------------------------------------------------------------- +# SkillBench → GEPA side info +# --------------------------------------------------------------------------- + + +def skillbench_to_asi( + with_results: list[AssertionResult], + without_results: list[AssertionResult], + *, + task_prompt: str | None = None, + scores: dict[str, float] | None = None, + with_response: str | None = None, + without_response: str | None = None, + reference_answer: str | None = None, + candidate: dict[str, str] | None = None, +) -> dict[str, Any]: + """Convert SkillBench assertion results to GEPA-optimized side info. + + Produces a flat dict with GEPA's standard diagnostic keys plus actual + agent output and reference answers so the reflection LM can make + targeted SKILL.md edits. + + Budget: ~1480 chars/example (Task 200 + Error ~80 + Expected 500 + + Actual 500 + scores ~200). With minibatch=3: ~4440 chars (~1100 tokens). + + Keys produced (all optional, only non-empty included): + ``Task`` — the task prompt (truncated at 200 chars) + ``Error`` — compact NEEDS_SKILL/REGRESSION assertion labels + ``Expected`` — reference answer from ground_truth.yaml (truncated at 500 chars) + ``Actual`` — agent response WITH skill (truncated at 500 chars) + ``skill_md_specific_info`` — sub-dict with ``Regressions`` for per-component routing + ``scores`` — score breakdown promoted to objective_scores by GEPA + + Args: + with_results: Assertion results from the WITH-skill run. + without_results: Assertion results from the WITHOUT-skill run. + task_prompt: The test prompt (for reflection context). + scores: Score breakdown dict (effectiveness, pass_with, structure, final). + with_response: Agent output WITH skill (truncated at 500 chars). + without_response: Agent output WITHOUT skill (reserved for future use). + reference_answer: Ground truth answer from ground_truth.yaml. + candidate: Full candidate dict for tool-specific diagnostic routing. + + Returns: + Side info dict for optimize_anything. + """ + diag = _summarize_failures(with_results, without_results) + + side_info: dict[str, Any] = {} + + # 1. Task context (short — just enough for the reflection LM) + if task_prompt: + side_info["Task"] = task_prompt[:200] + + # 2. Error: what specific assertions fail (from assertions.py) + if diag.get("Error"): + side_info["Error"] = diag["Error"] + + # 3. Expected: reference answer (what correct output looks like) + if reference_answer: + side_info["Expected"] = reference_answer[:500] + + # 4. Actual: agent response WITH skill (what was produced) + if with_response is not None: + side_info["Actual"] = with_response[:500] + + # 5. Regressions: routed to skill_md component + if diag.get("Regressions"): + side_info["skill_md_specific_info"] = {"Regressions": diag["Regressions"]} + + # 5b. Route tool-specific failures to {component}_specific_info + if candidate: + tool_components = {k: v for k, v in candidate.items() if k.startswith("tools_")} + for comp_name, comp_text in tool_components.items(): + comp_text_lower = comp_text.lower() + tool_failures = [] + for w, wo in zip(with_results, without_results): + label = _classify_assertion(w, wo) + if label in ("NEEDS_SKILL", "REGRESSION"): + content = _extract_content(w) + if content.lower() in comp_text_lower: + tool_failures.append(f"{label}: {w.assertion_type} — '{content}'") + if tool_failures: + side_info[f"{comp_name}_specific_info"] = { + "Related_assertions": "\n".join(tool_failures) + } + + # 6. Scores: needed for GEPA Pareto tracking + if scores: + side_info["scores"] = scores + + return side_info diff --git a/.test/src/skill_test/optimize/assertions.py b/.test/src/skill_test/optimize/assertions.py new file mode 100644 index 00000000..3b6815ab --- /dev/null +++ b/.test/src/skill_test/optimize/assertions.py @@ -0,0 +1,190 @@ +"""Binary assertion layer for SkillBench-style evaluation. + +Wraps pattern and fact checks into binary pass/fail assertions, +mirroring SkillBench's pytest-style binary approach. No fuzzy keyword +scoring -- each assertion either passes or fails. +""" + +import re +from dataclasses import dataclass +from typing import Any + + +@dataclass +class AssertionResult: + """Result of a single binary assertion.""" + + name: str + passed: bool + rationale: str + assertion_type: str # "pattern" | "fact" + + +def _run_pattern_assertions(response: str, expected_patterns: list) -> list[AssertionResult]: + """Run pattern assertions against a response. + + Each pattern spec can be a plain regex string or a dict with + ``pattern``, ``min_count``, ``max_count``, ``description`` keys. + """ + results = [] + for pattern_spec in expected_patterns: + if isinstance(pattern_spec, str): + pattern = pattern_spec + min_count = 1 + max_count = None + description = pattern[:40] + else: + pattern = pattern_spec["pattern"] + min_count = pattern_spec.get("min_count", 1) + max_count = pattern_spec.get("max_count", None) + description = pattern_spec.get("description", pattern[:40]) + + matches = len(re.findall(pattern, response, re.IGNORECASE)) + + if max_count is not None: + passed = min_count <= matches <= max_count + rationale = f"Found {matches} matches (need {min_count}-{max_count})" + else: + passed = matches >= min_count + rationale = f"Found {matches} matches (need >={min_count})" + + results.append(AssertionResult( + name=f"pattern_{description}", + passed=passed, + rationale=rationale, + assertion_type="pattern", + )) + return results + + +def _run_fact_assertions(response: str, expected_facts: list[str]) -> list[AssertionResult]: + """Run fact assertions against a response. + + Exact substring match (case-insensitive). No fuzzy keyword overlap. + """ + response_lower = response.lower() + results = [] + for fact in expected_facts: + found = fact.lower() in response_lower + results.append(AssertionResult( + name=f"fact_{fact[:40]}", + passed=found, + rationale=f"{'Found' if found else 'Missing'}: {fact}", + assertion_type="fact", + )) + return results + + +def run_all_assertions(response: str, expectations: dict[str, Any]) -> list[AssertionResult]: + """Run all pattern + fact assertions, return binary pass/fail per assertion. + + Args: + response: The text to check assertions against. + expectations: Dict with optional ``expected_patterns`` and ``expected_facts`` keys. + + Returns: + List of AssertionResult with binary pass/fail for each assertion. + """ + results: list[AssertionResult] = [] + + patterns = expectations.get("expected_patterns", []) + if patterns: + results.extend(_run_pattern_assertions(response, patterns)) + + facts = expectations.get("expected_facts", []) + if facts: + results.extend(_run_fact_assertions(response, facts)) + + return results + + +def _classify_assertion( + with_result: AssertionResult, + without_result: AssertionResult, +) -> str: + """Classify a single assertion by comparing with-skill vs without-skill. + + Returns one of: + POSITIVE — fails without skill, passes with (skill is helping) + REGRESSION — passes without skill, fails with (skill is confusing the agent) + NEEDS_SKILL — fails both with and without (skill must add this content) + NEUTRAL — same result either way (agent already knows this) + """ + if with_result.passed and not without_result.passed: + return "POSITIVE" + elif not with_result.passed and without_result.passed: + return "REGRESSION" + elif not with_result.passed and not without_result.passed: + return "NEEDS_SKILL" + else: + return "NEUTRAL" + + +def _extract_content(result: AssertionResult) -> str: + """Extract the actual expected content from an assertion result. + + For facts, strips the ``Missing: `` / ``Found: `` prefix to get the raw + fact text. For patterns, uses the description embedded in the assertion + name (strips the ``pattern_`` prefix). + """ + if result.assertion_type == "fact": + for prefix in ("Missing: ", "Found: "): + if result.rationale.startswith(prefix): + return result.rationale[len(prefix):] + return result.rationale + else: + # Pattern: name is "pattern_{description}", rationale is match count + return result.name.removeprefix("pattern_") + + +def summarize_failures( + with_results: list[AssertionResult], + without_results: list[AssertionResult], +) -> dict[str, str]: + """Build GEPA-friendly diagnostic strings from assertion results. + + Collects only NEEDS_SKILL and REGRESSION assertions (skips NEUTRAL/POSITIVE) + and produces structured output that maps to GEPA's standard diagnostic keys. + + Only non-empty keys are included in the returned dict so that GEPA does not + render empty ``## Header`` sections that waste tokens and confuse the + reflection LM. + + Returns: + Dict with a subset of: ``Error``, ``Regressions``. + ``Error`` carries compact NEEDS_SKILL/REGRESSION tokens that downstream + consumers (``_review_skillbench``, ``build_skillbench_background``) parse. + ``Regressions`` is a concise NL summary only present when regressions exist. + """ + needs_skill: list[tuple[AssertionResult, AssertionResult]] = [] + regressions: list[tuple[AssertionResult, AssertionResult]] = [] + + for w, wo in zip(with_results, without_results): + label = _classify_assertion(w, wo) + if label == "NEEDS_SKILL": + needs_skill.append((w, wo)) + elif label == "REGRESSION": + regressions.append((w, wo)) + + result: dict[str, str] = {} + + # Error: compact assertion labels (NEEDS_SKILL/REGRESSION tokens preserved) + error_lines: list[str] = [] + for w, _ in needs_skill: + content = _extract_content(w) + error_lines.append(f"NEEDS_SKILL: {w.assertion_type} — '{content}'") + for w, _ in regressions: + content = _extract_content(w) + error_lines.append(f"REGRESSION: {w.assertion_type} — '{content}'") + if error_lines: + result["Error"] = "\n".join(error_lines) + + # Regressions: concise NL (only when non-empty) + if regressions: + lines: list[str] = [] + for i, (w, wo) in enumerate(regressions, 1): + content = _extract_content(w) + lines.append(f"{i}. '{content}' — passes without skill, fails with it") + result["Regressions"] = "\n".join(lines) + + return result diff --git a/.test/src/skill_test/optimize/config.py b/.test/src/skill_test/optimize/config.py new file mode 100644 index 00000000..2aca7a6e --- /dev/null +++ b/.test/src/skill_test/optimize/config.py @@ -0,0 +1,318 @@ +"""GEPA configuration presets for skill optimization. + +Uses the optimize_anything API with GEPAConfig/EngineConfig/ReflectionConfig. +""" + +import os +import warnings + +from gepa.optimize_anything import GEPAConfig, EngineConfig, ReflectionConfig, RefinerConfig + +DEFAULT_REFLECTION_LM = os.environ.get( + "GEPA_REFLECTION_LM", "databricks/databricks-claude-opus-4-6" +) + +DEFAULT_GEN_LM = os.environ.get( + "GEPA_GEN_LM", "databricks/databricks-claude-sonnet-4-6" +) + +DEFAULT_TOKEN_BUDGET: int | None = int( + os.environ.get("GEPA_TOKEN_BUDGET", "0") +) or None + +# --------------------------------------------------------------------------- +# Register Databricks models with litellm so it knows their true context +# windows. Without this, litellm may fuzzy-match to a similar model with +# different limits, or worse, the Databricks serving endpoint may reject +# requests that exceed a vLLM-configured max_model_len. +# +# NOTE: This does NOT override the endpoint's own max_model_len setting. +# If the Databricks endpoint itself is configured with a low limit (e.g. +# 8192), you must either reconfigure the endpoint or use a different +# provider (openai/, anthropic/) whose endpoints support larger contexts. +# --------------------------------------------------------------------------- +def _configure_litellm_retries() -> None: + """Configure litellm to retry on transient errors (429, 529, 500). + + GEPA calls litellm.completion() without passing num_retries, so we + set it globally. This handles Anthropic 529 "Overloaded" errors, + rate limits, and other transient failures with exponential backoff. + + Rate-limit retries get extra attempts (10) since --include-tools sends + large contexts that easily hit token-per-minute ceilings on Opus. + """ + try: + import litellm + from litellm import RetryPolicy + + litellm.num_retries = 5 + litellm.request_timeout = 180 # seconds per attempt + litellm.retry_policy = RetryPolicy( + RateLimitErrorRetries=10, + InternalServerErrorRetries=5, + TimeoutErrorRetries=5, + ) + # Drop log noise from retries + litellm.suppress_debug_info = True + except ImportError: + pass + + +def _register_litellm_models() -> None: + """Register Databricks model context windows with litellm.""" + try: + import litellm + _models = { + "databricks/databricks-claude-opus-4-6": { + "max_tokens": 32_000, + "max_input_tokens": 200_000, + "max_output_tokens": 32_000, + "litellm_provider": "databricks", + "mode": "chat", + "input_cost_per_token": 0, + "output_cost_per_token": 0, + }, + "databricks/databricks-claude-sonnet-4-6": { + "max_tokens": 16_000, + "max_input_tokens": 200_000, + "max_output_tokens": 16_000, + "litellm_provider": "databricks", + "mode": "chat", + "input_cost_per_token": 0, + "output_cost_per_token": 0, + }, + "databricks/databricks-gpt-5-2": { + "max_tokens": 128_000, + "max_input_tokens": 272_000, + "max_output_tokens": 128_000, + "litellm_provider": "databricks", + "mode": "chat", + "input_cost_per_token": 0, + "output_cost_per_token": 0, + }, + } + for model_name, model_info in _models.items(): + litellm.model_cost[model_name] = model_info + except ImportError: + pass + + +_register_litellm_models() +_configure_litellm_retries() + + +# Overhead multiplier: the reflection prompt is roughly this many times +# the raw candidate tokens (includes background, ASI, framing). +_REFLECTION_OVERHEAD_MULTIPLIER = 3 + +PRESETS: dict[str, GEPAConfig] = { + "quick": GEPAConfig( + engine=EngineConfig(max_metric_calls=15, parallel=True), + reflection=ReflectionConfig(reflection_lm=DEFAULT_REFLECTION_LM), + refiner=RefinerConfig(max_refinements=1), + ), + "standard": GEPAConfig( + engine=EngineConfig(max_metric_calls=50, parallel=True), + reflection=ReflectionConfig( + reflection_lm=DEFAULT_REFLECTION_LM, + reflection_minibatch_size=3, + ), + refiner=RefinerConfig(max_refinements=1), + ), + "thorough": GEPAConfig( + engine=EngineConfig(max_metric_calls=150, parallel=True), + reflection=ReflectionConfig( + reflection_lm=DEFAULT_REFLECTION_LM, + reflection_minibatch_size=3, + ), + refiner=RefinerConfig(max_refinements=1), + ), +} + +# Base max_metric_calls per preset (used to scale by component count) +PRESET_BASE_CALLS: dict[str, int] = { + "quick": 15, + "standard": 50, + "thorough": 150, +} + +# Maximum total metric calls per pass to avoid runaway runtimes. +# With many components, uncapped scaling (e.g., 50 * 17 = 850) can cause +# multi-hour hangs with slower reflection models like Sonnet. +MAX_METRIC_CALLS_PER_PASS = 300 + +# Models known to be fast enough for large multi-component optimization. +# Other models get the metric-call cap applied. +_FAST_REFLECTION_MODELS = { + "databricks/databricks-claude-opus-4-6", + "databricks/databricks-gpt-5-2", + "openai/gpt-4o", + "anthropic/claude-opus-4-6", +} + + +def validate_databricks_env() -> None: + """Check that DATABRICKS_API_BASE is set correctly for litellm. + + litellm's Databricks provider requires: + DATABRICKS_API_BASE=https://.cloud.databricks.com/serving-endpoints + + A common mistake is omitting /serving-endpoints, which causes 404 errors. + """ + api_base = os.environ.get("DATABRICKS_API_BASE", "") + if api_base and not api_base.rstrip("/").endswith("/serving-endpoints"): + fixed = api_base.rstrip("/") + "/serving-endpoints" + warnings.warn( + f"DATABRICKS_API_BASE={api_base!r} is missing '/serving-endpoints' suffix. " + f"litellm will get 404 errors. Automatically fixing to: {fixed}", + stacklevel=2, + ) + os.environ["DATABRICKS_API_BASE"] = fixed + + +def validate_reflection_context( + reflection_lm: str, + total_candidate_tokens: int, +) -> None: + """Warn if the candidate is likely too large for the reflection model. + + Queries litellm's model registry for the model's max_input_tokens and + compares against the estimated reflection prompt size. + + Note: this checks litellm's *client-side* knowledge of the model. The + Databricks serving endpoint may have a *different* (lower) limit set via + vLLM's ``max_model_len``. If you see ``BadRequestError`` with + ``max_model_len`` in the message, the endpoint itself is the bottleneck -- + switch to a provider whose endpoint supports your context needs (e.g. + ``openai/gpt-4o`` or ``anthropic/claude-sonnet-4-5-20250514``). + """ + try: + import litellm + info = litellm.get_model_info(reflection_lm) + limit = info.get("max_input_tokens") or info.get("max_tokens") or 0 + except Exception: + return # can't determine limit -- skip check + + if limit <= 0: + return + + estimated_prompt = total_candidate_tokens * _REFLECTION_OVERHEAD_MULTIPLIER + if estimated_prompt > limit: + raise ValueError( + f"\nReflection model '{reflection_lm}' has a {limit:,}-token input limit " + f"(per litellm), but the estimated reflection prompt is ~{estimated_prompt:,} " + f"tokens ({total_candidate_tokens:,} candidate tokens x " + f"{_REFLECTION_OVERHEAD_MULTIPLIER} overhead).\n\n" + f"Fix: use a model with a larger context window:\n" + f" --reflection-lm 'databricks/databricks-claude-opus-4-6' (200K)\n" + f" --reflection-lm 'openai/gpt-4o' (128K)\n" + f" --reflection-lm 'anthropic/claude-sonnet-4-5-20250514' (200K)\n\n" + f"Or set the environment variable:\n" + f" export GEPA_REFLECTION_LM='databricks/databricks-claude-opus-4-6'\n\n" + f"If you already use a large-context model and still see 'max_model_len'\n" + f"errors, the Databricks serving endpoint itself has a low context limit.\n" + f"Switch to a non-Databricks provider (openai/ or anthropic/) instead.\n\n" + f" Current GEPA_REFLECTION_LM={os.environ.get('GEPA_REFLECTION_LM', '(not set)')}" + ) + + +def estimate_pass_duration( + num_metric_calls: int, + reflection_lm: str, + total_candidate_tokens: int, + num_dataset_examples: int = 7, +) -> float | None: + """Estimate wall-clock seconds for one optimization pass. + + Metric calls are mostly fast local evaluations. The slow part is + reflection LLM calls, which happen roughly once per iteration + (num_metric_calls / num_dataset_examples iterations). + + Returns None if estimation is not possible. + """ + # Rough per-reflection latency (seconds) based on model class + if reflection_lm in _FAST_REFLECTION_MODELS: + secs_per_reflection = 5.0 + elif "sonnet" in reflection_lm.lower(): + secs_per_reflection = 20.0 + elif "haiku" in reflection_lm.lower(): + secs_per_reflection = 8.0 + else: + secs_per_reflection = 15.0 + + # Scale by candidate size (larger candidates → slower) + size_factor = min(max(1.0, total_candidate_tokens / 10_000), 2.5) + adjusted = secs_per_reflection * size_factor + + # Approximate iterations (each iteration evaluates all dataset examples) + num_iterations = max(1, num_metric_calls // max(num_dataset_examples, 1)) + + return num_iterations * adjusted + + +def get_preset( + name: str, + reflection_lm: str | None = None, + num_components: int = 1, + max_metric_calls_override: int | None = None, +) -> GEPAConfig: + """Get a GEPA config preset by name, scaled by component count. + + When optimizing multiple components (skill + tool modules), GEPA's + round-robin selector divides the budget across all of them. We scale + ``max_metric_calls`` so that *each component* receives the preset's + base budget rather than splitting it. + + For slower reflection models (non-Opus/GPT-4o), the total metric calls + are capped at ``MAX_METRIC_CALLS_PER_PASS`` to avoid multi-hour hangs. + + Args: + name: One of "quick", "standard", "thorough" + reflection_lm: Override reflection LM model string + num_components: Number of GEPA components (used to scale budget) + max_metric_calls_override: Explicit cap on metric calls per pass + + Returns: + GEPAConfig instance + """ + if name not in PRESETS: + raise KeyError(f"Unknown preset '{name}'. Choose from: {list(PRESETS.keys())}") + + # Validate Databricks env if using databricks/ prefix + effective_lm = reflection_lm or DEFAULT_REFLECTION_LM + if isinstance(effective_lm, str) and effective_lm.startswith("databricks/"): + validate_databricks_env() + + base_calls = PRESET_BASE_CALLS[name] + scaled_calls = base_calls * max(num_components, 1) + + # Apply explicit override if provided + if max_metric_calls_override is not None: + scaled_calls = max_metric_calls_override + # Cap for slower models to avoid multi-hour hangs + elif effective_lm not in _FAST_REFLECTION_MODELS and scaled_calls > MAX_METRIC_CALLS_PER_PASS: + warnings.warn( + f"Capping metric calls from {scaled_calls} to {MAX_METRIC_CALLS_PER_PASS} " + f"for reflection model '{effective_lm}'. " + f"Use --max-metric-calls to override, or use a faster model " + f"(e.g., databricks/databricks-claude-opus-4-6).", + stacklevel=2, + ) + scaled_calls = MAX_METRIC_CALLS_PER_PASS + + config = PRESETS[name] + config = GEPAConfig( + engine=EngineConfig( + max_metric_calls=scaled_calls, + parallel=config.engine.parallel, + ), + reflection=ReflectionConfig( + reflection_lm=reflection_lm or config.reflection.reflection_lm, + reflection_minibatch_size=config.reflection.reflection_minibatch_size, + skip_perfect_score=config.reflection.skip_perfect_score, + ), + merge=config.merge, + refiner=config.refiner, + tracking=config.tracking, + ) + return config diff --git a/.test/src/skill_test/optimize/evaluator.py b/.test/src/skill_test/optimize/evaluator.py new file mode 100644 index 00000000..88c8157b --- /dev/null +++ b/.test/src/skill_test/optimize/evaluator.py @@ -0,0 +1,581 @@ +"""Evaluator bridge: wrap existing MLflow scorers into optimize_anything evaluators. + +Creates GEPA-compatible evaluator functions that take a candidate (str or dict) +and a task example, run existing scorers, and return (score, SideInfo). +""" + +import inspect +import re +from pathlib import Path +from typing import Any, Callable, Literal + +import tiktoken +import gepa.optimize_anything as oa +from mlflow.entities import Feedback + +from ..runners.evaluate import build_scorers, load_scorer_config +from ..scorers.universal import ( + python_syntax, + sql_syntax, + no_hallucinated_apis, + pattern_adherence, + expected_facts_present, +) +from .asi import feedback_to_asi +from .skillbench_evaluator import ( # noqa: F401 — re-exported for runner.py + create_skillbench_evaluator, + build_skillbench_background, +) + + +# --------------------------------------------------------------------------- +# Path utilities +# --------------------------------------------------------------------------- + +def _find_repo_root() -> Path: + """Find the repo root by searching upward for .test/src/.""" + current = Path(__file__).resolve().parent + while current != current.parent: + if (current / ".test" / "src").exists(): + return current + if (current / "src" / "skill_test").exists() and current.name == ".test": + return current.parent + current = current.parent + return Path.cwd() + + +def _find_skill_md(skill_name: str) -> Path | None: + """Locate the SKILL.md file for a given skill name.""" + repo_root = _find_repo_root() + candidates = [ + repo_root / ".claude" / "skills" / skill_name / "SKILL.md", + repo_root / "databricks-skills" / skill_name / "SKILL.md", + ] + for p in candidates: + if p.exists(): + return p + return None + + +# --------------------------------------------------------------------------- +# Token utilities +# --------------------------------------------------------------------------- + +def count_tokens(text: str) -> int: + """Count tokens using cl100k_base encoding.""" + enc = tiktoken.get_encoding("cl100k_base") + return len(enc.encode(text)) + + +def token_efficiency_score(candidate_text: str, original_token_count: int) -> float: + """Score based on how concise the candidate is vs. the original. + + Smaller than original = bonus up to 1.15, same size = 1.0, + larger = linear penalty to 0.0 at 2x. + """ + if original_token_count <= 0: + return 1.0 + enc = tiktoken.get_encoding("cl100k_base") + candidate_tokens = len(enc.encode(candidate_text)) + ratio = candidate_tokens / original_token_count + if ratio <= 1.0: + return 1.0 + 0.15 * (1.0 - ratio) + else: + return max(0.0, 2.0 - ratio) + + +# --------------------------------------------------------------------------- +# Scorer execution +# --------------------------------------------------------------------------- + +def _run_scorer(scorer_fn: Any, outputs: dict, expectations: dict, inputs: dict) -> list[Feedback]: + """Run a single scorer and normalize result to a list of Feedbacks.""" + sig = inspect.signature(scorer_fn) + params = list(sig.parameters.keys()) + + kwargs = {} + if "outputs" in params: + kwargs["outputs"] = outputs + if "expectations" in params: + kwargs["expectations"] = expectations + if "inputs" in params: + kwargs["inputs"] = inputs + + try: + result = scorer_fn(**kwargs) + except Exception as e: + return [Feedback(name=getattr(scorer_fn, "__name__", "unknown"), value="no", rationale=str(e))] + + if isinstance(result, list): + return result + elif isinstance(result, Feedback): + return [result] + return [] + + +def _run_deterministic_scorers( + response: str, + expectations: dict[str, Any], + prompt: str, + scorer_config: dict[str, Any], +) -> list[Feedback]: + """Run deterministic scorers against a response.""" + outputs = {"response": response} + inputs = {"prompt": prompt} + + if scorer_config: + scorers = build_scorers(scorer_config) + else: + scorers = [python_syntax, sql_syntax, pattern_adherence, no_hallucinated_apis, expected_facts_present] + + all_feedbacks = [] + for scorer_fn in scorers: + scorer_name = getattr(scorer_fn, "__name__", "") or getattr(scorer_fn, "name", "") + if scorer_name in ("Safety", "Guidelines", "skill_quality"): + continue + all_feedbacks.extend(_run_scorer(scorer_fn, outputs, expectations, inputs)) + + return all_feedbacks + + +def _validate_skill_structure(candidate_text: str) -> list[Feedback]: + """Validate the SKILL.md structure itself.""" + outputs = {"response": candidate_text} + feedbacks = [] + for scorer_fn in [python_syntax, sql_syntax, no_hallucinated_apis]: + result = _run_scorer(scorer_fn, outputs, {}, {}) + for fb in result: + feedbacks.append(Feedback(name=f"skill_{fb.name}", value=fb.value, rationale=fb.rationale)) + return feedbacks + + +_STOP_WORDS = frozenset({ + "the", "and", "for", "with", "that", "this", "from", "are", "was", + "were", "been", "being", "have", "has", "had", "does", "did", "but", + "not", "you", "all", "can", "her", "his", "its", "may", "our", + "out", "use", "uses", "will", "how", "who", "get", "which", "would", + "make", "like", "into", "than", "them", "then", "each", "other", + "should", "could", +}) + + +def _keyword_fact_score(fact: str, text: str) -> float: + """Score 0-1 based on keyword overlap between a fact and text.""" + words = [w for w in re.findall(r'\w{3,}', fact.lower()) if w not in _STOP_WORDS] + if not words: + return 1.0 + text_lower = text.lower() + found = sum(1 for w in words if w in text_lower) + return found / len(words) + + +def _score_skill_content_facts(candidate_text: str, expected_facts: list[str]) -> list[Feedback]: + """Score SKILL.md content against expected facts using keyword matching. + + Unlike the universal ``expected_facts_present`` scorer which requires exact + substring matches, this uses keyword extraction so descriptive facts like + "Uses CREATE OR REPLACE VIEW with WITH METRICS LANGUAGE YAML" match when + the individual keywords appear in the skill text. + """ + feedbacks = [] + for fact in expected_facts: + score = _keyword_fact_score(fact, candidate_text) + feedbacks.append(Feedback( + name=f"skill_content_fact_{fact[:40]}", + value=score, # continuous 0.0-1.0 + rationale=f"(skill content) Keyword match {score:.0%} for: {fact}", + )) + return feedbacks + + +def _score_skill_content(candidate_text: str, expectations: dict[str, Any]) -> list[Feedback]: + """Score the SKILL.md candidate itself for pattern/fact coverage. + + Runs pattern_adherence and expected_facts_present against the skill text + (not the response). This gives GEPA immediate dynamic signal: if a key + pattern is removed from SKILL.md, the score drops. + + Feedback names are prefixed with ``skill_content_`` to distinguish from + response-level scores. + """ + outputs = {"response": candidate_text} + feedbacks = [] + + # Pattern adherence on skill content + pa_results = _run_scorer(pattern_adherence, outputs, expectations, {}) + for fb in pa_results: + feedbacks.append(Feedback( + name=f"skill_content_{fb.name}", + value=fb.value, + rationale=f"(skill content) {fb.rationale or ''}", + )) + + # Expected facts on skill content (keyword matching for descriptive facts) + expected_facts = expectations.get("expected_facts", []) + if expected_facts: + feedbacks.extend(_score_skill_content_facts(candidate_text, expected_facts)) + + return feedbacks + + +# --------------------------------------------------------------------------- +# Evaluator factory (optimize_anything compatible) +# --------------------------------------------------------------------------- + +SKILL_KEY = "skill_md" + + +def create_skill_evaluator( + skill_name: str, + mode: Literal["static", "generative"] = "static", + task_lm: str | None = None, + gen_model: str | None = None, + original_token_counts: dict[str, int] | None = None, +) -> Callable: + """Create an optimize_anything-compatible evaluator for a skill. + + Returns a function: (candidate, example) -> (score, side_info) + + The candidate is dict[str, str] (may have "skill_md" + "tools_*" keys). + The example is a task dict from the dataset. + + Evaluation layers: + 1. Skill-content scoring: pattern/fact presence in SKILL.md itself + 2. Generative evaluation: LLM generates response from skill, scored + 3. Reference response check: fixed ground truth scoring (sanity) + 4. Structure validation: syntax, no hallucinated APIs on SKILL.md + 5. Token efficiency: conciseness vs original + + Args: + skill_name: Name of the skill being evaluated + mode: "static" uses ground truth response, "generative" generates fresh + task_lm: LLM for generative mode (deprecated, use gen_model) + gen_model: LLM model for generative evaluation + original_token_counts: Token counts of original artifacts + """ + scorer_config = load_scorer_config(skill_name) + effective_gen_model = gen_model or task_lm + + # Track whether we've warned about generation failure + _gen_warned = [False] + + # Compute original token count for efficiency scoring + if original_token_counts is None: + skill_path = _find_skill_md(skill_name) + original_token_counts = { + SKILL_KEY: count_tokens(skill_path.read_text()) if skill_path else 0 + } + total_original_tokens = sum(original_token_counts.values()) + + # Mutable closure state: per-task baseline scorer scores for comparison. + # Populated via evaluator.set_baseline() after evaluating the seed. + _baseline: dict[str, dict[str, float | None]] = {} + + def evaluator(candidate: dict[str, str], example: dict) -> tuple[float, dict]: + """Evaluate a candidate against a single task example. + + Args: + candidate: dict[str, str] with "skill_md" and/or "tools_*" keys + example: Task dict with "input", "answer", "additional_context" + + Returns: + (score, side_info) tuple for optimize_anything + """ + candidate_text = candidate.get(SKILL_KEY, "") + + # Decode expectations from additional_context + expectations = {} + expectations_json = example.get("additional_context", {}).get("expectations", "") + if expectations_json: + import json + try: + expectations = json.loads(expectations_json) + except (json.JSONDecodeError, TypeError): + pass + + # ------------------------------------------------------------------ + # Layer 1: Skill-content scoring (pattern/fact presence in SKILL.md) + # ------------------------------------------------------------------ + skill_content_feedbacks: list[Feedback] = [] + if candidate_text and expectations: + skill_content_feedbacks = _score_skill_content(candidate_text, expectations) + + skill_content_composite, skill_content_si = feedback_to_asi(skill_content_feedbacks) + + # ------------------------------------------------------------------ + # Layer 2: Generative evaluation (LLM generates from skill, score that) + # ------------------------------------------------------------------ + generated_response = None + gen_feedbacks: list[Feedback] = [] + gen_composite = 0.0 + + _gen_available = False # Track if generation actually worked + if effective_gen_model and candidate_text and example.get("input"): + import litellm + messages = [ + { + "role": "system", + "content": ( + "Use ONLY the following skill documentation to answer " + "the user's question. Do not use any other knowledge.\n\n" + f"{candidate_text}" + ), + }, + {"role": "user", "content": example.get("input", "")}, + ] + try: + from .skillbench_evaluator import _completion_with_backoff + resp = _completion_with_backoff(model=effective_gen_model, messages=messages) + generated_response = resp.choices[0].message.content + _gen_available = True + except Exception as e: + generated_response = None + gen_feedbacks.append(Feedback( + name="generation_error", + value="no", + rationale=f"LLM generation failed: {e}", + )) + if not _gen_warned[0]: + _gen_warned[0] = True + import warnings + warnings.warn( + f"\nGeneration model '{effective_gen_model}' failed: {e}\n" + "Falling back to skill-content + reference scoring (no generative eval).\n" + "The 20% 'generated response quality' layer will be inactive.\n" + "Fix: set DATABRICKS_API_KEY + DATABRICKS_API_BASE, or use " + "--gen-model with a working provider (e.g., --gen-model openai/gpt-4o).\n", + stacklevel=2, + ) + + if generated_response: + gen_feedbacks = _run_deterministic_scorers( + generated_response, expectations, example.get("input", ""), scorer_config + ) + + gen_composite, gen_si = feedback_to_asi(gen_feedbacks) + + # ------------------------------------------------------------------ + # Layer 3: Reference response check (ground truth — sanity baseline) + # ------------------------------------------------------------------ + reference_response = example.get("answer", "") + ref_feedbacks: list[Feedback] = [] + ref_composite = 0.0 + + if reference_response: + ref_feedbacks = _run_deterministic_scorers( + reference_response, expectations, example.get("input", ""), scorer_config + ) + ref_composite, _ = feedback_to_asi(ref_feedbacks) + + # ------------------------------------------------------------------ + # Layer 4: Validate skill structure + # ------------------------------------------------------------------ + structure_feedbacks: list[Feedback] = [] + if candidate_text: + structure_feedbacks = _validate_skill_structure(candidate_text) + + structure_composite, _ = feedback_to_asi(structure_feedbacks) + + # ------------------------------------------------------------------ + # Layer 5: Token efficiency across ALL components + # ------------------------------------------------------------------ + total_candidate_tokens = sum(count_tokens(v) for v in candidate.values()) + if total_original_tokens > 0: + ratio = total_candidate_tokens / total_original_tokens + if ratio <= 1.0: + efficiency = 1.0 + 0.15 * (1.0 - ratio) + else: + efficiency = max(0.0, 2.0 - ratio) + else: + efficiency = 1.0 + + # ------------------------------------------------------------------ + # Weighted final score + # ------------------------------------------------------------------ + # When generative eval succeeds, it gets the dominant weight. + # When gen fails (auth error, timeout, etc), fall back to + # skill-content-heavy weighting — this is the only layer that + # changes dynamically as GEPA mutates the skill. + if _gen_available and generated_response is not None: + # Full layered evaluation + final_score = ( + 0.20 * gen_composite # Generated response quality + + 0.35 * skill_content_composite # Skill content coverage + + 0.05 * ref_composite # Reference response (sanity) + + 0.10 * structure_composite # Structure validation + + 0.30 * efficiency # Token efficiency + ) + else: + # Fallback: no generative eval, emphasize skill content + efficiency + final_score = ( + 0.40 * skill_content_composite # Skill content coverage + + 0.20 * ref_composite # Reference response + + 0.10 * structure_composite # Structure validation + + 0.30 * efficiency # Token efficiency + ) + + # ------------------------------------------------------------------ + # Build unified side_info for GEPA reflection + # ------------------------------------------------------------------ + # Merge all feedbacks for the side_info dict + all_feedbacks = skill_content_feedbacks + gen_feedbacks + ref_feedbacks + structure_feedbacks + _, side_info = feedback_to_asi(all_feedbacks) + + side_info["scores"] = { + "generated_response_quality": gen_composite, + "skill_content_coverage": skill_content_composite, + "reference_response_check": ref_composite, + "structure_validation": structure_composite, + "token_efficiency": efficiency, + "final": final_score, + } + side_info["token_counts"] = { + "candidate_total": total_candidate_tokens, + "original_total": total_original_tokens, + } + + # Enrich ASI for GEPA reflection (Step 4 from plan) + if generated_response is not None: + side_info["_generated_response"] = generated_response[:2000] + side_info["_task_prompt"] = example.get("input", "")[:500] + + # Skill coverage summary + if skill_content_feedbacks: + found = [fb.name for fb in skill_content_feedbacks if fb.value == "yes"] + missing = [fb.name for fb in skill_content_feedbacks if fb.value == "no"] + side_info["_skill_coverage"] = { + "found": found, + "missing": missing, + "coverage_ratio": len(found) / max(len(found) + len(missing), 1), + } + + # Baseline comparison -- show GEPA's reflection LM what improved/regressed + task_key = example.get("input", "") + if task_key and task_key in _baseline: + comparisons = [] + for scorer_name, baseline_val in _baseline[task_key].items(): + current_val = side_info.get(scorer_name, {}).get("score") + if current_val is None or baseline_val is None: + continue + if current_val > baseline_val + 0.01: + comparisons.append( + f"Improved on {scorer_name} ({baseline_val:.2f} -> {current_val:.2f})" + ) + elif current_val < baseline_val - 0.01: + comparisons.append( + f"Regressed on {scorer_name} ({baseline_val:.2f} -> {current_val:.2f})" + ) + if comparisons: + side_info["_baseline_comparison"] = "; ".join(comparisons) + + return final_score, side_info + + def set_baseline(per_task_side_info: dict[str, dict]) -> None: + """Cache per-task per-scorer scores from the seed evaluation. + + Args: + per_task_side_info: {task_input_text: side_info_dict} from seed eval. + """ + for task_key, info in per_task_side_info.items(): + _baseline[task_key] = { + name: data.get("score") + for name, data in info.items() + if isinstance(data, dict) and "score" in data + and not name.startswith("_") + } + + evaluator.set_baseline = set_baseline # type: ignore[attr-defined] + return evaluator + + +def build_optimization_background( + skill_name: str, + original_token_count: int, + component_names: list[str] | None = None, + baseline_scores: dict[str, float] | None = None, + baseline_side_info: dict[str, dict] | None = None, +) -> str: + """Build the background context string for GEPA's reflection LM. + + Args: + skill_name: Name of the skill being optimized. + original_token_count: Total token count of the original artifacts. + component_names: Names of the candidate components (e.g. "skill_md", "tools_*"). + baseline_scores: Per-task overall scores from evaluating the seed candidate. + baseline_side_info: Per-task side_info dicts from evaluating the seed candidate. + """ + components_desc = "" + if component_names and any(c.startswith("tools_") for c in component_names): + tool_modules = [c.replace("tools_", "") for c in component_names if c.startswith("tools_")] + components_desc = ( + "\n\nYou are also optimizing MCP tool descriptions for these modules: " + f"{', '.join(tool_modules)}. " + "Tool descriptions are docstrings on @mcp.tool functions. Keep them " + "accurate, concise, and action-oriented.\n" + ) + + # Build baseline performance summary + baseline_desc = "" + if baseline_scores: + mean_score = sum(baseline_scores.values()) / len(baseline_scores) + perfect = [tid for tid, s in baseline_scores.items() if s >= 0.99] + weak = sorted( + [(tid, s) for tid, s in baseline_scores.items() if s < 0.99], + key=lambda x: x[1], + ) + + baseline_desc = ( + f"\n\nBASELINE PERFORMANCE (seed candidate):\n" + f" Mean score: {mean_score:.3f} across {len(baseline_scores)} test cases.\n" + ) + if perfect: + baseline_desc += f" Perfect/near-perfect ({len(perfect)}): {', '.join(perfect)}\n" + if weak: + baseline_desc += " Needs improvement:\n" + for tid, score in weak: + baseline_desc += f" - {tid}: {score:.3f}" + # Add per-scorer detail if available + if baseline_side_info and tid in baseline_side_info: + info = baseline_side_info[tid] + failing = [ + name for name, data in info.items() + if isinstance(data, dict) and data.get("status") == "fail" + and not name.startswith("_") + ] + if failing: + baseline_desc += f" (failing: {', '.join(failing)})" + baseline_desc += "\n" + + baseline_desc += ( + "\n PRIORITY: Focus optimization effort on the weak test cases above. " + "Do NOT break test cases that already score well.\n" + ) + + return ( + f"You are REFINING an existing, working SKILL.md file for the '{skill_name}' " + "Databricks skill. The seed candidate is a production skill that already works -- " + "preserve what already works and improve what doesn't.\n\n" + "SKILL.md files teach AI agents (like Claude Code) how to use specific Databricks features. " + "They contain patterns, code examples, API references, and best practices.\n\n" + "EVALUATION: The skill is evaluated by having a small LLM generate responses from it. " + "Better skill documentation produces more correct responses. Scores come from:\n" + "- Generated response quality (20%): An LLM reads ONLY the skill and answers a test prompt. " + "Its response is scored against expected patterns and facts.\n" + "- Skill content coverage (35%): Does the SKILL.md itself contain the patterns and facts " + "needed to answer test prompts? Removing key content directly drops this score.\n" + "- Reference response check (5%): Sanity check against a known-good response.\n" + "- Structure validation (10%): Python/SQL syntax, no hallucinated APIs.\n" + "- Token efficiency (30%): Conciseness vs original -- smaller is ACTIVELY REWARDED. " + "Shrinking the skill below its original size gives a bonus score (up to 1.15x at 0% of original). " + "Growing the skill is penalized linearly to 0.0 at 2x original size.\n\n" + "KEY INSIGHT: Token efficiency is the second-highest weight. Every token you remove " + "directly improves the score. SkillsBench research shows long skills hurt agent performance " + "via 'cognitive overhead' -- agents get confused by verbose docs. Be ruthlessly concise.\n\n" + f"IMPORTANT: The current artifacts total {original_token_count:,} tokens. " + "Optimized versions MUST be MORE CONCISE. Target at least 10-20% token reduction. " + "Remove redundant examples, consolidate similar patterns, " + "eliminate verbose explanations, and merge overlapping sections. " + "Every token consumed is agent context window budget -- keep skills lean and focused." + f"{baseline_desc}" + f"{components_desc}" + ) diff --git a/.test/src/skill_test/optimize/gskill/README.md b/.test/src/skill_test/optimize/gskill/README.md new file mode 100644 index 00000000..1c4621b0 --- /dev/null +++ b/.test/src/skill_test/optimize/gskill/README.md @@ -0,0 +1,128 @@ +# gskill: Auto-Generate Optimized Skills for Your Repository + +`gskill` uses [GEPA](https://github.com/databricks/gepa) to automatically generate optimized SKILL.md files that teach Claude Code how to work with your specific Databricks project patterns. + +## What It Does + +1. **Scans your repository** for Databricks patterns (SDK usage, SQL, notebooks, configs) +2. **Generates a SKILL.md** optimized for AI agent consumption +3. **Validates quality** using the skill-test evaluation framework +4. **Outputs to `.claude/skills/`** so Claude Code automatically picks it up + +## Quick Start + +### Prerequisites + +```bash +# Install GEPA +pip install gepa>=0.1.0 + +# Set up LLM API keys (for GEPA reflection) +export OPENAI_API_KEY=your-key-here +``` + +### Generate a Skill + +```bash +# From the ai-dev-kit repository +cd /path/to/ai-dev-kit + +# Generate a skill for your project repo +uv run python -c " +from skill_test.optimize.gskill import run_gskill +result = run_gskill('/path/to/your/databricks-project') +print(f'Generated: {result[\"skill_path\"]}') +" +``` + +### Using with Claude Code + +Once generated, the skill is automatically available to Claude Code: + +``` +your-repo/ +├── .claude/ +│ └── skills/ +│ └── your-repo/ +│ └── SKILL.md # <- Generated skill +├── src/ +│ └── ... +``` + +Claude Code reads `.claude/skills/*/SKILL.md` files and uses them as context when helping with your code. + +## Configuration + +### Presets + +| Preset | Iterations | Best For | +|--------|-----------|----------| +| `quick` | 15 | Initial generation, small repos | +| `standard` | 50 | Most repos (default) | +| `thorough` | 150 | Large repos, production quality | + +### Custom Context + +Provide additional files for gskill to consider: + +```python +from skill_test.optimize.gskill import run_gskill + +result = run_gskill( + repo_path="/path/to/your/repo", + preset="standard", + context_files=[ + "docs/architecture.md", + "README.md", + "src/config.py", + ], +) +``` + +## Evaluating Generated Skills + +Use the ai-dev-kit evaluation framework to validate generated skills: + +```bash +# 1. Add test cases for the generated skill +uv run python .test/scripts/init_skill.py your-skill-name + +# 2. Add ground truth test cases +uv run python .test/scripts/add.py your-skill-name + +# 3. Run evaluation +uv run python .test/scripts/run_eval.py your-skill-name + +# 4. Optimize further with GEPA +uv run python .test/scripts/optimize.py your-skill-name +``` + +## How It Works + +``` +Your Repository + │ + ▼ +┌──────────┐ ┌──────────┐ ┌──────────┐ +│ Scan for │ --> │ Generate │ --> │ Validate │ --> SKILL.md +│ patterns │ │ SKILL.md │ │ quality │ +└──────────┘ └──────────┘ └──────────┘ + │ + ▼ + GEPA optimize_anything + (reflects on quality, + iterates to improve) +``` + +GEPA's `optimize_anything` treats the SKILL.md as the artifact to optimize. It: +- Starts with patterns found in your repo as the seed +- Uses GEPA's reflection LM to propose improvements +- Scores each iteration for quality, correctness, and conciseness +- Selects the best candidate via Pareto frontier optimization + +## Tips + +- **Keep skills focused**: One skill per domain (e.g., separate skills for "data pipeline" and "model serving") +- **Add test cases**: Skills with ground truth test cases optimize much better than bootstrap mode +- **Iterate**: Run `optimize.py` after adding test cases for incremental improvement +- **Token budget**: Skills should be as concise as possible -- every token consumed is agent context window budget diff --git a/.test/src/skill_test/optimize/gskill/__init__.py b/.test/src/skill_test/optimize/gskill/__init__.py new file mode 100644 index 00000000..d27b7b51 --- /dev/null +++ b/.test/src/skill_test/optimize/gskill/__init__.py @@ -0,0 +1,5 @@ +"""gskill: Generate optimized skills for customer repositories using GEPA.""" + +from .pipeline import run_gskill + +__all__ = ["run_gskill"] diff --git a/.test/src/skill_test/optimize/gskill/pipeline.py b/.test/src/skill_test/optimize/gskill/pipeline.py new file mode 100644 index 00000000..7badfe7c --- /dev/null +++ b/.test/src/skill_test/optimize/gskill/pipeline.py @@ -0,0 +1,144 @@ +"""gskill pipeline: generate optimized skills for customer repositories. + +Uses optimize_anything to produce SKILL.md files from repository context. +""" + +import ast +import re +from pathlib import Path +from typing import Any + +from gepa.optimize_anything import optimize_anything, GEPAConfig, EngineConfig, ReflectionConfig +import gepa.optimize_anything as oa + +from ..config import get_preset + + +def run_gskill( + repo_path: str | Path, + skill_name: str | None = None, + output_dir: str | Path | None = None, + preset: str = "standard", + context_files: list[str] | None = None, +) -> dict[str, Any]: + """Generate an optimized skill for a customer repository. + + Args: + repo_path: Path to the customer's repository + skill_name: Name for the generated skill (auto-detected if None) + output_dir: Override output directory + preset: GEPA optimization preset + context_files: Additional files to provide as context + + Returns: + Dict with generated skill path and metadata + """ + repo_path = Path(repo_path).resolve() + if not repo_path.exists(): + raise FileNotFoundError(f"Repository not found: {repo_path}") + + config = get_preset(preset) + + if skill_name is None: + skill_name = repo_path.name + + repo_context = _scan_repo(repo_path, context_files) + + seed_content = ( + f"# {skill_name}\n\n" + "## Overview\n\n" + f"Patterns and best practices for the {skill_name} project.\n\n" + + repo_context + ) + + def evaluate(candidate: str, example: dict) -> tuple[float, dict]: + """Score structural quality of generated skill.""" + score = 0.0 + parts = 0 + + if re.search(r"^#{1,3}\s+", candidate, re.MULTILINE): + score += 1.0 + parts += 1 + + code_blocks = re.findall(r"```(\w+)\n(.*?)```", candidate, re.DOTALL) + if code_blocks: + score += 1.0 + parts += 1 + py_blocks = [b for lang, b in code_blocks if lang == "python"] + if py_blocks: + valid = sum(1 for b in py_blocks if _parses(b)) + score += valid / len(py_blocks) + parts += 1 + + word_count = len(candidate.split()) + if 200 <= word_count <= 5000: + score += 1.0 + parts += 1 + + final = score / parts if parts > 0 else 0.0 + oa.log(f"Structure score: {final:.2f}, words: {word_count}") + + return final, {"structure_score": final, "word_count": word_count} + + trainset = [ + {"input": f"Help me understand patterns in {skill_name}", "additional_context": {}, "answer": ""}, + {"input": f"Show code examples from {skill_name}", "additional_context": {}, "answer": ""}, + ] + + result = optimize_anything( + seed_candidate=seed_content, + evaluator=evaluate, + dataset=trainset, + objective=f"Generate a SKILL.md that teaches an AI coding agent the patterns in {skill_name}.", + background=( + "SKILL.md files teach AI agents (Claude Code) repository-specific patterns. " + "Focus on Databricks patterns: Unity Catalog, MLflow, Spark, Delta Lake, etc. " + "Be CONCISE and ACTION-ORIENTED. Lead with code examples." + ), + config=config, + ) + + generated_content = result.best_candidate + if isinstance(generated_content, dict): + generated_content = list(generated_content.values())[0] + + if output_dir is None: + output_dir = repo_path / ".claude" / "skills" / skill_name + else: + output_dir = Path(output_dir) + + output_dir.mkdir(parents=True, exist_ok=True) + skill_path = output_dir / "SKILL.md" + skill_path.write_text(generated_content) + + return { + "skill_name": skill_name, + "skill_path": str(skill_path), + "content_length": len(generated_content), + "repo_path": str(repo_path), + "preset": preset, + } + + +def _parses(code: str) -> bool: + try: + ast.parse(code) + return True + except SyntaxError: + return False + + +def _scan_repo(repo_path: Path, context_files: list[str] | None = None) -> str: + context_parts = [] + if context_files: + for f in context_files: + p = Path(f) if Path(f).is_absolute() else repo_path / f + if p.exists(): + content = p.read_text()[:5000] + context_parts.append(f"### {p.name}\n\n```\n{content}\n```\n") + + readme = repo_path / "README.md" + if readme.exists() and not context_files: + context_parts.append(f"### README\n\n{readme.read_text()[:3000]}\n") + + return "\n".join(context_parts) if context_parts else "" diff --git a/.test/src/skill_test/optimize/review.py b/.test/src/skill_test/optimize/review.py new file mode 100644 index 00000000..c5bc4e93 --- /dev/null +++ b/.test/src/skill_test/optimize/review.py @@ -0,0 +1,486 @@ +"""Review and apply workflow for optimization results. + +Provides human-readable output of optimization results and the ability +to apply the optimized SKILL.md to the repository. + +After each optimization run, results are saved to: + .test/skills//optimized_SKILL.md — the optimized content + .test/skills//last_optimization.md — summary with scores and diff + +Use ``--apply-last`` to apply a saved result without re-running optimization. +""" + +import difflib +import json +from datetime import datetime, timezone +from pathlib import Path + +from .runner import OptimizationResult +from .evaluator import _find_skill_md + + +def _get_results_dir(skill_name: str) -> Path: + """Get the results directory for a skill.""" + # Try standard skills dir first + candidates = [ + Path(".test/skills") / skill_name, + Path(__file__).resolve().parent.parent.parent / "skills" / skill_name, + ] + for d in candidates: + if d.exists(): + return d + # Fallback: create under .test/skills + d = Path(".test/skills") / skill_name + d.mkdir(parents=True, exist_ok=True) + return d + + +def save_result(result: OptimizationResult) -> tuple[Path | None, Path | None]: + """Save optimization results to disk for later application. + + Writes two files: + - ``optimized_SKILL.md`` — the raw optimized content (can be diffed/reviewed) + - ``last_optimization.json`` — metadata for ``--apply-last`` + + Returns: + Tuple of (optimized_skill_path, metadata_path), either may be None on error. + """ + if result.improvement <= 0 and result.original_content == result.optimized_content: + return None, None + + results_dir = _get_results_dir(result.skill_name) + + optimized_path = None + metadata_path = None + + # Write the optimized SKILL.md + if result.optimized_content and result.optimized_content != result.original_content: + optimized_path = results_dir / "optimized_SKILL.md" + optimized_path.write_text(result.optimized_content) + + # Write metadata for --apply-last + metadata = { + "skill_name": result.skill_name, + "timestamp": datetime.now(timezone.utc).isoformat(), + "original_score": result.original_score, + "optimized_score": result.optimized_score, + "improvement": result.improvement, + "original_token_count": result.original_token_count, + "optimized_token_count": result.optimized_token_count, + "token_reduction_pct": result.token_reduction_pct, + "diff_summary": result.diff_summary, + "mlflow_run_id": result.mlflow_run_id, + "evaluator_type": getattr(result, "evaluator_type", "legacy"), + } + + # Save tool components if present + if result.components: + tool_components = { + k: v for k, v in result.components.items() if k.startswith("tools_") + } + if tool_components: + metadata["has_tool_components"] = True + # Save each tool component + for comp_name, comp_text in tool_components.items(): + comp_path = results_dir / f"optimized_{comp_name}.txt" + comp_path.write_text(comp_text) + + metadata_path = results_dir / "last_optimization.json" + metadata_path.write_text(json.dumps(metadata, indent=2)) + + return optimized_path, metadata_path + + +def load_last_result(skill_name: str) -> OptimizationResult | None: + """Load the last saved optimization result for a skill. + + Returns: + OptimizationResult reconstructed from saved files, or None if not found. + """ + results_dir = _get_results_dir(skill_name) + metadata_path = results_dir / "last_optimization.json" + optimized_path = results_dir / "optimized_SKILL.md" + + if not metadata_path.exists(): + return None + + metadata = json.loads(metadata_path.read_text()) + + # Load optimized content + optimized_content = "" + if optimized_path.exists(): + optimized_content = optimized_path.read_text() + + # Load original content + original_content = "" + skill_path = _find_skill_md(skill_name) + if skill_path: + original_content = skill_path.read_text() + + # Reconstruct tool components + components = None + if metadata.get("has_tool_components"): + components = {} + if optimized_content: + components["skill_md"] = optimized_content + for f in results_dir.glob("optimized_tools_*.txt"): + comp_name = f.stem.replace("optimized_", "") + components[comp_name] = f.read_text() + + return OptimizationResult( + skill_name=skill_name, + original_score=metadata.get("original_score", 0.0), + optimized_score=metadata.get("optimized_score", 0.0), + improvement=metadata.get("improvement", 0.0), + original_content=original_content, + optimized_content=optimized_content, + original_token_count=metadata.get("original_token_count", 0), + optimized_token_count=metadata.get("optimized_token_count", 0), + token_reduction_pct=metadata.get("token_reduction_pct", 0.0), + diff_summary=metadata.get("diff_summary", ""), + val_scores={}, + mlflow_run_id=metadata.get("mlflow_run_id"), + gepa_result=None, + components=components, + ) + + +def _review_skillbench(result: OptimizationResult) -> None: + """Print SkillBench-style effectiveness metrics.""" + print(f"\n{'=' * 60}") + print(f" Optimization Results: {result.skill_name} (skillbench)") + print(f"{'=' * 60}") + + # Aggregate effectiveness metrics from side_info + si = result.skillbench_side_info or {} + + # Compute aggregate pass rates from per-task scores + task_count = 0 + sum_with = 0.0 + sum_without = 0.0 + sum_eff = 0.0 + per_task_lines: list[str] = [] + + for task_id in sorted(si.keys()): + info = si[task_id] + scores = info.get("scores", {}) + pw = scores.get("pass_rate_with", 0.0) + pwo = scores.get("pass_rate_without", 0.0) + eff = scores.get("skill_effectiveness", 0.0) + sum_with += pw + sum_without += pwo + sum_eff += eff + task_count += 1 + + # Build per-task notes from the new Error key + error = info.get("Error", "") + notes = [] + if "NEEDS_SKILL" in error: + # Extract assertion names from Error lines + failing = [ + line.split(":", 1)[1].strip().split(" — ")[0] + for line in error.split("\n") + if line.startswith("NEEDS_SKILL") + ] + notes.append(f"NEEDS: {', '.join(failing[:3])}") + if "REGRESSION" in error: + reg_count = error.count("REGRESSION") + notes.append(f"REGRESSION x{reg_count}") + if not notes: + notes.append("OK/NEUTRAL") + note_str = f" [{'; '.join(notes)}]" + per_task_lines.append( + f" {task_id:<30s} WITH {pw:.2f} WITHOUT {pwo:.2f} " + f"delta {eff:+.2f}{note_str}" + ) + + # Aggregate rates + if task_count > 0: + agg_with = sum_with / task_count + agg_without = sum_without / task_count + agg_eff = sum_eff / task_count + else: + agg_with = agg_without = agg_eff = 0.0 + + # Score summary + improvement_sign = "+" if result.improvement >= 0 else "" + print(f" Score: {result.original_score:.3f} -> {result.optimized_score:.3f} " + f"({improvement_sign}{result.improvement:.3f})") + print(f" Skill Effectiveness: {agg_eff:.2f}") + print(f" Pass Rate (with): {agg_with:.2f}") + print(f" Pass Rate (without): {agg_without:.2f} (baseline)") + + # Token counts + reduction_sign = "+" if result.token_reduction_pct >= 0 else "" + print(f" Tokens: {result.original_token_count:,} -> {result.optimized_token_count:,} " + f"({reduction_sign}{result.token_reduction_pct:.1f}%)") + + # GEPA iterations + if result.gepa_result and hasattr(result.gepa_result, "iterations"): + print(f" Iterations: {result.gepa_result.iterations}") + + if result.mlflow_run_id: + print(f" MLflow run: {result.mlflow_run_id}") + + print() + + # Per-task breakdown + if per_task_lines: + print(" Per-task:") + for line in per_task_lines: + print(line) + print() + + # Diff summary + if result.diff_summary and result.diff_summary != "No changes": + print(" Changes:") + for line in result.diff_summary.split("\n"): + print(f" {line}") + print() + + # Detailed diff (first 50 lines) + if result.original_content != result.optimized_content: + diff_lines = list(difflib.unified_diff( + result.original_content.splitlines(keepends=True), + result.optimized_content.splitlines(keepends=True), + fromfile="original SKILL.md", + tofile="optimized SKILL.md", + n=2, + )) + if len(diff_lines) > 50: + print(f" Diff (first 50 of {len(diff_lines)} lines):") + for line in diff_lines[:50]: + print(f" {line}", end="") + print(f"\n ... ({len(diff_lines) - 50} more lines)") + else: + print(" Diff:") + for line in diff_lines: + print(f" {line}", end="") + print() + else: + print(" No changes to SKILL.md content.") + + # Validation breakdown + if result.val_scores: + print(" Validation scores by test case:") + for task_id, score in sorted(result.val_scores.items()): + status = "PASS" if score >= 0.5 else "FAIL" + print(f" {status} {task_id}: {score:.3f}") + print() + + # Auto-save result to disk + saved_skill, saved_meta = save_result(result) + if saved_skill: + print(f" Saved: {saved_skill}") + print(f" Apply: uv run python .test/scripts/optimize.py {result.skill_name} --apply-last") + elif result.original_content == result.optimized_content: + print(f" No improvement found -- nothing saved.") + print(f"{'=' * 60}\n") + + +def review_optimization(result: OptimizationResult) -> None: + """Print optimization summary for human review. + + Automatically saves the result to disk so it can be applied later + with ``--apply-last`` without re-running optimization. + + Shows: score improvement, token reduction, diff of changed sections, + per-test-case score breakdown, validation set performance. + + Delegates to ``_review_skillbench()`` when ``evaluator_type == "skillbench"``. + """ + if getattr(result, "evaluator_type", "legacy") == "skillbench": + return _review_skillbench(result) + + print(f"\n{'=' * 60}") + print(f" Optimization Results: {result.skill_name}") + print(f"{'=' * 60}") + + # Quality scores + improvement_sign = "+" if result.improvement >= 0 else "" + print(f" Quality: {result.original_score:.3f} -> {result.optimized_score:.3f} " + f"({improvement_sign}{result.improvement:.3f})") + + # Token counts + reduction_sign = "+" if result.token_reduction_pct >= 0 else "" + print(f" Tokens: {result.original_token_count:,} -> {result.optimized_token_count:,} " + f"({reduction_sign}{result.token_reduction_pct:.1f}%)") + + # Validation scores + if result.val_scores: + avg_val = sum(result.val_scores.values()) / len(result.val_scores) + print(f" Validation: avg={avg_val:.3f} ({len(result.val_scores)} cases)") + + # GEPA iterations + if result.gepa_result and hasattr(result.gepa_result, "iterations"): + print(f" Iterations: {result.gepa_result.iterations}") + + # MLflow run + if result.mlflow_run_id: + print(f" MLflow run: {result.mlflow_run_id}") + + print() + + # Diff summary + if result.diff_summary and result.diff_summary != "No changes": + print(" Changes:") + for line in result.diff_summary.split("\n"): + print(f" {line}") + print() + + # Detailed diff (first 50 lines) + if result.original_content != result.optimized_content: + diff_lines = list(difflib.unified_diff( + result.original_content.splitlines(keepends=True), + result.optimized_content.splitlines(keepends=True), + fromfile="original SKILL.md", + tofile="optimized SKILL.md", + n=2, + )) + if len(diff_lines) > 50: + print(f" Diff (first 50 of {len(diff_lines)} lines):") + for line in diff_lines[:50]: + print(f" {line}", end="") + print(f"\n ... ({len(diff_lines) - 50} more lines)") + else: + print(" Diff:") + for line in diff_lines: + print(f" {line}", end="") + print() + else: + print(" No changes to SKILL.md content.") + + # Validation breakdown + if result.val_scores: + print(" Validation scores by test case:") + for task_id, score in sorted(result.val_scores.items()): + status = "PASS" if score >= 0.5 else "FAIL" + print(f" {status} {task_id}: {score:.3f}") + print() + + # Auto-save result to disk + saved_skill, saved_meta = save_result(result) + if saved_skill: + print(f" Saved: {saved_skill}") + print(f" Apply: uv run python .test/scripts/optimize.py {result.skill_name} --apply-last") + elif result.original_content == result.optimized_content: + print(f" No improvement found — nothing saved.") + print(f"{'=' * 60}\n") + + +def apply_optimization(result: OptimizationResult) -> Path | None: + """Apply optimized SKILL.md and/or tool descriptions. + + Writes back: + - SKILL.md (if changed) + - MCP tool docstrings (if tools were included in optimization) + + Args: + result: OptimizationResult from optimize_skill() + + Returns: + Path to the updated SKILL.md (or None if tools_only) + + Raises: + ValueError: If optimization did not improve the skill + """ + if result.improvement < 0: + raise ValueError( + f"Optimization regressed quality ({result.improvement:+.3f}). " + "Refusing to apply. Use --force to override." + ) + + skill_path = None + + # Apply SKILL.md changes + if result.optimized_content and result.optimized_content != result.original_content: + skill_path = _find_skill_md(result.skill_name) + if skill_path: + skill_path.write_text(result.optimized_content) + print(f"Applied optimized SKILL.md to {skill_path}") + + # Apply tool description changes + if result.tool_map and result.components: + from .tools import parse_gepa_component, write_tool_descriptions + + all_optimized_tools = {} + for comp_name, comp_text in result.components.items(): + if comp_name.startswith("tools_"): + parsed = parse_gepa_component(comp_text) + all_optimized_tools.update(parsed) + + if all_optimized_tools: + modified = write_tool_descriptions(all_optimized_tools, result.tool_map) + if modified: + print(f"Applied optimized tool descriptions to {len(modified)} files:") + for f in modified: + print(f" {f}") + + print(f" Quality: {result.original_score:.3f} -> {result.optimized_score:.3f} " + f"({result.improvement:+.3f})") + print(f" Tokens: {result.original_token_count:,} -> {result.optimized_token_count:,} " + f"({result.token_reduction_pct:+.1f}%)") + + # Try to update baseline + try: + from ..runners.compare import save_baseline + + if result.mlflow_run_id: + save_baseline( + skill_name=result.skill_name, + run_id=result.mlflow_run_id, + metrics={"optimized_score": result.optimized_score}, + test_count=len(result.val_scores) if result.val_scores else 0, + ) + print(f" Baseline updated.") + except Exception: + pass + + return skill_path + + +def format_cost_estimate( + train_count: int, + val_count: int | None, + preset: str, + mode: str, +) -> str: + """Estimate the cost of running optimization. + + Args: + train_count: Number of training tasks + val_count: Number of validation tasks (or None) + preset: Preset name + mode: "static" or "generative" + + Returns: + Human-readable cost estimate string + """ + # Rough estimates based on preset + max_calls = {"quick": 15, "standard": 50, "thorough": 150}.get(preset, 50) + + # Each metric call runs all scorers on all train tasks + calls_per_iteration = train_count + if val_count: + calls_per_iteration += val_count + + total_scorer_calls = max_calls * calls_per_iteration + + if mode == "static": + # Static mode: ~$0.001 per scorer call (just deterministic checks) + est_cost = total_scorer_calls * 0.001 + else: + # Generative mode: ~$0.01 per call (LLM generation + scoring) + est_cost = total_scorer_calls * 0.01 + + # GEPA reflection calls + reflection_cost = max_calls * 0.02 # ~$0.02 per reflection + + total = est_cost + reflection_cost + + return ( + f"Estimated cost: ~${total:.2f}\n" + f" Scorer calls: {total_scorer_calls:,} x {'$0.001' if mode == 'static' else '$0.01'}\n" + f" Reflection calls: {max_calls} x $0.02\n" + f" Max iterations: {max_calls}" + ) diff --git a/.test/src/skill_test/optimize/runner.py b/.test/src/skill_test/optimize/runner.py new file mode 100644 index 00000000..36f44dcc --- /dev/null +++ b/.test/src/skill_test/optimize/runner.py @@ -0,0 +1,499 @@ +"""End-to-end orchestrator for GEPA skill optimization. + +Uses optimize_anything API: evaluator function + GEPAConfig. +""" + +import copy +import difflib +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal + +from gepa.optimize_anything import optimize_anything, GEPAConfig + +from ..config import SkillTestConfig +from ..runners.evaluate import setup_mlflow +from .config import get_preset, validate_reflection_context, estimate_pass_duration, DEFAULT_GEN_LM, DEFAULT_TOKEN_BUDGET +from .evaluator import ( + SKILL_KEY, + create_skill_evaluator, + count_tokens, + build_optimization_background, + _find_skill_md, + create_skillbench_evaluator, + build_skillbench_background, +) +from .splitter import create_gepa_datasets, generate_bootstrap_tasks, to_gepa_instances +from .tools import ( + extract_tool_descriptions, + tools_to_gepa_components, + get_tool_stats, +) + + +@dataclass +class OptimizationResult: + """Result of a GEPA optimization run.""" + + skill_name: str + original_score: float + optimized_score: float + improvement: float + original_content: str + optimized_content: str + original_token_count: int + optimized_token_count: int + token_reduction_pct: float + diff_summary: str + val_scores: dict[str, float] + mlflow_run_id: str | None + gepa_result: Any + components: dict[str, str] | None = None + original_components: dict[str, str] | None = None + tool_map: Any = None + evaluator_type: str = "legacy" + skillbench_side_info: dict[str, dict] | None = None + + +def _compute_diff_summary(original: str, optimized: str) -> str: + """Generate a human-readable summary of changes.""" + original_lines = original.splitlines(keepends=True) + optimized_lines = optimized.splitlines(keepends=True) + diff = list(difflib.unified_diff(original_lines, optimized_lines, fromfile="original", tofile="optimized", n=1)) + + if not diff: + return "No changes" + + added = sum(1 for line in diff if line.startswith("+") and not line.startswith("+++")) + removed = sum(1 for line in diff if line.startswith("-") and not line.startswith("---")) + + parts = [] + if added: + parts.append(f"+{added} lines added") + if removed: + parts.append(f"-{removed} lines removed") + + changed_sections = set() + for line in diff: + content = line[1:].strip() if line.startswith(("+", "-")) and not line.startswith(("+++", "---")) else "" + if content.startswith("#"): + changed_sections.add(content) + + summary = ", ".join(parts) + if changed_sections: + sections = "\n".join(f" ~ {s}" for s in sorted(changed_sections)[:10]) + summary += f"\n\nChanged sections:\n{sections}" + + return summary + + +def _evaluate_on_tasks(evaluator, candidate, tasks): + """Run evaluator on tasks and return mean score, per-task scores, and per-task side_info. + + Returns: + (mean_score, per_task_scores, side_info_by_id, side_info_by_input) + - per_task_scores: {task_id: float} + - side_info_by_id: {task_id: side_info_dict} -- for background builder + - side_info_by_input: {input_text: side_info_dict} -- for evaluator baseline + """ + gepa_instances = to_gepa_instances(tasks) + per_task = {} + side_info_by_id = {} + side_info_by_input = {} + for i, inst in enumerate(gepa_instances): + score, side_info = evaluator(candidate, inst) + task_id = tasks[i].get("id", f"task_{i}") + per_task[task_id] = score + side_info_by_id[task_id] = side_info + side_info_by_input[inst.get("input", f"task_{i}")] = side_info + mean = sum(per_task.values()) / len(per_task) if per_task else 0.0 + return mean, per_task, side_info_by_id, side_info_by_input + + +def optimize_skill( + skill_name: str, + mode: Literal["static", "generative"] = "static", + preset: Literal["quick", "standard", "thorough"] = "standard", + task_lm: str | None = None, + gen_model: str | None = None, + reflection_lm: str | None = None, + include_tools: bool = False, + tool_modules: list[str] | None = None, + tools_only: bool = False, + dry_run: bool = False, + max_passes: int = 5, + max_metric_calls: int | None = None, + evaluator_type: Literal["legacy", "skillbench"] = "skillbench", + token_budget: int | None = None, + use_judges: bool = False, +) -> OptimizationResult: + """Run end-to-end GEPA optimization on a skill and/or tools. + + Uses optimize_anything API with a simple evaluator function. + Runs up to ``max_passes`` optimization passes per component, feeding + each pass's best candidate as the seed for the next. Stops early + when a pass produces no score improvement. + + Args: + skill_name: Name of the skill to optimize + mode: "static" or "generative" + preset: "quick" (15), "standard" (50), "thorough" (150) + task_lm: LLM for generative mode (deprecated, use gen_model) + gen_model: LLM for generative evaluation (default: GEPA_GEN_LM env) + reflection_lm: Override reflection LM + include_tools: Include MCP tool descriptions as additional components + tool_modules: Specific tool modules (None = all) + tools_only: Optimize ONLY tool descriptions + dry_run: Show config without running + max_passes: Maximum optimization passes (default 5). Each pass + re-seeds from the previous best and runs a full GEPA cycle. + evaluator_type: "skillbench" (measures skill effectiveness delta) or + "legacy" (weighted scoring with keyword/token efficiency). + token_budget: Hard token ceiling; candidates exceeding this are penalized. + use_judges: Enable MLflow LLM judges (Correctness + Guidelines) for NL feedback. + """ + # 1. Load SKILL.md + skill_path = _find_skill_md(skill_name) + if not tools_only and skill_path is None: + raise FileNotFoundError(f"Could not find SKILL.md for '{skill_name}'") + + original_content = skill_path.read_text() if skill_path else "" + + # 1b. Load MCP tool descriptions + tool_map = None + tool_components: dict[str, str] = {} + if include_tools or tools_only: + tool_map = extract_tool_descriptions(modules=tool_modules) + tool_components = tools_to_gepa_components(tool_map, per_module=True) + stats = get_tool_stats() + print(f"Tool modules: {stats['modules']}, tools: {stats['total_tools']}, " + f"description chars: {stats['total_description_chars']:,}") + + # 2. Build seed_candidate (multi-component dict) + seed_candidate: dict[str, str] = {} + original_token_counts: dict[str, int] = {} + + if not tools_only: + seed_candidate[SKILL_KEY] = original_content + original_token_counts[SKILL_KEY] = count_tokens(original_content) + + for comp_name, comp_text in tool_components.items(): + seed_candidate[comp_name] = comp_text + original_token_counts[comp_name] = count_tokens(comp_text) + + total_original_tokens = sum(original_token_counts.values()) + + # Default --include-tools for SkillBench (tools are the primary token consumer) + if evaluator_type == "skillbench" and not tools_only and not include_tools and not tool_components: + include_tools = True + tool_map = extract_tool_descriptions(modules=tool_modules) + tool_components = tools_to_gepa_components(tool_map, per_module=True) + stats = get_tool_stats() + print(f"[SkillBench] Auto-including tools: {stats['modules']} modules, " + f"{stats['total_tools']} tools, {stats['total_description_chars']:,} chars") + for comp_name, comp_text in tool_components.items(): + seed_candidate[comp_name] = comp_text + original_token_counts[comp_name] = count_tokens(comp_text) + total_original_tokens = sum(original_token_counts.values()) + + # Resolve token budget + token_budget = token_budget or DEFAULT_TOKEN_BUDGET + + # 3. Load datasets + try: + train, val = create_gepa_datasets(skill_name) + except FileNotFoundError: + train, val = [], None + + if not train: + train = generate_bootstrap_tasks(skill_name) + val = None + print(f"No test cases found. Using {len(train)} auto-generated tasks.") + + # 4. Build evaluator + # Resolve gen_model: explicit arg > task_lm (deprecated) > env default + effective_gen_model = gen_model or task_lm or DEFAULT_GEN_LM + if effective_gen_model: + print(f"Generation model: {effective_gen_model}") + judges_label = " (with LLM judges)" if use_judges else "" + print(f"Evaluator: {evaluator_type}{judges_label}") + + if evaluator_type == "skillbench": + if not effective_gen_model: + raise ValueError( + "SkillBench evaluator requires a gen_model. " + "Pass --gen-model or set GEPA_GEN_LM env var." + ) + evaluator = create_skillbench_evaluator( + skill_name, + gen_model=effective_gen_model, + original_token_counts=original_token_counts, + token_budget=token_budget, + use_judges=use_judges, + ) + else: + evaluator = create_skill_evaluator( + skill_name, mode=mode, task_lm=task_lm, + gen_model=effective_gen_model, + original_token_counts=original_token_counts, + ) + + # 5. Get config (scaled by component count) + num_components = len(seed_candidate) + config = get_preset( + preset, + reflection_lm=reflection_lm, + num_components=num_components, + max_metric_calls_override=max_metric_calls, + ) + print(f"Reflection model: {config.reflection.reflection_lm}") + + # 5b. Validate reflection model context window + validate_reflection_context( + config.reflection.reflection_lm, total_original_tokens, + ) + + # Dry run + if dry_run: + print(f"\n=== Dry Run: {skill_name} ({evaluator_type}) ===") + if not tools_only: + print(f"SKILL.md path: {skill_path}") + print(f"Components: {list(seed_candidate.keys())}") + print(f"Total original tokens: {total_original_tokens:,}") + for comp, tokens in original_token_counts.items(): + print(f" {comp}: {tokens:,} tokens") + print(f"Train tasks: {len(train)}") + print(f"Val tasks: {len(val) if val else 'None (single-task mode)'}") + print(f"Mode: {mode}") + print(f"Generation model: {effective_gen_model or 'None (static only)'}") + print(f"Evaluator type: {evaluator_type}") + print(f"Preset: {preset} (max_metric_calls={config.engine.max_metric_calls}, " + f"scaled for {num_components} component(s))") + print(f"Max passes: {max_passes}") + print(f"Reflection LM: {config.reflection.reflection_lm}") + + original_score, original_per_task, si_by_id, _ = _evaluate_on_tasks( + evaluator, seed_candidate, train + ) + print(f"Current score: {original_score:.3f}") + + # Show per-task baseline in dry-run output + for task_id, score in original_per_task.items(): + print(f" {task_id}: {score:.3f}") + + if evaluator_type == "skillbench": + background = build_skillbench_background( + skill_name, total_original_tokens, + component_names=list(seed_candidate.keys()), + baseline_scores=original_per_task, + baseline_side_info=si_by_id, + token_budget=token_budget, + use_judges=use_judges, + ) + else: + background = build_optimization_background( + skill_name, total_original_tokens, + component_names=list(seed_candidate.keys()), + baseline_scores=original_per_task, + baseline_side_info=si_by_id, + ) + print(f"\nBackground preview:\n{background[:500]}...") + + return OptimizationResult( + skill_name=skill_name, + original_score=original_score, + optimized_score=original_score, + improvement=0.0, + original_content=original_content, + optimized_content=original_content, + original_token_count=total_original_tokens, + optimized_token_count=total_original_tokens, + token_reduction_pct=0.0, + diff_summary="Dry run - no optimization performed", + val_scores={}, + mlflow_run_id=None, + gepa_result=None, + components=dict(seed_candidate), + original_components=dict(seed_candidate), + tool_map=tool_map, + evaluator_type=evaluator_type, + skillbench_side_info=si_by_id if evaluator_type == "skillbench" else None, + ) + + # Evaluate original and capture per-task detail for baseline context + original_score, original_per_task, si_by_id, si_by_input = _evaluate_on_tasks( + evaluator, seed_candidate, train + ) + + # Set baseline on evaluator so future calls include regression/improvement info + if hasattr(evaluator, "set_baseline"): + evaluator.set_baseline(si_by_input) + + # 6. Build background (with baseline scores) and objective + if evaluator_type == "skillbench": + background = build_skillbench_background( + skill_name, total_original_tokens, + component_names=list(seed_candidate.keys()), + baseline_scores=original_per_task, + baseline_side_info=si_by_id, + token_budget=token_budget, + use_judges=use_judges, + ) + objective = ( + f"Refine and improve the existing '{skill_name}' skill. " + "Score is based on SKILL EFFECTIVENESS (45%) and TOKEN EFFICIENCY (25%). " + "Adding content the agent already knows does NOT help and costs tokens. " + "Focus on what the agent would otherwise get wrong. " + "Be concise — remove redundant examples and verbose explanations." + ) + else: + background = build_optimization_background( + skill_name, total_original_tokens, + component_names=list(seed_candidate.keys()), + baseline_scores=original_per_task, + baseline_side_info=si_by_id, + ) + objective = ( + f"Refine and improve the existing '{skill_name}' skill. " + "Preserve patterns and examples that already score well. " + "Focus on fixing scorer failures and reducing token count without sacrificing correctness. " + "Higher quality scores and fewer tokens are both better." + ) + + # 7. Convert datasets to GEPA format + trainset = to_gepa_instances(train) + valset = to_gepa_instances(val) if val else None + + # 8. Multi-pass optimization loop + # Each pass feeds the previous best as the new seed. + # Stops early when a pass produces no score improvement. + current_seed = dict(seed_candidate) + best = dict(seed_candidate) + best_score = original_score + last_result = None + total_metric_calls = 0 + improvement_threshold = 0.0005 # minimum improvement to continue + + print(f"\n Starting multi-pass optimization (up to {max_passes} passes, " + f"{num_components} component(s), {config.engine.max_metric_calls} metric calls/pass)") + + # Print estimated time per pass + est_secs = estimate_pass_duration( + config.engine.max_metric_calls, + config.reflection.reflection_lm, + total_original_tokens, + num_dataset_examples=len(train), + ) + if est_secs is not None: + est_mins = est_secs / 60 + if est_mins > 5: + print(f" Estimated ~{est_mins:.0f} min/pass ({est_mins * max_passes:.0f} min total for {max_passes} passes)") + + for pass_num in range(1, max_passes + 1): + print(f"\n --- Pass {pass_num}/{max_passes} (best score so far: {best_score:.4f}) ---") + + # Deep-copy config: optimize_anything mutates + # config.reflection.reflection_prompt_template when objective/background + # are provided, which would cause a "mutually exclusive" error on pass 2+. + pass_config = copy.deepcopy(config) + + result = optimize_anything( + seed_candidate=current_seed, + evaluator=evaluator, + dataset=trainset, + valset=valset, + objective=objective, + background=background, + config=pass_config, + ) + total_metric_calls += result.total_metric_calls or 0 + + # Evaluate this pass's best candidate + candidate = result.best_candidate + pass_score, _, _, _ = _evaluate_on_tasks(evaluator, candidate, train) + improvement = pass_score - best_score + + print(f" Pass {pass_num} score: {pass_score:.4f} " + f"(delta: {'+' if improvement >= 0 else ''}{improvement:.4f})") + + if pass_score > best_score + improvement_threshold: + best = dict(candidate) + best_score = pass_score + last_result = result + # Use the improved candidate as seed for next pass + current_seed = dict(candidate) + else: + print(f" No significant improvement in pass {pass_num} -- stopping early.") + if last_result is None: + last_result = result + break + else: + print(f" Completed all {max_passes} passes.") + + if last_result is None: + last_result = result + + # 9. Extract results + optimized_content = best.get(SKILL_KEY, original_content) + optimized_token_count = sum(count_tokens(v) for v in best.values()) + + optimized_score = best_score + + val_scores: dict[str, float] = {} + if val: + _, val_scores, _, _ = _evaluate_on_tasks(evaluator, best, val) + + token_reduction_pct = ( + (total_original_tokens - optimized_token_count) / total_original_tokens * 100 + if total_original_tokens > 0 else 0.0 + ) + + diff_summary = _compute_diff_summary(original_content, optimized_content) + + # 10. MLflow logging (best-effort) + mlflow_run_id = None + try: + import mlflow + stc = SkillTestConfig() + setup_mlflow(stc) + with mlflow.start_run(run_name=f"{skill_name}_optimize_{preset}"): + mlflow.set_tags({"optimizer": "gepa", "skill_name": skill_name, "preset": preset, "mode": mode, "evaluator_type": evaluator_type}) + mlflow.log_metrics({ + "original_score": original_score, + "optimized_score": optimized_score, + "improvement": optimized_score - original_score, + "original_tokens": float(total_original_tokens), + "optimized_tokens": float(optimized_token_count), + "token_reduction_pct": token_reduction_pct, + "total_metric_calls": float(total_metric_calls), + }) + mlflow_run_id = mlflow.active_run().info.run_id + except Exception: + pass + + # Capture final side_info for skillbench review output + final_si_by_id = None + if evaluator_type == "skillbench": + _, _, final_si_by_id, _ = _evaluate_on_tasks(evaluator, best, train) + + return OptimizationResult( + skill_name=skill_name, + original_score=original_score, + optimized_score=optimized_score, + improvement=optimized_score - original_score, + original_content=original_content, + optimized_content=optimized_content, + original_token_count=total_original_tokens, + optimized_token_count=optimized_token_count, + token_reduction_pct=token_reduction_pct, + diff_summary=diff_summary, + val_scores=val_scores, + mlflow_run_id=mlflow_run_id, + gepa_result=last_result, + components=dict(best), + original_components=dict(seed_candidate), + tool_map=tool_map, + evaluator_type=evaluator_type, + skillbench_side_info=final_si_by_id, + ) diff --git a/.test/src/skill_test/optimize/skillbench_evaluator.py b/.test/src/skill_test/optimize/skillbench_evaluator.py new file mode 100644 index 00000000..8ddc3290 --- /dev/null +++ b/.test/src/skill_test/optimize/skillbench_evaluator.py @@ -0,0 +1,697 @@ +"""SkillBench-inspired evaluator: measure skill effectiveness via WITH vs WITHOUT comparison. + +Evaluates skills by measuring agent performance WITH the skill vs WITHOUT it +on real tasks, then computing a skill effectiveness delta. This replaces the +5-layer weighted scoring with a 3-phase approach: + + Phase 1: WITH-SKILL -- LLM generates response with SKILL.md in context + Phase 2: WITHOUT-SKILL -- LLM generates response with NO skill (cached once) + Phase 3: COMPUTE -- binary pass/fail assertions on both, derive effectiveness + +Scoring weights (default / with --use-judges): + 45% / 35% Skill Effectiveness (delta: pass_rate_with - pass_rate_without) + 25% / 25% Absolute Quality (pass_rate_with_skill) + 0% / 10% Judge Quality (LLM judge pass rate — 6 judges, see below) + 5% / 5% Structure (syntax validity) + 25% / 25% Token Efficiency (smaller candidates score higher) + +LLM Judges (when --use-judges is enabled): + 1. RelevanceToQuery — does the response address the user's input? (always runs) + 2. Completeness — does the response fully answer all parts? (always runs) + 3. Correctness — are expected facts present? (requires expected_facts) + 4. Guidelines — does response follow per-test rules? (requires guidelines) + 5. ExpectationsGuidelines — combined facts+guidelines check (when both present) + 6. Custom skill judge — MemAlign-inspired domain judge with skill-specific + evaluation criteria extracted from ground_truth guidelines +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import threading +import time +from typing import Any, Callable + +import litellm +from mlflow.entities import Feedback + +from ..scorers.universal import python_syntax, sql_syntax, no_hallucinated_apis +from .assertions import AssertionResult, run_all_assertions +from .asi import skillbench_to_asi + +logger = logging.getLogger(__name__) + + +def _prompt_hash(prompt: str) -> str: + """Stable hash for caching baseline results by prompt.""" + return hashlib.sha256(prompt.encode()).hexdigest()[:16] + + +class _RateLimiter: + """Thread-safe token-bucket rate limiter for LLM API calls. + + Limits both concurrency (via semaphore) and request rate (via + minimum inter-call spacing). When --include-tools sends large + contexts to Opus, this prevents bursts that exceed + token-per-minute quotas. + """ + + def __init__(self, max_concurrent: int = 2, min_interval: float = 1.0): + self._semaphore = threading.Semaphore(max_concurrent) + self._min_interval = min_interval + self._lock = threading.Lock() + self._last_call: float = 0.0 + + def acquire(self) -> None: + self._semaphore.acquire() + with self._lock: + now = time.monotonic() + wait = self._last_call + self._min_interval - now + if wait > 0: + time.sleep(wait) + self._last_call = time.monotonic() + + def release(self) -> None: + self._semaphore.release() + + +# Module-level rate limiter shared across evaluator instances. +_rate_limiter = _RateLimiter(max_concurrent=2, min_interval=1.0) + + +def _completion_with_backoff(*, max_retries: int = 6, **kwargs) -> Any: + """Call litellm.completion with explicit exponential backoff for rate limits. + + This is a safety net on top of litellm's built-in retries. litellm's + global num_retries handles most transient errors, but sustained + token-per-minute exhaustion on Opus can outlast them. This wrapper + adds longer waits between retry bursts. + """ + last_err: Exception | None = None + for attempt in range(max_retries + 1): + if attempt > 0: + delay = min(2 ** attempt, 60) # 2, 4, 8, 16, 32, 60 + logger.warning( + "Rate limited (attempt %d/%d), backing off %.0fs", + attempt, max_retries, delay, + ) + time.sleep(delay) + _rate_limiter.acquire() + try: + result = litellm.completion(**kwargs) + return result + except litellm.RateLimitError as e: + last_err = e + finally: + _rate_limiter.release() + raise last_err # type: ignore[misc] + + +def _run_structure_scorers(text: str) -> float: + """Run structure validation scorers on text, return 0.0-1.0 composite.""" + outputs = {"response": text} + scores: list[float] = [] + for scorer_fn in [python_syntax, sql_syntax, no_hallucinated_apis]: + try: + result = scorer_fn(outputs=outputs) + if isinstance(result, list): + for fb in result: + if fb.value == "yes": + scores.append(1.0) + elif fb.value == "no": + scores.append(0.0) + # skip doesn't count + elif isinstance(result, Feedback): + if result.value == "yes": + scores.append(1.0) + elif result.value == "no": + scores.append(0.0) + except Exception: + pass + return sum(scores) / len(scores) if scores else 1.0 + + +class SkillBenchEvaluator: + """GEPA-compatible evaluator using SkillBench WITH vs WITHOUT methodology. + + Args: + gen_model: LLM model for generating responses. Required -- no silent fallback. + scorer_config: Optional scorer config (unused, kept for interface compat). + original_token_counts: Token counts of original artifacts for efficiency scoring. + token_budget: Hard token ceiling; candidates exceeding this are penalized. + """ + + def __init__( + self, + gen_model: str, + scorer_config: dict[str, Any] | None = None, + original_token_counts: dict[str, int] | None = None, + token_budget: int | None = None, + use_judges: bool = False, + skill_guidelines: list[str] | None = None, + ): + if not gen_model: + raise ValueError( + "SkillBench evaluator requires a gen_model. " + "Pass --gen-model or set GEPA_GEN_LM env var." + ) + self.gen_model = gen_model + self._baseline_cache: dict[str, list[AssertionResult]] = {} + self._baseline_response_cache: dict[str, str] = {} + self._original_token_counts = original_token_counts or {} + self._total_original_tokens = sum(self._original_token_counts.values()) + self._token_budget = token_budget + self._use_judges = use_judges + self._skill_guidelines = skill_guidelines or [] + + def _generate_response(self, prompt: str, skill_context: str | None = None) -> str: + """Generate a response with or without skill context. + + Uses temperature=0 for deterministic outputs. + """ + messages = [] + if skill_context: + messages.append({ + "role": "system", + "content": ( + "Use ONLY the following skill documentation to answer " + "the user's question. Do not use any other knowledge.\n\n" + f"{skill_context}" + ), + }) + messages.append({"role": "user", "content": prompt}) + + resp = _completion_with_backoff( + model=self.gen_model, + messages=messages, + temperature=0, + ) + return resp.choices[0].message.content or "" + + def _get_baseline( + self, prompt: str, expectations: dict[str, Any], + ) -> tuple[list[AssertionResult], str]: + """Get WITHOUT-skill baseline, computing once then caching. + + Returns: + Tuple of (assertion_results, raw_response). + """ + key = _prompt_hash(prompt) + if key not in self._baseline_cache: + response = self._generate_response(prompt, skill_context=None) + self._baseline_response_cache[key] = response + self._baseline_cache[key] = run_all_assertions(response, expectations) + return self._baseline_cache[key], self._baseline_response_cache[key] + + def _run_llm_judges( + self, response: str, expectations: dict, prompt: str, reference: str, + ) -> list[dict]: + """Run MLflow LLM judges and return NL feedback dicts. + + Runs up to 6 judges for comprehensive evaluation: + 1. RelevanceToQuery — always (no ground truth needed) + 2. Completeness — always (no ground truth needed) + 3. Correctness — when expected_facts present + 4. Guidelines — when per-test guidelines present + 5. ExpectationsGuidelines — when BOTH facts + guidelines present + 6. Custom skill judge — MemAlign-inspired domain-specific judge + + Each judge returns {judge, passed, rationale} for GEPA reflection. + """ + results = [] + + inputs_dict = {"prompt": prompt} + outputs_dict = {"response": response} + + facts = expectations.get("expected_facts", []) + guidelines = expectations.get("guidelines", []) + + # 1. RelevanceToQuery: does response address the user's input? + # Always runs — catches off-topic or confused responses that binary + # assertions can't detect. + try: + from mlflow.genai.scorers import RelevanceToQuery + judge = RelevanceToQuery() + fb = judge(inputs=inputs_dict, outputs=outputs_dict) + results.append({ + "judge": "relevance", + "passed": fb.value == "yes", + "rationale": fb.rationale or "", + }) + except Exception as e: + logger.debug("RelevanceToQuery judge failed: %s", e) + + # 2. Completeness: does response fully answer all parts of the prompt? + # Always runs — catches partial answers where a response mentions the + # right topic but skips sub-questions. + try: + from mlflow.genai.scorers import Completeness + judge = Completeness() + fb = judge(inputs=inputs_dict, outputs=outputs_dict) + results.append({ + "judge": "completeness", + "passed": fb.value == "yes", + "rationale": fb.rationale or "", + }) + except Exception as e: + logger.debug("Completeness judge failed: %s", e) + + # 3. Correctness: are expected facts present in the response? + if facts and reference: + try: + from mlflow.genai.scorers import Correctness + judge = Correctness() + fb = judge( + inputs=inputs_dict, + outputs=outputs_dict, + expectations={"expected_facts": facts}, + ) + results.append({ + "judge": "correctness", + "passed": fb.value == "yes", + "rationale": fb.rationale or "", + }) + except Exception as e: + logger.debug("Correctness judge failed: %s", e) + + # 4. Guidelines: per-test custom evaluation rules + if guidelines: + try: + from mlflow.genai.scorers import Guidelines + judge = Guidelines( + name="skill_guidelines", guidelines=guidelines, + ) + fb = judge(inputs=inputs_dict, outputs=outputs_dict) + results.append({ + "judge": "guidelines", + "passed": fb.value == "yes", + "rationale": fb.rationale or "", + }) + except Exception as e: + logger.debug("Guidelines judge failed: %s", e) + + # 5. ExpectationsGuidelines: combined facts + guidelines in one pass. + # When both are available, this judge evaluates them together and + # produces a single holistic rationale — often more insightful than + # separate Correctness + Guidelines calls. + if facts and guidelines: + try: + from mlflow.genai.scorers import ExpectationsGuidelines + judge = ExpectationsGuidelines( + name="expectations_guidelines", + guidelines=guidelines, + ) + fb = judge( + inputs=inputs_dict, + outputs=outputs_dict, + expectations={"expected_facts": facts}, + ) + results.append({ + "judge": "expectations_guidelines", + "passed": fb.value == "yes", + "rationale": fb.rationale or "", + }) + except Exception as e: + logger.debug("ExpectationsGuidelines judge failed: %s", e) + + # 6. Custom skill judge (MemAlign-inspired): uses domain-specific + # evaluation principles extracted from the skill's ground_truth + # guidelines. This mimics MemAlign's "semantic memory" — a set of + # generalizable principles learned from labeled examples — without + # requiring MLflow trace alignment infrastructure. + if self._skill_guidelines: + try: + self._run_custom_skill_judge( + results, prompt, response, facts, + ) + except Exception as e: + logger.debug("Custom skill judge failed: %s", e) + + return results + + def _run_custom_skill_judge( + self, + results: list[dict], + prompt: str, + response: str, + facts: list[str], + ) -> None: + """Run a MemAlign-inspired custom judge with skill-domain principles. + + Uses ``mlflow.genai.judges.make_judge`` to create a domain-specific + judge whose instructions incorporate evaluation principles extracted + from all ground_truth.yaml guidelines across the skill's test cases. + + This is the "semantic memory" component of the MemAlign approach: + rather than aligning from traces, we extract and deduplicate the + skill's evaluation principles upfront and inject them as judge + instructions. + + Appends result dicts directly to ``results``. + """ + from mlflow.genai.judges import make_judge + + # Build instruction prompt from collected skill guidelines + principles = "\n".join( + f"- {g}" for g in self._skill_guidelines + ) + + judge = make_judge( + name="skill_domain_judge", + instructions=( + "You are an expert evaluator for a Databricks skill. " + "Evaluate whether the response correctly follows the " + "domain-specific principles below.\n\n" + "## Domain Principles (from skill evaluation criteria)\n" + f"{principles}\n\n" + "## Evaluation\n" + "Given the user question and response, determine if the " + "response adheres to the domain principles above. Focus on " + "technical accuracy, correct API usage, and completeness " + "of the domain-specific guidance.\n\n" + "Question: {{ inputs.prompt }}\n" + "Response: {{ outputs.response }}" + ), + feedback_value_type=bool, + ) + + fb = judge( + inputs={"prompt": prompt}, + outputs={"response": response}, + ) + results.append({ + "judge": "skill_domain", + "passed": fb.value == "yes" if isinstance(fb.value, str) else bool(fb.value), + "rationale": fb.rationale or "", + }) + + def __call__( + self, candidate: dict[str, str], example: dict, + ) -> tuple[float, dict]: + """Evaluate a candidate skill against a single task example. + + GEPA-compatible signature: (candidate, example) -> (score, side_info) + """ + skill_md = candidate.get("skill_md", "") + + # Build combined context: skill + tool descriptions + tool_parts = [] + for key in sorted(candidate): + if key.startswith("tools_"): + tool_parts.append(candidate[key]) + + full_context = skill_md + if tool_parts: + full_context += "\n\n## Available MCP Tools\n\n" + "\n\n".join(tool_parts) + + prompt = example.get("input", "") + + # Decode expectations + expectations: dict[str, Any] = {} + expectations_json = example.get("additional_context", {}).get("expectations", "") + if expectations_json: + try: + expectations = json.loads(expectations_json) + except (json.JSONDecodeError, TypeError): + pass + + # If no prompt or no expectations, return minimal score + if not prompt or not expectations: + return 0.0, {"_error": "No prompt or expectations for this task"} + + # Phase 1: WITH skill + tools + with_response = self._generate_response(prompt, skill_context=full_context) + with_results = run_all_assertions(with_response, expectations) + + # Phase 2: WITHOUT skill (cached) + without_results, without_response = self._get_baseline(prompt, expectations) + + # Phase 3: Compute scores + total = len(with_results) + if total == 0: + return 0.0, {"_error": "No assertions to evaluate"} + + pass_with = sum(r.passed for r in with_results) / total + pass_without = sum(r.passed for r in without_results) / total + effectiveness = pass_with - pass_without + + # LLM judge scoring (optional) + judge_results = [] + if self._use_judges: + judge_results = self._run_llm_judges( + with_response, expectations, prompt, + reference=example.get("answer", ""), + ) + + judge_pass_rate = 1.0 # default when judges disabled + if judge_results: + judge_pass_rate = sum(j["passed"] for j in judge_results) / len(judge_results) + + # Structure validation on the skill itself + structure = _run_structure_scorers(skill_md) if skill_md else 1.0 + + # Token efficiency scoring + from .evaluator import count_tokens + total_candidate_tokens = sum(count_tokens(v) for v in candidate.values()) + + if self._total_original_tokens > 0: + ratio = total_candidate_tokens / self._total_original_tokens + if ratio <= 1.0: + efficiency = 1.0 + 0.15 * (1.0 - ratio) # Bonus for smaller + else: + efficiency = max(0.0, 2.0 - ratio) # Penalty for growth + + # Hard penalty if over explicit budget + if self._token_budget and total_candidate_tokens > self._token_budget: + over_ratio = total_candidate_tokens / self._token_budget + efficiency = min(efficiency, max(0.0, 2.0 - over_ratio)) + else: + efficiency = 1.0 + + # Weighted final score + # When judges enabled: steal 10% from effectiveness for judge_quality + # When judges disabled: 10% folds back into effectiveness (original weights) + if self._use_judges and judge_results: + final_score = ( + 0.35 * max(0.0, effectiveness) + + 0.25 * pass_with + + 0.10 * judge_pass_rate + + 0.05 * structure + + 0.25 * efficiency + ) + else: + final_score = ( + 0.45 * max(0.0, effectiveness) + + 0.25 * pass_with + + 0.05 * structure + + 0.25 * efficiency + ) + + # Build side info via skillbench_to_asi + score_breakdown = { + "skill_effectiveness": effectiveness, + "pass_rate_with": pass_with, + "pass_rate_without": pass_without, + "structure": structure, + "token_efficiency": efficiency, + "final": final_score, + } + if self._use_judges and judge_results: + score_breakdown["judge_quality"] = judge_pass_rate + + reference_answer = example.get("answer", "") + + side_info = skillbench_to_asi( + with_results, + without_results, + task_prompt=prompt, + scores=score_breakdown, + with_response=with_response, + without_response=without_response, + reference_answer=reference_answer or None, + candidate=candidate, + ) + + # Feed judge rationale into side_info for GEPA reflection + if judge_results: + failing_judges = [j for j in judge_results if not j["passed"]] + if failing_judges: + rationale_lines = [ + f"{j['judge']}: {j['rationale'][:200]}" for j in failing_judges + ] + side_info["Judge_feedback"] = "\n".join(rationale_lines) + + # Also route to skill_md_specific_info for component targeting + if "skill_md_specific_info" not in side_info: + side_info["skill_md_specific_info"] = {} + side_info["skill_md_specific_info"]["Judge_analysis"] = "\n".join(rationale_lines) + + # Add token counts to side_info for GEPA Pareto tracking + side_info["token_counts"] = { + "candidate_total": total_candidate_tokens, + "original_total": self._total_original_tokens, + } + if self._token_budget: + side_info["token_counts"]["budget"] = self._token_budget + + return final_score, side_info + + +def _collect_skill_guidelines(skill_name: str) -> list[str]: + """Collect and deduplicate all guidelines from a skill's ground_truth.yaml. + + These form the "semantic memory" for the MemAlign-inspired custom judge. + Returns a deduplicated list of guideline strings. + """ + from pathlib import Path + import yaml + + gt_path = Path(".test/skills") / skill_name / "ground_truth.yaml" + if not gt_path.exists(): + return [] + + try: + with open(gt_path) as f: + data = yaml.safe_load(f) or {} + except Exception: + return [] + + seen: set[str] = set() + guidelines: list[str] = [] + for tc in data.get("test_cases", []): + for g in tc.get("expectations", {}).get("guidelines", []): + g_norm = g.strip() + if g_norm and g_norm not in seen: + seen.add(g_norm) + guidelines.append(g_norm) + + return guidelines + + +def create_skillbench_evaluator( + skill_name: str, + gen_model: str, + original_token_counts: dict[str, int] | None = None, + token_budget: int | None = None, + use_judges: bool = False, +) -> Callable: + """Factory for SkillBench-style evaluator. + + Returns a GEPA-compatible callable: (candidate, example) -> (score, side_info) + + When ``use_judges`` is enabled, collects all guidelines from the skill's + ground_truth.yaml to build a MemAlign-inspired custom domain judge + alongside the standard MLflow predefined judges. + + Args: + skill_name: Name of the skill being evaluated. + gen_model: LLM model for generating responses. Required. + original_token_counts: Token counts of original artifacts for efficiency scoring. + token_budget: Hard token ceiling; candidates exceeding this are penalized. + use_judges: Enable MLflow LLM judges for NL feedback. + """ + skill_guidelines: list[str] = [] + if use_judges: + skill_guidelines = _collect_skill_guidelines(skill_name) + if skill_guidelines: + logger.info( + "Loaded %d domain guidelines for custom skill judge", + len(skill_guidelines), + ) + + return SkillBenchEvaluator( + gen_model=gen_model, + original_token_counts=original_token_counts, + token_budget=token_budget, + use_judges=use_judges, + skill_guidelines=skill_guidelines, + ) + + +def build_skillbench_background( + skill_name: str, + original_token_count: int, + component_names: list[str] | None = None, + baseline_scores: dict[str, float] | None = None, + baseline_side_info: dict[str, dict] | None = None, + token_budget: int | None = None, + use_judges: bool = False, +) -> str: + """Build concise GEPA reflection context for SkillBench optimization. + + Kept short so GEPA's reflection LM spends its context on the per-example + diagnostics (Error/Expected/Actual) rather than methodology. + """ + # Concise per-task baseline summary + baseline_desc = "" + if baseline_scores: + mean_score = sum(baseline_scores.values()) / len(baseline_scores) + baseline_desc = f"\nBASELINE: mean {mean_score:.3f} across {len(baseline_scores)} tasks." + + if baseline_side_info: + needs_skill_ids = [] + regression_ids = [] + for tid, info in baseline_side_info.items(): + error = info.get("Error", "") + if "NEEDS_SKILL" in error: + needs_skill_ids.append(tid) + if "REGRESSION" in error: + regression_ids.append(tid) + if needs_skill_ids: + baseline_desc += ( + f"\n NEEDS_SKILL ({len(needs_skill_ids)} tasks): " + f"{', '.join(needs_skill_ids[:5])}" + ) + if regression_ids: + baseline_desc += ( + f"\n REGRESSION ({len(regression_ids)} tasks): " + f"{', '.join(regression_ids[:5])}" + ) + + components_desc = "" + if component_names and any(c.startswith("tools_") for c in component_names): + tool_modules = [c.replace("tools_", "") for c in component_names if c.startswith("tools_")] + components_desc = ( + f"\nAlso optimizing MCP tool descriptions for: {', '.join(tool_modules)}. " + "Tool descriptions are included in the agent's context alongside the skill. " + "The agent uses them to decide which tools to call and how. " + "Keep docstrings accurate and concise — every token counts toward the budget." + ) + + # Token efficiency guidance + token_desc = ( + f"\nTOKEN EFFICIENCY (25% of score): Current artifacts total {original_token_count:,} tokens. " + "Smaller candidates score HIGHER. Remove redundant examples, consolidate " + "overlapping sections, eliminate verbose explanations. Be ruthlessly concise." + ) + if token_budget: + token_desc += ( + f"\nTOKEN BUDGET: {token_budget:,} tokens. Candidates exceeding this " + "are heavily penalized. Stay well under the budget." + ) + + judge_desc = "" + if use_judges: + judge_desc = ( + "\nLLM JUDGES (10% of score): Six judges provide natural-language feedback " + "in 'Judge_feedback': RelevanceToQuery (on-topic?), Completeness (fully " + "answered?), Correctness (facts present?), Guidelines (rules followed?), " + "ExpectationsGuidelines (combined holistic check), and a custom skill " + "domain judge (MemAlign-inspired, evaluates domain-specific principles). " + "Use judge rationale to understand WHY responses fail — it's more " + "actionable than binary NEEDS_SKILL/REGRESSION labels." + ) + + return ( + f"You are refining SKILL.md for '{skill_name}'.\n" + "The skill is scored by how much it HELPS an agent answer correctly.\n" + "Assertions labeled NEEDS_SKILL = add this content. REGRESSION = simplify or remove.\n" + "Focus on: specific API syntax, version requirements, non-obvious patterns.\n" + "Do NOT add generic knowledge the agent already has (NEUTRAL assertions)." + f"{baseline_desc}" + f"{components_desc}" + f"{token_desc}" + f"{judge_desc}" + ) diff --git a/.test/src/skill_test/optimize/splitter.py b/.test/src/skill_test/optimize/splitter.py new file mode 100644 index 00000000..98f6c871 --- /dev/null +++ b/.test/src/skill_test/optimize/splitter.py @@ -0,0 +1,233 @@ +"""Train/val dataset splitting for GEPA optimization. + +Loads ground_truth.yaml test cases and splits them into train/val sets, +stratified by metadata.category when possible. + +GEPA's DefaultDataInst format: {"input": str, "additional_context": dict[str, str], "answer": str} + +We store our internal task representation alongside, and convert to GEPA format +when needed via to_gepa_instances(). +""" + +import json +import random +import re +from collections import defaultdict +from pathlib import Path +from typing import Any, TypedDict + +from ..dataset import EvalRecord, get_dataset_source + + +class SkillTask(TypedDict, total=False): + """Internal task representation (superset of GEPA DefaultDataInst).""" + + id: str + input: str # The prompt (maps to DefaultDataInst.input) + answer: str # Expected response (maps to DefaultDataInst.answer) + additional_context: dict[str, str] # Extra context (maps to DefaultDataInst.additional_context) + expectations: dict[str, Any] # Scorer expectations (not sent to GEPA directly) + metadata: dict[str, Any] # Category, difficulty, etc. + + +def _summarize_expectations(expectations: dict[str, Any]) -> str: + """Produce a human-readable summary of what a task tests. + + Included in additional_context so GEPA's reflection LM understands + what each test case is checking without parsing JSON. + """ + parts = [] + + patterns = expectations.get("expected_patterns", []) + if patterns: + descs = [] + for p in patterns: + if isinstance(p, str): + descs.append(p[:40]) + elif isinstance(p, dict): + descs.append(p.get("description", p.get("pattern", "")[:40])) + parts.append(f"Patterns: {', '.join(descs)}") + + facts = expectations.get("expected_facts", []) + if facts: + parts.append(f"Facts: {', '.join(str(f) for f in facts)}") + + guidelines = expectations.get("guidelines", []) + if guidelines: + parts.append(f"Guidelines: {'; '.join(str(g) for g in guidelines[:3])}") + + return " | ".join(parts) if parts else "No specific expectations" + + +def _record_to_task(record: EvalRecord) -> SkillTask: + """Convert an EvalRecord to our internal task format.""" + task: SkillTask = { + "id": record.id, + "input": record.inputs.get("prompt", ""), + "additional_context": {}, + "answer": "", + "metadata": record.metadata or {}, + } + if record.outputs: + task["answer"] = record.outputs.get("response", "") + if record.expectations: + task["expectations"] = record.expectations + # Also encode expectations into additional_context for GEPA reflection + task["additional_context"]["expectations"] = json.dumps(record.expectations) + # Human-readable summary for GEPA's reflection LM + task["additional_context"]["evaluation_criteria"] = _summarize_expectations( + record.expectations + ) + return task + + +def to_gepa_instances(tasks: list[SkillTask]) -> list[dict[str, Any]]: + """Convert internal tasks to GEPA DefaultDataInst format. + + Returns list of {"input": str, "additional_context": dict[str,str], "answer": str} + """ + return [ + { + "input": t["input"], + "additional_context": t.get("additional_context", {}), + "answer": t.get("answer", ""), + } + for t in tasks + ] + + +def create_gepa_datasets( + skill_name: str, + val_ratio: float = 0.2, + base_path: Path | None = None, + seed: int = 42, +) -> tuple[list[SkillTask], list[SkillTask] | None]: + """Load ground_truth.yaml, stratify by metadata.category, split into train/val. + + For skills with <5 test cases: uses all as train, val=None (single-task mode). + For skills with >=5 test cases: stratified train/val split (generalization mode). + + Args: + skill_name: Name of the skill to load test cases for + val_ratio: Fraction of test cases to hold out for validation + base_path: Override base path for skills directory + seed: Random seed for reproducible splits + + Returns: + Tuple of (train_tasks, val_tasks). val_tasks is None if <5 test cases. + """ + source = get_dataset_source(skill_name, base_path) + records = source.load() + + if not records: + return [], None + + tasks = [_record_to_task(r) for r in records] + + # Too few for a meaningful val split + if len(tasks) < 5: + return tasks, None + + # Stratify by category + by_category: dict[str, list[SkillTask]] = defaultdict(list) + for task in tasks: + cat = task.get("metadata", {}).get("category", "_uncategorized") + by_category[cat].append(task) + + rng = random.Random(seed) + train: list[SkillTask] = [] + val: list[SkillTask] = [] + + for cat, cat_tasks in by_category.items(): + rng.shuffle(cat_tasks) + n_val = max(1, int(len(cat_tasks) * val_ratio)) + + # Ensure at least 1 train sample per category + if len(cat_tasks) - n_val < 1: + n_val = len(cat_tasks) - 1 + + if n_val <= 0: + train.extend(cat_tasks) + else: + val.extend(cat_tasks[:n_val]) + train.extend(cat_tasks[n_val:]) + + # If val ended up empty, fall back + if not val: + return tasks, None + + return train, val + + +def generate_bootstrap_tasks(skill_name: str, base_path: Path | None = None) -> list[SkillTask]: + """Generate synthetic tasks from a SKILL.md when no ground_truth.yaml exists. + + Parses the SKILL.md for documented patterns and generates basic test prompts + that exercise each pattern. + + Args: + skill_name: Name of the skill + base_path: Override base path for skills directory + + Returns: + List of synthetic SkillTask dicts + """ + if base_path is None: + # Find repo root for path resolution + from .evaluator import _find_repo_root + + repo_root = _find_repo_root() + skill_md_candidates = [ + repo_root / ".claude" / "skills" / skill_name / "SKILL.md", + repo_root / "databricks-skills" / skill_name / "SKILL.md", + ] + else: + skill_md_candidates = [base_path.parent / skill_name / "SKILL.md"] + + skill_content = None + for path in skill_md_candidates: + if path.exists(): + skill_content = path.read_text() + break + + if not skill_content: + return [] + + tasks: list[SkillTask] = [] + + # Extract h2/h3 headers as topic areas + headers = re.findall(r"^#{2,3}\s+(.+)$", skill_content, re.MULTILINE) + + for i, header in enumerate(headers): + tasks.append( + { + "id": f"bootstrap_{i:03d}", + "input": f"Using the {skill_name} skill, help me with: {header}", + "additional_context": {}, + "answer": "", + "metadata": {"category": "bootstrap", "source": "auto_generated"}, + } + ) + + # Extract code block language hints for targeted prompts + code_langs = set(re.findall(r"```(\w+)\n", skill_content)) + for lang in code_langs: + tasks.append( + { + "id": f"bootstrap_lang_{lang}", + "input": f"Show me a {lang} example using {skill_name} patterns", + "additional_context": {}, + "answer": "", + "metadata": {"category": "bootstrap", "source": "auto_generated"}, + } + ) + + return tasks or [ + { + "id": "bootstrap_general", + "input": f"Explain the key patterns in {skill_name}", + "additional_context": {}, + "answer": "", + "metadata": {"category": "bootstrap", "source": "auto_generated"}, + } + ] diff --git a/.test/src/skill_test/optimize/tools.py b/.test/src/skill_test/optimize/tools.py new file mode 100644 index 00000000..29ef6c24 --- /dev/null +++ b/.test/src/skill_test/optimize/tools.py @@ -0,0 +1,271 @@ +"""Tool description extraction and writing for MCP server tools. + +Extracts @mcp.tool docstrings from Python source files, formats them for GEPA +optimization, and writes optimized descriptions back to source files. + +Each tool module (sql.py, compute.py, etc.) becomes one GEPA component so +GEPA's round-robin selector cycles through modules efficiently. +""" + +import ast +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +# Separator used between tools within a module's GEPA component text +TOOL_SEPARATOR = "\n\n### TOOL: " + +MCP_TOOLS_DIR = Path(__file__).resolve().parents[5] / "databricks-mcp-server" / "databricks_mcp_server" / "tools" + + +@dataclass +class ToolDescription: + """A single tool's description extracted from source.""" + + name: str # Function name + docstring: str # The full docstring text + module: str # Module name (e.g., "sql", "compute") + lineno: int # Line number of the function def + source_path: Path # Path to the source file + + +def _find_tools_dir() -> Path: + """Find the MCP tools directory.""" + if MCP_TOOLS_DIR.exists(): + return MCP_TOOLS_DIR + # Fallback: search from repo root + from .evaluator import _find_repo_root + + repo_root = _find_repo_root() + candidate = repo_root / "databricks-mcp-server" / "databricks_mcp_server" / "tools" + if candidate.exists(): + return candidate + raise FileNotFoundError( + "Could not find MCP tools directory. " + "Expected at databricks-mcp-server/databricks_mcp_server/tools/" + ) + + +def extract_tool_descriptions( + modules: list[str] | None = None, + tools_dir: Path | None = None, +) -> dict[str, list[ToolDescription]]: + """Extract all @mcp.tool docstrings from MCP server tool files. + + Args: + modules: Optional list of module names to extract (e.g., ["sql", "compute"]). + If None, extracts all modules. + tools_dir: Override path to tools directory. + + Returns: + Dict mapping module_name -> list of ToolDescription. + """ + if tools_dir is None: + tools_dir = _find_tools_dir() + + results: dict[str, list[ToolDescription]] = {} + + for py_file in sorted(tools_dir.glob("*.py")): + module_name = py_file.stem + if module_name == "__init__": + continue + if modules and module_name not in modules: + continue + + source = py_file.read_text() + tree = ast.parse(source) + + tool_descs = [] + for node in ast.walk(tree): + if not isinstance(node, ast.FunctionDef): + continue + # Check if decorated with @mcp.tool + for dec in node.decorator_list: + is_mcp_tool = False + if isinstance(dec, ast.Attribute) and isinstance(dec.value, ast.Name): + if dec.value.id == "mcp" and dec.attr == "tool": + is_mcp_tool = True + elif isinstance(dec, ast.Name) and dec.id == "mcp": + is_mcp_tool = True + if is_mcp_tool: + docstring = ast.get_docstring(node) or "" + tool_descs.append( + ToolDescription( + name=node.name, + docstring=docstring, + module=module_name, + lineno=node.lineno, + source_path=py_file, + ) + ) + break + + if tool_descs: + results[module_name] = tool_descs + + return results + + +def tools_to_gepa_components( + tool_map: dict[str, list[ToolDescription]], + per_module: bool = True, +) -> dict[str, str]: + """Convert extracted tool descriptions into GEPA component text blocks. + + Args: + tool_map: Output of extract_tool_descriptions() + per_module: If True, one GEPA component per module (e.g., "tools_sql"). + If False, all tools in a single "tool_descriptions" component. + + Returns: + Dict mapping component_name -> text block. + """ + if per_module: + components = {} + for module_name, tools in tool_map.items(): + text_parts = [] + for td in tools: + text_parts.append(f"### TOOL: {td.name}\n{td.docstring}") + components[f"tools_{module_name}"] = "\n\n".join(text_parts) + return components + else: + all_parts = [] + for module_name, tools in sorted(tool_map.items()): + for td in tools: + all_parts.append(f"### TOOL: {td.name} (module: {module_name})\n{td.docstring}") + return {"tool_descriptions": "\n\n".join(all_parts)} + + +def parse_gepa_component(component_text: str) -> dict[str, str]: + """Parse a GEPA component text block back into individual tool descriptions. + + Args: + component_text: Text block with ### TOOL: markers + + Returns: + Dict mapping tool_name -> optimized docstring + """ + tools = {} + parts = re.split(r"### TOOL:\s*", component_text) + for part in parts: + part = part.strip() + if not part: + continue + # First line has the tool name (possibly with module annotation) + lines = part.split("\n", 1) + name_line = lines[0].strip() + # Remove module annotation if present: "execute_sql (module: sql)" + name = re.match(r"(\w+)", name_line).group(1) if re.match(r"(\w+)", name_line) else name_line + docstring = lines[1].strip() if len(lines) > 1 else "" + tools[name] = docstring + return tools + + +def write_tool_descriptions( + optimized: dict[str, str], + tool_map: dict[str, list[ToolDescription]], +) -> list[Path]: + """Write optimized docstrings back to MCP server source files. + + Uses AST to locate the exact docstring positions and replaces them + in the source text while preserving all other code. + + Args: + optimized: Dict mapping tool_name -> optimized docstring text + tool_map: Original extraction map (for source file locations) + + Returns: + List of modified file paths + """ + # Group updates by file + updates_by_file: dict[Path, list[tuple[ToolDescription, str]]] = {} + for module_name, tools in tool_map.items(): + for td in tools: + if td.name in optimized: + updates_by_file.setdefault(td.source_path, []).append((td, optimized[td.name])) + + modified_files = [] + for file_path, updates in updates_by_file.items(): + source = file_path.read_text() + tree = ast.parse(source) + source_lines = source.splitlines(keepends=True) + + # Process updates in reverse line order to preserve positions + updates_sorted = sorted(updates, key=lambda x: x[0].lineno, reverse=True) + + for td, new_docstring in updates_sorted: + # Find the function node + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == td.name and node.lineno == td.lineno: + # Find the docstring node (first Expr with a Constant string) + if ( + node.body + and isinstance(node.body[0], ast.Expr) + and isinstance(node.body[0].value, ast.Constant) + and isinstance(node.body[0].value.value, str) + ): + doc_node = node.body[0] + # Get the docstring's line range + start_line = doc_node.lineno - 1 # 0-indexed + end_line = doc_node.end_lineno # exclusive + + # Detect indentation from the original docstring line + original_line = source_lines[start_line] + indent = re.match(r"(\s*)", original_line).group(1) + + # Build new docstring with proper indentation + new_doc_lines = [f'{indent}"""\n'] + for line in new_docstring.split("\n"): + if line.strip(): + new_doc_lines.append(f"{indent}{line}\n") + else: + new_doc_lines.append("\n") + new_doc_lines.append(f'{indent}"""\n') + + # Replace lines + source_lines[start_line:end_line] = new_doc_lines + break + + new_source = "".join(source_lines) + + # Validate the new source parses + try: + ast.parse(new_source) + except SyntaxError as e: + print(f"WARNING: Optimized source for {file_path.name} has syntax error: {e}") + print("Skipping this file.") + continue + + file_path.write_text(new_source) + modified_files.append(file_path) + + return modified_files + + +def list_tool_modules(tools_dir: Path | None = None) -> list[str]: + """List available tool module names.""" + if tools_dir is None: + tools_dir = _find_tools_dir() + return sorted( + f.stem for f in tools_dir.glob("*.py") + if f.stem != "__init__" + ) + + +def get_tool_stats(tools_dir: Path | None = None) -> dict[str, Any]: + """Get statistics about available MCP tools.""" + tool_map = extract_tool_descriptions(tools_dir=tools_dir) + total_tools = sum(len(tools) for tools in tool_map.values()) + total_chars = sum( + len(td.docstring) for tools in tool_map.values() for td in tools + ) + return { + "modules": len(tool_map), + "total_tools": total_tools, + "total_description_chars": total_chars, + "per_module": { + name: {"tools": len(tools), "chars": sum(len(td.docstring) for td in tools)} + for name, tools in tool_map.items() + }, + } diff --git a/.test/src/skill_test/scorers/universal.py b/.test/src/skill_test/scorers/universal.py index 3b839bf4..40e2a8b2 100644 --- a/.test/src/skill_test/scorers/universal.py +++ b/.test/src/skill_test/scorers/universal.py @@ -125,20 +125,20 @@ def no_hallucinated_apis(outputs: Dict[str, Any]) -> Feedback: @scorer -def expected_facts_present(outputs: Dict[str, Any], expectations: Dict[str, Any]) -> Feedback: - """Check if expected facts are mentioned in response.""" +def expected_facts_present(outputs: Dict[str, Any], expectations: Dict[str, Any]) -> List[Feedback]: + """Check if expected facts are mentioned in response (per-fact granularity).""" response = outputs.get("response", "").lower() expected_facts = expectations.get("expected_facts", []) if not expected_facts: - return Feedback(name="expected_facts", value="skip", rationale="No expected_facts defined") + return [Feedback(name="expected_facts", value="skip", rationale="No expected_facts defined")] - missing = [] + feedbacks = [] for fact in expected_facts: - if fact.lower() not in response: - missing.append(fact) - - if missing: - return Feedback(name="expected_facts", value="no", rationale=f"Missing facts: {missing}") - - return Feedback(name="expected_facts", value="yes", rationale=f"All {len(expected_facts)} expected facts present") + found = fact.lower() in response + feedbacks.append(Feedback( + name=f"fact_{fact[:40]}", + value="yes" if found else "no", + rationale=f"{'Found' if found else 'Missing'}: {fact}", + )) + return feedbacks diff --git a/.test/tests/test_optimize_e2e.py b/.test/tests/test_optimize_e2e.py new file mode 100644 index 00000000..355beba0 --- /dev/null +++ b/.test/tests/test_optimize_e2e.py @@ -0,0 +1,229 @@ +"""End-to-end tests for GEPA skill optimization using optimize_anything API. + +Unit tests run without API keys. E2E tests require GEPA reflection LM access. + +Run unit tests: + cd .test && uv run pytest tests/test_optimize_e2e.py -v -k "not TestOptimizeE2E" + +Run everything (slow, requires API key): + cd .test && uv run pytest tests/test_optimize_e2e.py -v -s +""" + +import pytest + +from skill_test.optimize.evaluator import token_efficiency_score, count_tokens, SKILL_KEY +from skill_test.optimize.splitter import create_gepa_datasets, generate_bootstrap_tasks, to_gepa_instances +from skill_test.optimize.asi import feedback_to_score, feedback_to_asi + +try: + from mlflow.entities import Feedback + HAS_MLFLOW = True +except ImportError: + Feedback = None + HAS_MLFLOW = False + +try: + from gepa.optimize_anything import GEPAConfig, EngineConfig, ReflectionConfig + HAS_GEPA = True +except ImportError: + HAS_GEPA = False + + +# -------------------------------------------------------------------------- +# Unit tests (no GEPA/LLM required) +# -------------------------------------------------------------------------- + +class TestTokenEfficiency: + def test_same_size_scores_one(self): + text = "Hello world, this is a test." + tokens = count_tokens(text) + assert token_efficiency_score(text, tokens) == 1.0 + + def test_smaller_scores_one(self): + assert token_efficiency_score("short", 100) == 1.0 + + def test_double_size_scores_zero(self): + text = "word " * 200 + tokens = count_tokens(text) + assert token_efficiency_score(text + text, tokens) == pytest.approx(0.0, abs=0.05) + + def test_zero_original_returns_one(self): + assert token_efficiency_score("anything", 0) == 1.0 + + +class TestSplitter: + def test_model_serving_has_split(self): + try: + train, val = create_gepa_datasets("databricks-model-serving") + assert len(train) > 0 + if len(train) + (len(val) if val else 0) >= 5: + assert val is not None + except FileNotFoundError: + pytest.skip("No ground_truth.yaml") + + def test_reproducible_splits(self): + try: + t1, v1 = create_gepa_datasets("databricks-model-serving", seed=42) + t2, v2 = create_gepa_datasets("databricks-model-serving", seed=42) + assert [t["id"] for t in t1] == [t["id"] for t in t2] + except FileNotFoundError: + pytest.skip("No ground_truth.yaml") + + def test_tasks_have_correct_keys(self): + try: + train, _ = create_gepa_datasets("databricks-model-serving") + for task in train: + assert "id" in task + assert "input" in task + assert "answer" in task + assert "additional_context" in task + except FileNotFoundError: + pytest.skip("No ground_truth.yaml") + + def test_to_gepa_instances(self): + try: + train, _ = create_gepa_datasets("databricks-model-serving") + instances = to_gepa_instances(train) + assert len(instances) == len(train) + for inst in instances: + assert "input" in inst + assert "additional_context" in inst + assert "answer" in inst + assert "id" not in inst + except FileNotFoundError: + pytest.skip("No ground_truth.yaml") + + def test_bootstrap_tasks_generated(self): + tasks = generate_bootstrap_tasks("databricks-model-serving") + assert len(tasks) > 0 + for task in tasks: + assert "id" in task + assert "input" in task + + +@pytest.mark.skipif(not HAS_MLFLOW, reason="mlflow not installed") +class TestASI: + def test_yes_scores_one(self): + assert feedback_to_score(Feedback(name="test", value="yes")) == 1.0 + + def test_no_scores_zero(self): + assert feedback_to_score(Feedback(name="test", value="no")) == 0.0 + + def test_skip_returns_none(self): + assert feedback_to_score(Feedback(name="test", value="skip")) is None + + def test_feedback_to_asi_composite(self): + feedbacks = [ + Feedback(name="syntax", value="yes", rationale="Valid"), + Feedback(name="pattern", value="no", rationale="Missing X"), + Feedback(name="optional", value="skip", rationale="N/A"), + ] + score, si = feedback_to_asi(feedbacks) + assert score == pytest.approx(0.5) + assert si["syntax"]["score"] == 1.0 + assert si["pattern"]["score"] == 0.0 + assert si["optional"]["status"] == "skipped" + assert si["_summary"]["scored"] == 2 + + +@pytest.mark.skipif(not HAS_GEPA, reason="gepa not installed") +class TestConfig: + def test_presets_exist(self): + from skill_test.optimize.config import PRESETS + assert "quick" in PRESETS + assert "standard" in PRESETS + assert "thorough" in PRESETS + + def test_quick_has_fewer_calls(self): + from skill_test.optimize.config import PRESETS + assert PRESETS["quick"].engine.max_metric_calls < PRESETS["standard"].engine.max_metric_calls + + def test_presets_are_gepa_configs(self): + from skill_test.optimize.config import PRESETS + for name, cfg in PRESETS.items(): + assert isinstance(cfg, GEPAConfig), f"{name} is not GEPAConfig" + assert isinstance(cfg.engine, EngineConfig) + assert isinstance(cfg.reflection, ReflectionConfig) + + +class TestBootstrapMode: + def test_nonexistent_skill_returns_empty(self): + tasks = generate_bootstrap_tasks("nonexistent-skill-xyz") + assert tasks == [] + + def test_bootstrap_has_gepa_format(self): + tasks = generate_bootstrap_tasks("databricks-model-serving") + if not tasks: + pytest.skip("No SKILL.md found") + instances = to_gepa_instances(tasks) + for inst in instances: + assert isinstance(inst["input"], str) + assert isinstance(inst["additional_context"], dict) + + +@pytest.mark.skipif(not HAS_GEPA, reason="gepa not installed") +class TestToolExtraction: + def test_extract_tools(self): + from skill_test.optimize.tools import extract_tool_descriptions, get_tool_stats + stats = get_tool_stats() + assert stats["modules"] > 0 + assert stats["total_tools"] > 0 + + def test_tools_to_gepa_components(self): + from skill_test.optimize.tools import extract_tool_descriptions, tools_to_gepa_components + tool_map = extract_tool_descriptions(modules=["sql"]) + components = tools_to_gepa_components(tool_map) + assert "tools_sql" in components + assert "### TOOL:" in components["tools_sql"] + + +@pytest.mark.skipif(not HAS_GEPA, reason="gepa not installed") +class TestDryRun: + def test_dry_run_skill_only(self): + from skill_test.optimize.runner import optimize_skill + try: + result = optimize_skill("databricks-model-serving", preset="quick", dry_run=True) + assert result.improvement == 0.0 + assert result.original_content == result.optimized_content + assert result.gepa_result is None + assert result.original_token_count > 0 + except FileNotFoundError: + pytest.skip("SKILL.md not found") + + def test_dry_run_with_tools(self): + from skill_test.optimize.runner import optimize_skill + try: + result = optimize_skill( + "databricks-model-serving", preset="quick", dry_run=True, + include_tools=True, tool_modules=["serving"], + ) + assert SKILL_KEY in result.components + assert "tools_serving" in result.components + assert result.original_token_count > 0 + except FileNotFoundError: + pytest.skip("SKILL.md not found") + + +# -------------------------------------------------------------------------- +# E2E integration (requires GEPA + LLM API key) +# -------------------------------------------------------------------------- + +@pytest.mark.skipif(not HAS_GEPA, reason="gepa not installed") +@pytest.mark.slow +class TestOptimizeE2E: + def test_optimize_improves_quality_and_reduces_tokens(self): + from skill_test.optimize.runner import optimize_skill + result = optimize_skill( + skill_name="databricks-spark-declarative-pipelines", + mode="static", + preset="quick", + ) + assert result.optimized_score >= result.original_score + assert result.optimized_token_count <= result.original_token_count * 1.05 + + if result.val_scores: + avg_val = sum(result.val_scores.values()) / len(result.val_scores) + assert avg_val >= result.optimized_score - 0.05 + + print(f"\nQuality: {result.original_score:.3f} -> {result.optimized_score:.3f}") + print(f"Tokens: {result.original_token_count:,} -> {result.optimized_token_count:,}") diff --git a/.test/uv.lock b/.test/uv.lock index 0a1e71f8..67d8cfe0 100644 --- a/.test/uv.lock +++ b/.test/uv.lock @@ -838,7 +838,7 @@ name = "exceptiongroup" version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" } wheels = [ @@ -1149,6 +1149,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/01/c9/97cc5aae1648dcb851958a3ddf73ccd7dbe5650d95203ecb4d7720b4cdbf/fsspec-2026.1.0-py3-none-any.whl", hash = "sha256:cb76aa913c2285a3b49bdd5fc55b1d7c708d7208126b60f2eb8194fe1b4cbdcc", size = 201838, upload-time = "2026-01-09T15:21:34.041Z" }, ] +[[package]] +name = "gepa" +version = "0.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/30/511e52916956508f56eca721260fcd524cfffd580e57782dd471be925f7e/gepa-0.1.0.tar.gz", hash = "sha256:f8b3d7918d4cdcf8593f39ef1cc757c4ba1a4e6793e3ffb622e6c0bc60a1efd9", size = 226064, upload-time = "2026-02-19T19:43:08.272Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/32/fe8afb3d2a6605a6bcbc8f119f0a2adae96e9e5d57ebed055490219956a8/gepa-0.1.0-py3-none-any.whl", hash = "sha256:4e3f8fe8ca20169e60518b2e9d416e8c4a579459848adffdcad12223fbf9643e", size = 191392, upload-time = "2026-02-19T19:43:07.065Z" }, +] + [[package]] name = "gitdb" version = "4.0.12" @@ -3919,8 +3928,10 @@ dependencies = [ [package.optional-dependencies] all = [ { name = "databricks-sdk" }, + { name = "gepa" }, { name = "pytest" }, { name = "pytest-asyncio" }, + { name = "tiktoken" }, ] databricks = [ { name = "databricks-sdk" }, @@ -3929,18 +3940,24 @@ dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, ] +optimize = [ + { name = "gepa" }, + { name = "tiktoken" }, +] [package.metadata] requires-dist = [ { name = "databricks-sdk", marker = "extra == 'databricks'", specifier = ">=0.20.0" }, + { name = "gepa", marker = "extra == 'optimize'", specifier = ">=0.1.0" }, { name = "mlflow", extras = ["databricks"], specifier = ">=3.6.1" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23" }, { name = "python-dotenv", specifier = ">=1.0.0" }, { name = "pyyaml", specifier = ">=6.0" }, - { name = "skill-test", extras = ["databricks", "dev"], marker = "extra == 'all'" }, + { name = "skill-test", extras = ["databricks", "dev", "optimize"], marker = "extra == 'all'" }, + { name = "tiktoken", marker = "extra == 'optimize'", specifier = ">=0.7.0" }, ] -provides-extras = ["databricks", "dev", "all"] +provides-extras = ["databricks", "dev", "optimize", "all"] [[package]] name = "skops" diff --git a/databricks-builder-app/app.yaml b/databricks-builder-app/app.yaml index 5efe3f00..94694545 100644 --- a/databricks-builder-app/app.yaml +++ b/databricks-builder-app/app.yaml @@ -92,9 +92,9 @@ env: # instead of hitting Anthropic directly. The app dynamically sets ANTHROPIC_BASE_URL # and ANTHROPIC_AUTH_TOKEN from the user's Databricks credentials at runtime. - name: ANTHROPIC_MODEL - value: "databricks-claude-opus-4-5" + value: "databricks-claude-opus-4-6" - name: ANTHROPIC_MODEL_MINI - value: "databricks-claude-sonnet-4-5" + value: "databricks-claude-sonnet-4-6" - name: CLAUDE_CODE_STREAM_CLOSE_TIMEOUT value: "3600000" diff --git a/databricks-builder-app/requirements.txt b/databricks-builder-app/requirements.txt index e9287669..48b5cb7c 100644 --- a/databricks-builder-app/requirements.txt +++ b/databricks-builder-app/requirements.txt @@ -16,7 +16,7 @@ greenlet>=3.0.0 psycopg2-binary>=2.9.11 # Claude Agent SDK -claude-agent-sdk>=0.1.19 +claude-agent-sdk>=0.1.39 anthropic>=0.42.0 # MLflow for tracing Claude Code conversations diff --git a/databricks-builder-app/server/services/agent.py b/databricks-builder-app/server/services/agent.py index b2766523..30d2cf47 100644 --- a/databricks-builder-app/server/services/agent.py +++ b/databricks-builder-app/server/services/agent.py @@ -423,11 +423,15 @@ async def stream_agent_response( claude_env['ANTHROPIC_AUTH_TOKEN'] = effective_fmapi_token # Set the model to use (required for Databricks FMAPI) - anthropic_model = os.environ.get('ANTHROPIC_MODEL', 'databricks-claude-opus-4-5') + anthropic_model = os.environ.get('ANTHROPIC_MODEL', 'databricks-claude-opus-4-6') claude_env['ANTHROPIC_MODEL'] = anthropic_model - # Disable beta headers for Databricks FMAPI compatibility - claude_env['ANTHROPIC_CUSTOM_HEADERS'] = 'x-databricks-disable-beta-headers: true' + # Disable beta headers and experimental betas for Databricks FMAPI compatibility + # ANTHROPIC_CUSTOM_HEADERS enables coding agent mode on FMAPI + # CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS prevents context_management and other + # experimental body parameters that FMAPI doesn't support (400: Extra inputs not permitted) + claude_env['ANTHROPIC_CUSTOM_HEADERS'] = 'x-databricks-use-coding-agent-mode: true' + claude_env['CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS'] = '1' logger.info(f'Configured Databricks model serving: {anthropic_base_url} with model {anthropic_model}') logger.info(f'Claude env vars: BASE_URL={claude_env.get("ANTHROPIC_BASE_URL")}, MODEL={claude_env.get("ANTHROPIC_MODEL")}') diff --git a/databricks-mcp-server/databricks_mcp_server/tools/manifest.py b/databricks-mcp-server/databricks_mcp_server/tools/manifest.py index 1adc30c5..500a77a7 100644 --- a/databricks-mcp-server/databricks_mcp_server/tools/manifest.py +++ b/databricks-mcp-server/databricks_mcp_server/tools/manifest.py @@ -30,7 +30,8 @@ def _delete_from_databricks(resource_type: str, resource_id: str) -> Optional[st @mcp.tool def list_tracked_resources(type: Optional[str] = None) -> Dict[str, Any]: - """List resources tracked in the project manifest. + """ + List resources tracked in the project manifest. The manifest records every resource created through the MCP server (dashboards, jobs, pipelines, Genie spaces, KAs, MAS, schemas, volumes, etc.). @@ -60,7 +61,8 @@ def delete_tracked_resource( resource_id: str, delete_from_databricks: bool = False, ) -> Dict[str, Any]: - """Delete a resource from the project manifest, and optionally from Databricks. + """ + Delete a resource from the project manifest, and optionally from Databricks. Use this to clean up resources that were created during development/testing. diff --git a/databricks-mcp-server/databricks_mcp_server/tools/pipelines.py b/databricks-mcp-server/databricks_mcp_server/tools/pipelines.py index 26b91d8c..0a77c322 100644 --- a/databricks-mcp-server/databricks_mcp_server/tools/pipelines.py +++ b/databricks-mcp-server/databricks_mcp_server/tools/pipelines.py @@ -1,33 +1,24 @@ -"""Pipeline tools - Manage Spark Declarative Pipelines (SDP). +"""Pipeline tools - Manage Spark Declarative Pipelines (SDP).""" -Provides 4 workflow-oriented tools following the Lakebase pattern: -- create_or_update_pipeline: idempotent resource management -- get_pipeline: get details by name/ID, or list all -- delete_pipeline: delete by ID -- run_pipeline: start, stop, and wait for pipeline runs -""" - -import logging -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any from databricks_tools_core.identity import get_default_tags from databricks_tools_core.spark_declarative_pipelines.pipelines import ( - create_or_update_pipeline as _create_or_update_pipeline, + create_pipeline as _create_pipeline, get_pipeline as _get_pipeline, + update_pipeline as _update_pipeline, delete_pipeline as _delete_pipeline, start_update as _start_update, get_update as _get_update, stop_pipeline as _stop_pipeline, get_pipeline_events as _get_pipeline_events, + create_or_update_pipeline as _create_or_update_pipeline, find_pipeline_by_name as _find_pipeline_by_name, - wait_for_pipeline_update as _wait_for_pipeline_update, ) from ..manifest import register_deleter from ..server import mcp -logger = logging.getLogger(__name__) - def _delete_pipeline_resource(resource_id: str) -> None: _delete_pipeline(pipeline_id=resource_id) @@ -36,192 +27,113 @@ def _delete_pipeline_resource(resource_id: str) -> None: register_deleter("pipeline", _delete_pipeline_resource) -# ============================================================================ -# Tool 1: create_or_update_pipeline -# ============================================================================ - - @mcp.tool -def create_or_update_pipeline( +def create_pipeline( name: str, root_path: str, catalog: str, schema: str, workspace_file_paths: List[str], - start_run: bool = False, - wait_for_completion: bool = False, - full_refresh: bool = True, - timeout: int = 1800, extra_settings: Dict[str, Any] = None, ) -> Dict[str, Any]: """ - Create a new pipeline or update an existing one with the same name. - - This is the main tool for pipeline resource management. It: - 1. Searches for an existing pipeline with the same name (or uses 'id' from extra_settings) - 2. Creates a new pipeline or updates the existing one - 3. Optionally starts a run and waits for completion - - For more granular run control (stop, selective refresh, validate-only), - use run_pipeline() separately. - - Uses Unity Catalog and serverless compute by default. + Create a new Spark Declarative Pipeline (Unity Catalog, serverless by default). Args: - name: Pipeline name (used for lookup and creation) + name: Pipeline name root_path: Root folder for source code (added to Python sys.path for imports) - catalog: Unity Catalog name for output tables + catalog: Unity Catalog name schema: Schema name for output tables workspace_file_paths: List of workspace file paths (raw .sql or .py files) - start_run: If True, start a pipeline run after create/update (default: False) - wait_for_completion: If True and start_run=True, wait for the run to finish (default: False) - full_refresh: If True, performs full refresh when starting (default: True) - timeout: Maximum wait time in seconds when wait_for_completion=True (default: 1800) - extra_settings: Optional dict with additional pipeline settings. Supports all SDK - options: clusters, continuous, development, photon, edition, channel, event_log, - configuration, notifications, tags, serverless, etc. - If 'id' is provided, the pipeline will be updated instead of created. - Explicit parameters (name, root_path, catalog, schema) take precedence. + extra_settings: Optional dict with additional pipeline settings (clusters, + continuous, development, photon, edition, channel, event_log, configuration, + notifications, tags, serverless, etc.). Explicit parameters take precedence. Returns: - Dictionary with: - - pipeline_id: The pipeline ID - - pipeline_name: The pipeline name - - created: True if newly created, False if updated - - success: True if operation succeeded - - message: Human-readable status message - - Example: - >>> create_or_update_pipeline( - ... name="my_pipeline", - ... root_path="/Workspace/project", - ... catalog="my_catalog", - ... schema="my_schema", - ... workspace_file_paths=["/Workspace/project/pipeline.py"] - ... ) - >>> create_or_update_pipeline( - ... name="my_pipeline", ..., - ... start_run=True, wait_for_completion=True - ... ) + Dictionary with pipeline_id of the created pipeline. """ # Auto-inject default tags into extra_settings; user tags take precedence extra_settings = extra_settings or {} extra_settings.setdefault("tags", {}) extra_settings["tags"] = {**get_default_tags(), **extra_settings["tags"]} - result = _create_or_update_pipeline( + result = _create_pipeline( name=name, root_path=root_path, catalog=catalog, schema=schema, workspace_file_paths=workspace_file_paths, - start_run=start_run, - wait_for_completion=wait_for_completion, - full_refresh=full_refresh, - timeout=timeout, extra_settings=extra_settings, ) - # Track resource on successful create/update + # Track resource on successful create try: - result_dict = result.to_dict() - pipeline_id = result_dict.get("pipeline_id") - if pipeline_id: + if result.pipeline_id: from ..manifest import track_resource track_resource( resource_type="pipeline", name=name, - resource_id=pipeline_id, + resource_id=result.pipeline_id, ) except Exception: pass # best-effort tracking - return result.to_dict() - - -# ============================================================================ -# Tool 2: get_pipeline -# ============================================================================ + return {"pipeline_id": result.pipeline_id} @mcp.tool -def get_pipeline( - pipeline_id: Optional[str] = None, - name: Optional[str] = None, -) -> Dict[str, Any]: +def get_pipeline(pipeline_id: str) -> Dict[str, Any]: """ - Get pipeline details, or list all pipelines. - - Pass pipeline_id or name to get one pipeline's details enriched with - latest update status and recent events. Omit both to list all pipelines. + Get pipeline details and configuration. Args: - pipeline_id: Pipeline ID. Takes precedence over name. - name: Pipeline name. Used to look up pipeline ID if pipeline_id not provided. + pipeline_id: Pipeline ID Returns: - Single pipeline dict with enriched details (if ID/name provided), - or {"pipelines": [...]} when listing all. - - Example: - >>> get_pipeline(pipeline_id="abc-123") - {"pipeline_id": "abc-123", "name": "my_pipeline", "state": "IDLE", ...} - >>> get_pipeline(name="my_pipeline") - {"pipeline_id": "abc-123", "name": "my_pipeline", ...} - >>> get_pipeline() - {"pipelines": [{"pipeline_id": "abc-123", "name": "my_pipeline", ...}]} + Dictionary with pipeline configuration and state. """ - # Resolve name to pipeline_id if needed - if not pipeline_id and name: - pipeline_id = _find_pipeline_by_name(name=name) - if not pipeline_id: - return {"error": f"Pipeline '{name}' not found."} - - if pipeline_id: - result = _get_pipeline(pipeline_id=pipeline_id) - pipeline_dict = result.as_dict() if hasattr(result, "as_dict") else vars(result) - - # Enrich with latest update status - try: - latest_updates = pipeline_dict.get("latest_updates", []) - if latest_updates: - latest = latest_updates[0] - update_id = latest.get("update_id") - if update_id: - update_result = _get_update(pipeline_id=pipeline_id, update_id=update_id) - update_dict = update_result.as_dict() if hasattr(update_result, "as_dict") else vars(update_result) - pipeline_dict["latest_update_status"] = update_dict - except Exception: - pass - - # Enrich with recent events - try: - events = _get_pipeline_events(pipeline_id=pipeline_id, max_results=10) - pipeline_dict["recent_events"] = [e.as_dict() if hasattr(e, "as_dict") else vars(e) for e in events] - except Exception: - pass - - return pipeline_dict - - # List all pipelines - from databricks_tools_core.auth import get_workspace_client - - w = get_workspace_client() - pipelines = [] - try: - for p in w.pipelines.list_pipelines(): - entry = p.as_dict() if hasattr(p, "as_dict") else vars(p) - pipelines.append(entry) - except Exception as e: - return {"error": f"Failed to list pipelines: {e}"} + result = _get_pipeline(pipeline_id=pipeline_id) + return result.as_dict() if hasattr(result, "as_dict") else vars(result) - return {"pipelines": pipelines} +@mcp.tool +def update_pipeline( + pipeline_id: str, + name: str = None, + root_path: str = None, + catalog: str = None, + schema: str = None, + workspace_file_paths: List[str] = None, + extra_settings: Dict[str, Any] = None, +) -> Dict[str, str]: + """ + Update pipeline configuration. + + Args: + pipeline_id: Pipeline ID + name: New pipeline name + root_path: New root folder for source code + catalog: New catalog name + schema: New schema name + workspace_file_paths: New list of file paths (raw .sql or .py files) + extra_settings: Optional dict with additional pipeline settings (clusters, + continuous, development, photon, edition, channel, event_log, configuration, + notifications, tags, serverless, etc.). Explicit parameters take precedence. -# ============================================================================ -# Tool 3: delete_pipeline -# ============================================================================ + Returns: + Dictionary with status message. + """ + _update_pipeline( + pipeline_id=pipeline_id, + name=name, + root_path=root_path, + catalog=catalog, + schema=schema, + workspace_file_paths=workspace_file_paths, + extra_settings=extra_settings, + ) + return {"status": "updated"} @mcp.tool @@ -245,59 +157,27 @@ def delete_pipeline(pipeline_id: str) -> Dict[str, str]: return {"status": "deleted"} -# ============================================================================ -# Tool 4: run_pipeline -# ============================================================================ - - @mcp.tool -def run_pipeline( +def start_update( pipeline_id: str, + refresh_selection: List[str] = None, full_refresh: bool = False, - refresh_selection: Optional[List[str]] = None, - full_refresh_selection: Optional[List[str]] = None, + full_refresh_selection: List[str] = None, validate_only: bool = False, - wait_for_completion: bool = False, - timeout: int = 1800, - stop: bool = False, -) -> Dict[str, Any]: +) -> Dict[str, str]: """ - Start, stop, or monitor a pipeline run. - - Set stop=True to stop a running pipeline. Otherwise starts an update. - Optionally waits for the run to complete with timeout. + Start a pipeline update or dry-run validation. Args: pipeline_id: Pipeline ID - full_refresh: If True, performs full refresh of all tables refresh_selection: List of table names to refresh + full_refresh: If True, performs full refresh of all tables full_refresh_selection: List of table names for full refresh validate_only: If True, validates without updating data (dry run) - wait_for_completion: If True, wait for run to complete (default: False) - timeout: Maximum wait time in seconds (default: 1800 = 30 minutes) - stop: If True, stop the currently running pipeline instead of starting Returns: - Dictionary with: - - update_id: Update ID (if started) - - state: Final state (if waited) - - success: True if completed successfully - - duration_seconds: Time taken (if waited) - - errors: List of error details (if failed) - - status: "stopped" (if stop=True) - - Example: - >>> run_pipeline(pipeline_id="abc-123", full_refresh=True) - {"update_id": "xyz-456", "status": "started"} - >>> run_pipeline(pipeline_id="abc-123", stop=True) - {"status": "stopped"} - >>> run_pipeline(pipeline_id="abc-123", wait_for_completion=True, timeout=600) - {"update_id": "xyz-456", "state": "COMPLETED", "success": True, ...} + Dictionary with update_id for polling status. """ - if stop: - _stop_pipeline(pipeline_id=pipeline_id) - return {"pipeline_id": pipeline_id, "status": "stopped"} - update_id = _start_update( pipeline_id=pipeline_id, refresh_selection=refresh_selection, @@ -305,55 +185,188 @@ def run_pipeline( full_refresh_selection=full_refresh_selection, validate_only=validate_only, ) + return {"update_id": update_id} + + +@mcp.tool +def get_update(pipeline_id: str, update_id: str) -> Dict[str, Any]: + """ + Get pipeline update status and results. + + Args: + pipeline_id: Pipeline ID + update_id: Update ID from start_update + + Returns: + Dictionary with update status (QUEUED, RUNNING, COMPLETED, FAILED, etc.) + """ + result = _get_update(pipeline_id=pipeline_id, update_id=update_id) + return result.as_dict() if hasattr(result, "as_dict") else vars(result) - result: Dict[str, Any] = { - "pipeline_id": pipeline_id, - "update_id": update_id, - "status": "started", - } - if not wait_for_completion: - result["message"] = ( - f"Pipeline update started. Use get_pipeline(pipeline_id='{pipeline_id}') " - f"to check status, or run_pipeline with wait_for_completion=True to wait." +@mcp.tool +def stop_pipeline(pipeline_id: str) -> Dict[str, str]: + """ + Stop a running pipeline. + + Args: + pipeline_id: Pipeline ID + + Returns: + Dictionary with status message. + """ + _stop_pipeline(pipeline_id=pipeline_id) + return {"status": "stopped"} + + +@mcp.tool +def get_pipeline_events( + pipeline_id: str, + max_results: int = 100, +) -> List[Dict[str, Any]]: + """ + Get pipeline events, issues, and error messages. + + Use this to debug pipeline failures. + + Args: + pipeline_id: Pipeline ID + max_results: Maximum number of events to return (default: 100) + + Returns: + List of event dictionaries with error details. + """ + events = _get_pipeline_events(pipeline_id=pipeline_id, max_results=max_results) + return [e.as_dict() if hasattr(e, "as_dict") else vars(e) for e in events] + + +@mcp.tool +def create_or_update_pipeline( + name: str, + root_path: str, + catalog: str, + schema: str, + workspace_file_paths: List[str], + start_run: bool = False, + wait_for_completion: bool = False, + full_refresh: bool = True, + timeout: int = 1800, + extra_settings: Dict[str, Any] = None, +) -> Dict[str, Any]: + """ + Create a new pipeline or update an existing one with the same name. + + This is the main tool for pipeline management. It: + 1. Searches for an existing pipeline with the same name (or uses 'id' from extra_settings) + 2. Creates a new pipeline or updates the existing one + 3. Optionally starts a pipeline run with full refresh + 4. Optionally waits for the run to complete and returns detailed results + + Uses Unity Catalog and serverless compute by default. + + Args: + name: Pipeline name (used for lookup and creation) + root_path: Root folder for source code (added to Python sys.path for imports) + catalog: Unity Catalog name for output tables + schema: Schema name for output tables + workspace_file_paths: List of workspace file paths (raw .sql or .py files) + start_run: If True, start a pipeline update after create/update (default: False) + wait_for_completion: If True, wait for run to complete (default: False) + full_refresh: If True, perform full refresh when starting (default: True) + timeout: Maximum wait time in seconds (default: 1800 = 30 minutes) + extra_settings: Optional dict with additional pipeline settings. Supports all SDK + options: clusters, continuous, development, photon, edition, channel, event_log, + configuration, notifications, tags, serverless, etc. + If 'id' is provided, the pipeline will be updated instead of created. + Explicit parameters (name, root_path, catalog, schema) take precedence. + + Returns: + Dictionary with detailed status: + - pipeline_id: The pipeline ID + - pipeline_name: The pipeline name + - created: True if newly created, False if updated + - success: True if all operations succeeded + - state: Final state if run was started (COMPLETED, FAILED, etc.) + - duration_seconds: Time taken if waited + - error_message: Error message if failed + - errors: List of detailed errors if failed + - message: Human-readable status message + + Example usage: + # Just create/update the pipeline + create_or_update_pipeline(name="my_pipeline", ...) + + # Create/update and run immediately + create_or_update_pipeline(name="my_pipeline", ..., start_run=True) + + # Create/update, run, and wait for completion + create_or_update_pipeline( + name="my_pipeline", ..., + start_run=True, + wait_for_completion=True ) - return result - try: - wait_result = _wait_for_pipeline_update( - pipeline_id=pipeline_id, - update_id=update_id, - timeout=timeout, + # Create with custom settings (non-serverless, development mode) + create_or_update_pipeline( + name="my_pipeline", ..., + extra_settings={ + "serverless": False, + "development": True, + "clusters": [{"label": "default", "num_workers": 2}] + } ) - result["state"] = wait_result["state"] - result["success"] = wait_result["success"] - result["duration_seconds"] = wait_result["duration_seconds"] - result["status"] = "completed" if wait_result["success"] else "failed" - - if not wait_result["success"]: - result["errors"] = wait_result.get("errors", []) - if result["errors"]: - first_error = result["errors"][0] - error_msg = first_error.get("message", "") - if first_error.get("exceptions"): - exc = first_error["exceptions"][0] - error_msg = exc.get("message", error_msg) - result["error_message"] = error_msg - result["message"] = ( - f"Pipeline run failed with state: {result['state']}. " - f"Use get_pipeline(pipeline_id='{pipeline_id}') for full details." + """ + # Auto-inject default tags into extra_settings; user tags take precedence + extra_settings = extra_settings or {} + extra_settings.setdefault("tags", {}) + extra_settings["tags"] = {**get_default_tags(), **extra_settings["tags"]} + + result = _create_or_update_pipeline( + name=name, + root_path=root_path, + catalog=catalog, + schema=schema, + workspace_file_paths=workspace_file_paths, + start_run=start_run, + wait_for_completion=wait_for_completion, + full_refresh=full_refresh, + timeout=timeout, + extra_settings=extra_settings, + ) + + # Track resource on successful create/update + try: + result_dict = result.to_dict() + pipeline_id = result_dict.get("pipeline_id") + if pipeline_id: + from ..manifest import track_resource + + track_resource( + resource_type="pipeline", + name=name, + resource_id=pipeline_id, ) - else: - result["message"] = f"Pipeline completed successfully in {result['duration_seconds']}s." - - except TimeoutError as e: - result["state"] = "TIMEOUT" - result["success"] = False - result["status"] = "timeout" - result["error_message"] = str(e) - result["message"] = ( - f"Pipeline run timed out after {timeout}s. The pipeline may still be running. " - f"Use get_pipeline(pipeline_id='{pipeline_id}') to check status." - ) + except Exception: + pass # best-effort tracking - return result + return result.to_dict() + + +@mcp.tool +def find_pipeline_by_name(name: str) -> Dict[str, Any]: + """ + Find a pipeline by name and return its ID. + + Args: + name: Pipeline name to search for + + Returns: + Dictionary with: + - found: True if pipeline exists + - pipeline_id: Pipeline ID if found, None otherwise + """ + pipeline_id = _find_pipeline_by_name(name=name) + return { + "found": pipeline_id is not None, + "pipeline_id": pipeline_id, + } diff --git a/databricks-mcp-server/databricks_mcp_server/tools/sql.py b/databricks-mcp-server/databricks_mcp_server/tools/sql.py index 01355525..2be905d3 100644 --- a/databricks-mcp-server/databricks_mcp_server/tools/sql.py +++ b/databricks-mcp-server/databricks_mcp_server/tools/sql.py @@ -28,9 +28,8 @@ def execute_sql( If no warehouse_id is provided, automatically selects the best available warehouse. IMPORTANT: For creating or dropping schemas, catalogs, and volumes, use the - manage_uc_objects tool instead of SQL DDL. It handles resource tracking and - auto-tagging. Only use execute_sql for queries (SELECT, INSERT, UPDATE) and - table DDL (CREATE TABLE, DROP TABLE). + manage_uc_objects tool instead of SQL DDL. Only use execute_sql for queries + (SELECT, INSERT, UPDATE) and table DDL (CREATE TABLE, DROP TABLE). Args: sql_query: SQL query to execute @@ -67,9 +66,8 @@ def execute_sql_multi( in optimal order. Independent queries run in parallel. IMPORTANT: For creating or dropping schemas, catalogs, and volumes, use the - manage_uc_objects tool instead of SQL DDL. It handles resource tracking and - auto-tagging. Only use execute_sql/execute_sql_multi for queries (SELECT, - INSERT, UPDATE) and table DDL (CREATE TABLE, DROP TABLE). + manage_uc_objects tool instead of SQL DDL. Only use execute_sql/execute_sql_multi + for queries (SELECT, INSERT, UPDATE) and table DDL (CREATE TABLE, DROP TABLE). Args: sql_content: SQL content with multiple statements separated by ;