diff --git a/internal/db/queries.go b/internal/db/queries.go index a49fd99d..d21d0681 100644 --- a/internal/db/queries.go +++ b/internal/db/queries.go @@ -9,6 +9,20 @@ import ( "time" ) +// listSearchCondition returns a SQL condition and args for a free-text search. +// The title is searched as "#{number} {title}" so substring queries can match +// the number, the title, or both at once (e.g. "278" hits "#278 fix bug"). +// Author is matched separately. The alias is the table alias used in the +// surrounding query (e.g. "p" for merge requests, "i" for issues). +func listSearchCondition(alias, search string) (string, []any) { + like := "%" + search + "%" + cond := fmt.Sprintf( + "(('#' || %s.number || ' ' || %s.title) LIKE ? OR %s.author LIKE ?)", + alias, alias, alias, + ) + return cond, []any{like, like} +} + func sqlPlaceholders(count int) string { parts := make([]string, count) for i := range parts { @@ -1651,9 +1665,9 @@ func (d *DB) ListMergeRequests(ctx context.Context, opts ListMergeRequestsOpts) conds = append(conds, "s.number IS NOT NULL") } if opts.Search != "" { - conds = append(conds, "(p.title LIKE ? OR p.author LIKE ?)") - like := "%" + opts.Search + "%" - args = append(args, like, like) + cond, condArgs := listSearchCondition("p", opts.Search) + conds = append(conds, cond) + args = append(args, condArgs...) } where := "" @@ -2391,9 +2405,9 @@ func (d *DB) ListIssues( conds = append(conds, "s.number IS NOT NULL") } if opts.Search != "" { - conds = append(conds, "(i.title LIKE ? OR i.author LIKE ?)") - like := "%" + opts.Search + "%" - args = append(args, like, like) + cond, condArgs := listSearchCondition("i", opts.Search) + conds = append(conds, cond) + args = append(args, condArgs...) } where := "" diff --git a/internal/db/queries_test.go b/internal/db/queries_test.go index 337b8ecf..744b16b8 100644 --- a/internal/db/queries_test.go +++ b/internal/db/queries_test.go @@ -1479,6 +1479,34 @@ func TestListPullRequestsFilterBySearch(t *testing.T) { Assert.Equal(t, 1, prs[0].Number) } +func TestListPullRequestsFilterBySearchNumber(t *testing.T) { + require := require.New(t) + assert := Assert.New(t) + d := openTestDB(t) + + repoID := insertTestRepo(t, d, "owner", "repo") + base := baseTime() + + insertTestMR(t, d, repoID, 12, "add feature", base) + insertTestMR(t, d, repoID, 278, "fix bug", base.Add(time.Hour)) + insertTestMR(t, d, repoID, 290, "another change", base.Add(2*time.Hour)) + + prs, err := d.ListMergeRequests(t.Context(), ListMergeRequestsOpts{Search: "278"}) + require.NoError(err) + require.Len(prs, 1) + assert.Equal(278, prs[0].Number) + + prs, err = d.ListMergeRequests(t.Context(), ListMergeRequestsOpts{Search: "#278"}) + require.NoError(err) + require.Len(prs, 1) + assert.Equal(278, prs[0].Number) + + // Substring of number matches multiple. + prs, err = d.ListMergeRequests(t.Context(), ListMergeRequestsOpts{Search: "2"}) + require.NoError(err) + require.Len(prs, 3) +} + func TestListPullRequestsFilterByKanban(t *testing.T) { assert := Assert.New(t) require := require.New(t) @@ -2391,6 +2419,39 @@ func TestIssueRepoScopedQueriesCanonicalizeOwnerName(t *testing.T) { assert.Equal(issueID, gotID) } +func TestListIssuesFilterBySearch(t *testing.T) { + require := require.New(t) + assert := Assert.New(t) + d := openTestDB(t) + + repoID := insertTestRepo(t, d, "owner", "repo") + base := baseTime() + + insertTestIssue(t, d, repoID, 12, "report a bug", base) + insertTestIssue(t, d, repoID, 278, "filter broken", base.Add(time.Hour)) + insertTestIssue(t, d, repoID, 290, "another change", base.Add(2*time.Hour)) + + issues, err := d.ListIssues(t.Context(), ListIssuesOpts{Search: "broken"}) + require.NoError(err) + require.Len(issues, 1) + assert.Equal(278, issues[0].Number) + + issues, err = d.ListIssues(t.Context(), ListIssuesOpts{Search: "278"}) + require.NoError(err) + require.Len(issues, 1) + assert.Equal(278, issues[0].Number) + + issues, err = d.ListIssues(t.Context(), ListIssuesOpts{Search: "#278"}) + require.NoError(err) + require.Len(issues, 1) + assert.Equal(278, issues[0].Number) + + // Substring of number matches multiple. + issues, err = d.ListIssues(t.Context(), ListIssuesOpts{Search: "2"}) + require.NoError(err) + require.Len(issues, 3) +} + func TestReplaceIssueLabels_RejectsWrongRepoID(t *testing.T) { require := require.New(t) d := openTestDB(t) diff --git a/internal/server/api_test.go b/internal/server/api_test.go index 341e0c0b..93b2e797 100644 --- a/internal/server/api_test.go +++ b/internal/server/api_test.go @@ -695,6 +695,10 @@ func withSeedPRHeadRepoCloneURL(cloneURL string) seedPROpt { return func(pr *db.MergeRequest) { pr.HeadRepoCloneURL = cloneURL } } +func withSeedPRTitle(title string) seedPROpt { + return func(pr *db.MergeRequest) { pr.Title = title } +} + // seedPR inserts a repo and a PR into the DB, returning the PR's internal ID. func seedPR(t *testing.T, database *db.DB, owner, name string, number int, opts ...seedPROpt) int64 { t.Helper() @@ -5386,6 +5390,86 @@ func TestAPIListPullsStateFilter(t *testing.T) { require.Equal(http.StatusBadRequest, resp.StatusCode()) } +func TestAPIListPullsSearchByNumber(t *testing.T) { + require := require.New(t) + assert := Assert.New(t) + srv, database := setupTestServer(t) + ctx := t.Context() + + seedPR(t, database, "acme", "widget", 12, withSeedPRTitle("add feature")) + seedPR(t, database, "acme", "widget", 278, withSeedPRTitle("fix bug")) + seedPR(t, database, "acme", "widget", 290, withSeedPRTitle("another change")) + + client := setupTestClient(t, srv) + + pullNumbers := func(params *generated.ListPullsParams) []int { + t.Helper() + resp, err := client.HTTP.ListPullsWithResponse(ctx, params) + require.NoError(err) + require.Equal(http.StatusOK, resp.StatusCode()) + require.NotNil(resp.JSON200) + nums := make([]int, 0, len(*resp.JSON200)) + for _, pr := range *resp.JSON200 { + nums = append(nums, int(pr.Number)) + } + return nums + } + + q := "278" + assert.ElementsMatch([]int{278}, pullNumbers(&generated.ListPullsParams{Q: &q})) + + q = "#278" + assert.ElementsMatch([]int{278}, pullNumbers(&generated.ListPullsParams{Q: &q})) + + // Title still matches. + q = "fix" + assert.ElementsMatch([]int{278}, pullNumbers(&generated.ListPullsParams{Q: &q})) + + // Substring of number matches multiple. + q = "2" + assert.ElementsMatch([]int{12, 278, 290}, pullNumbers(&generated.ListPullsParams{Q: &q})) +} + +func TestAPIListIssuesSearchByNumber(t *testing.T) { + require := require.New(t) + assert := Assert.New(t) + srv, database := setupTestServer(t) + ctx := t.Context() + + seedIssueOnHost(t, database, "github.com", "acme", "widget", 12, "open", "report a bug") + seedIssueOnHost(t, database, "github.com", "acme", "widget", 278, "open", "filter broken") + seedIssueOnHost(t, database, "github.com", "acme", "widget", 290, "open", "another change") + + client := setupTestClient(t, srv) + + issueNumbers := func(params *generated.ListIssuesParams) []int { + t.Helper() + resp, err := client.HTTP.ListIssuesWithResponse(ctx, params) + require.NoError(err) + require.Equal(http.StatusOK, resp.StatusCode()) + require.NotNil(resp.JSON200) + nums := make([]int, 0, len(*resp.JSON200)) + for _, issue := range *resp.JSON200 { + nums = append(nums, int(issue.Number)) + } + return nums + } + + q := "278" + assert.ElementsMatch([]int{278}, issueNumbers(&generated.ListIssuesParams{Q: &q})) + + q = "#278" + assert.ElementsMatch([]int{278}, issueNumbers(&generated.ListIssuesParams{Q: &q})) + + // Title still matches. + q = "broken" + assert.ElementsMatch([]int{278}, issueNumbers(&generated.ListIssuesParams{Q: &q})) + + // Substring of number matches multiple. + q = "2" + assert.ElementsMatch([]int{12, 278, 290}, issueNumbers(&generated.ListIssuesParams{Q: &q})) +} + func TestAPIListPullsReportsBackfilledMergedPRFromMergedAt(t *testing.T) { assert := Assert.New(t) require := require.New(t)