Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 46 additions & 59 deletions crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,48 +56,48 @@ const LABELED_SECTIONS_INSTRUCTIONS: &str = indoc! {r#"
const NUMBERED_LINES_INSTRUCTIONS: &str = indoc! {r#"
# Instructions

You are a code completion assistant helping a programmer finish their work. Your task is to:
You are an edit prediction agent in a code editor.
Your job is to predict the next edit that the user will make,
based on their last few edits and their current cursor location.

1. Analyze the edit history to understand what the programmer is trying to achieve
2. Identify any incomplete refactoring or changes that need to be finished
3. Make the remaining edits that a human programmer would logically make next
4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere.
## Output Format

Focus on:
- Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs)
- Completing any partially-applied changes across the codebase
- Ensuring consistency with the programming style and patterns already established
- Making edits that maintain or improve code quality
- If the programmer started refactoring one instance of a pattern, find and update ALL similar instances
- Don't write a lot of code if you're not sure what to do

Rules:
- Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
- Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
- Write the edits in the unified diff format as shown in the example.

# Example output:
You must briefly explain your understanding of the user's goal, in one
or two sentences, and then specify their next edit in the form of a
unified diff, like this:

```
--- a/src/myapp/cli.py
+++ b/src/myapp/cli.py
@@ -1,3 +1,3 @@
-
-
-import sys
+import json
@@ ... @@
import os
import time
import sys
+from constants import LOG_LEVEL_WARNING
@@ ... @@
config.headless()
config.set_interactive(false)
-config.set_log_level(LOG_L)
+config.set_log_level(LOG_LEVEL_WARNING)
config.set_use_color(True)
```

# Edit History:
## Edit History

"#};

const UNIFIED_DIFF_REMINDER: &str = indoc! {"
---

Please analyze the edit history and the files, then provide the unified diff for your predicted edits.
Analyze the edit history and the files, then provide the unified diff for your predicted edits.
Do not include the cursor marker in your output.
If you're editing multiple files, be sure to reflect filename in the hunk's header.
Your diff should include edited file paths in its file headers (lines beginning with `---` and `+++`).
Do not include line numbers in the hunk headers, use `@@ ... @@`.
Removed lines begin with `-`.
Added lines begin with `+`.
Context lines begin with an extra space.
Context and removed lines are used to match the target edit location, so make sure to include enough of them
to uniquely identify it amongst all excerpts of code provided.
"};

pub fn build_prompt(
Expand All @@ -121,8 +121,7 @@ pub fn build_prompt(
EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
),
],
PromptFormat::LabeledSections => vec![(request.cursor_point, CURSOR_MARKER)],
PromptFormat::NumLinesUniDiff => {
PromptFormat::LabeledSections | PromptFormat::NumLinesUniDiff => {
vec![(request.cursor_point, CURSOR_MARKER)]
}
PromptFormat::OnlySnippets => vec![],
Expand All @@ -132,46 +131,31 @@ pub fn build_prompt(
PromptFormat::MarkedExcerpt => MARKED_EXCERPT_INSTRUCTIONS.to_string(),
PromptFormat::LabeledSections => LABELED_SECTIONS_INSTRUCTIONS.to_string(),
PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(),
// only intended for use via zeta_cli
PromptFormat::OnlySnippets => String::new(),
};

if request.events.is_empty() {
prompt.push_str("(No edit history)\n\n");
} else {
prompt.push_str(
"The following are the latest edits made by the user, from earlier to later.\n\n",
);
prompt.push_str("Here are the latest edits made by the user, from earlier to later.\n\n");
push_events(&mut prompt, &request.events);
}

prompt.push_str(indoc! {"
# Code Excerpts

The cursor marker <|user_cursor|> indicates the current user cursor position.
The file is in current state, edits from edit history have been applied.
"});

if request.prompt_format == PromptFormat::NumLinesUniDiff {
if request.referenced_declarations.is_empty() {
prompt.push_str(indoc! {"
# File under the cursor:

The cursor marker <|user_cursor|> indicates the current user cursor position.
The file is in current state, edits from edit history have been applied.
We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.

"});
} else {
// Note: This hasn't been trained on yet
prompt.push_str(indoc! {"
# Code Excerpts:

The cursor marker <|user_cursor|> indicates the current user cursor position.
Other excerpts of code from the project have been included as context based on their similarity to the code under the cursor.
Context excerpts are not guaranteed to be relevant, so use your own judgement.
Files are in their current state, edits from edit history have been applied.
We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.

"});
}
} else {
prompt.push_str("\n## Code\n\n");
prompt.push_str(indoc! {"
We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.
"});
}

prompt.push('\n');

let mut section_labels = Default::default();

if !request.referenced_declarations.is_empty() || !request.signatures.is_empty() {
Expand All @@ -198,8 +182,11 @@ pub fn build_prompt(
}
}

if request.prompt_format == PromptFormat::NumLinesUniDiff {
prompt.push_str(UNIFIED_DIFF_REMINDER);
match request.prompt_format {
PromptFormat::NumLinesUniDiff => {
prompt.push_str(UNIFIED_DIFF_REMINDER);
}
_ => {}
}

Ok((prompt, section_labels))
Expand Down
43 changes: 27 additions & 16 deletions crates/zeta2/src/zeta2.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use anyhow::{Context as _, Result, anyhow};
use anyhow::{Context as _, Result, anyhow, bail};
use chrono::TimeDelta;
use client::{Client, EditPredictionUsage, UserStore};
use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
ZED_VERSION_HEADER_NAME,
};
use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
use collections::HashMap;
use edit_prediction_context::{
DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
Expand Down Expand Up @@ -943,23 +943,34 @@ impl Zeta {

let (res, usage) = response?;
let request_id = EditPredictionId(res.id.clone().into());
let Some(output_text) = text_from_response(res) else {
let Some(mut output_text) = text_from_response(res) else {
return Ok((None, usage))
};

let (edited_buffer_snapshot, edits) =
crate::udiff::parse_diff(&output_text, |path| {
included_files
.iter()
.find_map(|(_, buffer, probe_path, ranges)| {
if probe_path.as_ref() == path {
Some((buffer, ranges.as_slice()))
} else {
None
}
})
})
.await?;
if output_text.contains(CURSOR_MARKER) {
log::trace!("Stripping out {CURSOR_MARKER} from response");
output_text = output_text.replace(CURSOR_MARKER, "");
}

let (edited_buffer_snapshot, edits) = match options.prompt_format {
PromptFormat::NumLinesUniDiff => {
crate::udiff::parse_diff(&output_text, |path| {
included_files
.iter()
.find_map(|(_, buffer, probe_path, ranges)| {
if probe_path.as_ref() == path {
Some((buffer, ranges.as_slice()))
} else {
None
}
})
})
.await?
}
_ => {
bail!("unsupported prompt format {}", options.prompt_format)
}
};

let edited_buffer = included_files
.iter()
Expand Down
2 changes: 1 addition & 1 deletion crates/zeta_cli/src/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub async fn run_evaluate_one(
);
as_json
} else {
zeta2_predict(example.clone(), &app_state, cx)
zeta2_predict(example.clone(), Default::default(), &app_state, cx)
.await
.unwrap()
};
Expand Down
14 changes: 12 additions & 2 deletions crates/zeta_cli/src/predict.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::PromptFormat;
use crate::example::{ActualExcerpt, NamedExample};
use crate::headless::ZetaCliAppState;
use crate::paths::LOGS_DIR;
use ::serde::Serialize;
use anyhow::{Result, anyhow};
use clap::Args;
// use cloud_llm_client::predict_edits_v3::PromptFormat;
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
use futures::StreamExt as _;
use gpui::{AppContext, AsyncApp};
Expand All @@ -19,9 +21,11 @@ use std::time::{Duration, Instant};

#[derive(Debug, Args)]
pub struct PredictArguments {
example_path: PathBuf,
#[arg(long, value_enum, default_value_t = PromptFormat::default())]
prompt_format: PromptFormat,
#[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
format: PredictionsOutputFormat,
example_path: PathBuf,
}

#[derive(clap::ValueEnum, Debug, Clone)]
Expand All @@ -36,7 +40,9 @@ pub async fn run_zeta2_predict(
cx: &mut AsyncApp,
) {
let example = NamedExample::load(args.example_path).unwrap();
let result = zeta2_predict(example, &app_state, cx).await.unwrap();
let result = zeta2_predict(example, args.prompt_format, &app_state, cx)
.await
.unwrap();
result.write(args.format, std::io::stdout()).unwrap();
}

Expand All @@ -46,6 +52,7 @@ thread_local! {

pub async fn zeta2_predict(
example: NamedExample,
prompt_format: PromptFormat,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<PredictionDetails> {
Expand Down Expand Up @@ -193,6 +200,9 @@ pub async fn zeta2_predict(
});

zeta.update(cx, |zeta, cx| {
let mut options = zeta.options().clone();
options.prompt_format = prompt_format.into();
zeta.set_options(options);
zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
})?
.await?;
Expand Down
Loading