diff --git a/cmd/msgvault/cmd/account_identity.go b/cmd/msgvault/cmd/account_identity.go new file mode 100644 index 00000000..40aab23e --- /dev/null +++ b/cmd/msgvault/cmd/account_identity.go @@ -0,0 +1,69 @@ +package cmd + +import ( + "fmt" + "io" + "strings" + + "github.com/wesm/msgvault/internal/store" +) + +// noDefaultIdentityHelp is the flag help text for --no-default-identity. +// Each ingest command registers its own bool variable and reuses this constant. +const noDefaultIdentityHelp = "Suppress auto-default-identity at account creation. " + + "Note: a one-time legacy [identity] config migration may still write confirmed " + + "identifiers to the account on first post-upgrade startup." + +// confirmDefaultIdentity writes one confirmed identifier to a freshly +// created source's identity. Best-effort: any error is logged and swallowed +// so a partially failed identity write never breaks ingest. Empty identifiers +// are a silent no-op. +// +// Skips the write when the source already has at least one identity row. +// add-account / add-imap / add-o365 / import-* commands all call this on +// every invocation (including reruns and rebinds), so without this guard +// an identity the user explicitly removed via `identity remove` would be +// re-added on the next ingest re-run, silently affecting dedup sent-copy +// detection. The guard preserves the documented "freshly created source" +// intent while degrading gracefully if the user has removed every +// identity (in which case the default is restored, which is desirable). +// +// **Ordering note:** ingest commands MUST call confirmDefaultIdentity +// BEFORE runPostSourceCreateMigrations on the same invocation. The +// legacy [identity] migration uses set-semantics merge, so calling the +// default-identity write first and the migration second produces the +// correct merged state. Calling them in the other order populates +// account_identities with the legacy addresses first, then the +// `len(existing) > 0` guard suppresses the source's own account +// identifier entirely (regression caught in iter15). See the per-ingest +// command order in addaccount.go etc. +// +// account is the user-facing account name shown in the confirmation message. +// Callers should gate this behind the per-command --no-default-identity flag. +func confirmDefaultIdentity(out io.Writer, s *store.Store, sourceID int64, account, identifier, signal string) { + id := strings.TrimSpace(identifier) + if id == "" { + return + } + existing, err := s.ListAccountIdentities(sourceID) + if err != nil { + logger.Warn("auto-default-identity precheck failed", + "source_id", sourceID, + "account", account, + "error", err.Error()) + return + } + if len(existing) > 0 { + return + } + if err := s.AddAccountIdentity(sourceID, id, signal); err != nil { + logger.Warn("auto-default-identity write failed", + "source_id", sourceID, + "account", account, + "identifier", id, + "signal", signal, + "error", err.Error()) + return + } + _, _ = fmt.Fprintf(out, "Confirmed identity %s on %s (signal: %s).\n", id, account, signal) +} diff --git a/cmd/msgvault/cmd/account_identity_test.go b/cmd/msgvault/cmd/account_identity_test.go new file mode 100644 index 00000000..8475b1bc --- /dev/null +++ b/cmd/msgvault/cmd/account_identity_test.go @@ -0,0 +1,122 @@ +package cmd + +import ( + "io" + "log/slog" + "path/filepath" + "testing" + + "github.com/wesm/msgvault/internal/store" +) + +func TestConfirmDefaultIdentity_HappyPath(t *testing.T) { + tmpDir := t.TempDir() + s, err := store.Open(filepath.Join(tmpDir, "msgvault.db")) + if err != nil { + t.Fatal(err) + } + defer func() { _ = s.Close() }() + if err := s.InitSchema(); err != nil { + t.Fatal(err) + } + + src, err := s.GetOrCreateSource("gmail", "alice@example.com") + if err != nil { + t.Fatal(err) + } + confirmDefaultIdentity(io.Discard, s, src.ID, "alice@example.com", "alice@example.com", "account-identifier") + rows, err := s.ListAccountIdentities(src.ID) + if err != nil { + t.Fatal(err) + } + if len(rows) != 1 || rows[0].Address != "alice@example.com" { + t.Fatalf("got %+v", rows) + } + if rows[0].SourceSignal != "account-identifier" { + t.Errorf("signal=%q", rows[0].SourceSignal) + } +} + +func TestConfirmDefaultIdentity_EmptyIdentifierIsNoOp(t *testing.T) { + tmpDir := t.TempDir() + s, err := store.Open(filepath.Join(tmpDir, "msgvault.db")) + if err != nil { + t.Fatal(err) + } + defer func() { _ = s.Close() }() + if err := s.InitSchema(); err != nil { + t.Fatal(err) + } + + src, err := s.GetOrCreateSource("gmail", "alice@example.com") + if err != nil { + t.Fatal(err) + } + confirmDefaultIdentity(io.Discard, s, src.ID, "alice@example.com", "", "account-identifier") + rows, _ := s.ListAccountIdentities(src.ID) + if len(rows) != 0 { + t.Errorf("want empty, got %+v", rows) + } +} + +func TestConfirmDefaultIdentity_StoreErrorDoesNotPanic(t *testing.T) { + tmpDir := t.TempDir() + s, err := store.Open(filepath.Join(tmpDir, "msgvault.db")) + if err != nil { + t.Fatal(err) + } + defer func() { _ = s.Close() }() + if err := s.InitSchema(); err != nil { + t.Fatal(err) + } + + savedLogger := logger + defer func() { logger = savedLogger }() + logger = slog.New(slog.NewTextHandler(io.Discard, nil)) + + prevDefault := slog.Default() + slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil))) + t.Cleanup(func() { slog.SetDefault(prevDefault) }) + + // sourceID 99999 does not exist; FK violation returns an error + // from AddAccountIdentity. The helper must swallow it. + confirmDefaultIdentity(io.Discard, s, 99999, "ghost@example.com", "ghost@example.com", "account-identifier") +} + +// TestConfirmDefaultIdentity_LegacyMigrationOverridesNoDefault pins the +// documented behavior: skipping confirmDefaultIdentity (simulating +// --no-default-identity) does NOT prevent MigrateLegacyIdentityConfig from +// writing the address. +func TestConfirmDefaultIdentity_LegacyMigrationOverridesNoDefault(t *testing.T) { + tmpDir := t.TempDir() + s, err := store.Open(filepath.Join(tmpDir, "msgvault.db")) + if err != nil { + t.Fatal(err) + } + defer func() { _ = s.Close() }() + if err := s.InitSchema(); err != nil { + t.Fatal(err) + } + + _, err = s.GetOrCreateSource("gmail", "alice@example.com") + if err != nil { + t.Fatal(err) + } + // Simulate --no-default-identity: do not call confirmDefaultIdentity. + // Then run startup migrations with a non-empty legacy address list. + applied, _, _, _, err := s.MigrateLegacyIdentityConfig([]string{"alice@example.com"}) + if err != nil { + t.Fatal(err) + } + if !applied { + t.Fatal("migration did not apply") + } + src, err := s.GetOrCreateSource("gmail", "alice@example.com") + if err != nil { + t.Fatal(err) + } + rows, _ := s.ListAccountIdentities(src.ID) + if len(rows) != 1 { + t.Fatalf("legacy migration should have written, got %+v", rows) + } +} diff --git a/cmd/msgvault/cmd/account_scope.go b/cmd/msgvault/cmd/account_scope.go new file mode 100644 index 00000000..1f6a350e --- /dev/null +++ b/cmd/msgvault/cmd/account_scope.go @@ -0,0 +1,138 @@ +package cmd + +import ( + "errors" + "fmt" + + "github.com/wesm/msgvault/internal/store" +) + +// Scope is the result of resolving a user-supplied --account or +// --collection flag against the store. +type Scope struct { + Input string + Source *store.Source + Collection *store.CollectionWithSources +} + +// IsEmpty reports whether the scope resolved to nothing. +func (s Scope) IsEmpty() bool { + return s.Source == nil && s.Collection == nil +} + +// IsCollection reports whether the scope refers to a collection. +func (s Scope) IsCollection() bool { + return s.Collection != nil +} + +// SourceIDs returns the source IDs that this scope expands to. +func (s Scope) SourceIDs() []int64 { + switch { + case s.Collection != nil: + return append([]int64(nil), s.Collection.SourceIDs...) + case s.Source != nil: + return []int64{s.Source.ID} + } + return nil +} + +// DisplayName returns a human-readable label for the scope. +func (s Scope) DisplayName() string { + switch { + case s.Collection != nil: + return s.Collection.Name + case s.Source != nil: + return s.Source.Identifier + } + return "" +} + +// ResolveAccountFlag resolves the value of an --account flag. +// It rejects collection names with a hint to use --collection. +func ResolveAccountFlag(st *store.Store, input string) (Scope, error) { + scope := Scope{Input: input} + if input == "" { + return scope, nil + } + + // Try source resolution first. + sources, err := st.GetSourcesByIdentifierOrDisplayName(input) + if err != nil { + return scope, fmt.Errorf("look up source for %q: %w", input, err) + } + if len(sources) > 1 { + names := make([]string, 0, len(sources)) + for _, s := range sources { + names = append(names, fmt.Sprintf( + "%s (%s, id=%d)", + s.Identifier, s.SourceType, s.ID, + )) + } + return scope, fmt.Errorf( + "ambiguous account %q matches multiple sources: %v", + input, names, + ) + } + if len(sources) == 1 { + scope.Source = sources[0] + return scope, nil + } + + // No source match — check whether a collection exists with this name and + // reject with a helpful hint. + _, cerr := st.GetCollectionByName(input) + switch { + case cerr == nil: + return scope, fmt.Errorf( + "%q is a collection, not an account; use --collection %s", + input, input, + ) + case errors.Is(cerr, store.ErrCollectionNotFound): + // Neither a source nor a collection. + default: + return scope, fmt.Errorf("look up collection %q: %w", input, cerr) + } + + return scope, fmt.Errorf( + "no account found for %q (try 'msgvault list-accounts')", + input, + ) +} + +// ResolveCollectionFlag resolves the value of a --collection flag. +// It rejects account identifiers with a hint to use --account. +func ResolveCollectionFlag(st *store.Store, input string) (Scope, error) { + scope := Scope{Input: input} + if input == "" { + return scope, nil + } + + // Try collection resolution first. + coll, err := st.GetCollectionByName(input) + switch { + case err == nil: + scope.Collection = coll + return scope, nil + case errors.Is(err, store.ErrCollectionNotFound): + // Fall through to source check. + default: + return scope, fmt.Errorf("look up collection %q: %w", input, err) + } + + // No collection found — check whether any source matches and reject with a hint. + sources, serr := st.GetSourcesByIdentifierOrDisplayName(input) + if serr != nil { + return scope, fmt.Errorf("look up source for %q: %w", input, serr) + } + if len(sources) >= 1 { + return scope, fmt.Errorf( + "%q is an account, not a collection; use --account %s", + input, input, + ) + } + + return scope, fmt.Errorf( + "no collection named %q (try 'msgvault collection list')", + input, + ) +} diff --git a/cmd/msgvault/cmd/account_scope_test.go b/cmd/msgvault/cmd/account_scope_test.go new file mode 100644 index 00000000..b1e261f9 --- /dev/null +++ b/cmd/msgvault/cmd/account_scope_test.go @@ -0,0 +1,177 @@ +package cmd + +import ( + "strings" + "testing" + + "github.com/wesm/msgvault/internal/testutil" + "github.com/wesm/msgvault/internal/testutil/storetest" +) + +// setupScopeFixture creates a store with one source and one collection for +// resolver tests. Returns the store, source identifier, and collection name. +func setupScopeFixture(t *testing.T) ( + f *storetest.Fixture, + accountID string, + collectionName string, +) { + t.Helper() + f = storetest.New(t) + // f.Source is "test@example.com" / gmail, created by storetest.New. + accountID = f.Source.Identifier // "test@example.com" + + collectionName = "inbox-collection" + _, err := f.Store.CreateCollection(collectionName, "", []int64{f.Source.ID}) + testutil.MustNoErr(t, err, "CreateCollection") + + return f, accountID, collectionName +} + +func TestResolveAccountFlag_EmptyInput(t *testing.T) { + f, _, _ := setupScopeFixture(t) + + scope, err := ResolveAccountFlag(f.Store, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !scope.IsEmpty() { + t.Errorf("expected empty scope, got source=%v collection=%v", + scope.Source, scope.Collection) + } +} + +func TestResolveCollectionFlag_EmptyInput(t *testing.T) { + f, _, _ := setupScopeFixture(t) + + scope, err := ResolveCollectionFlag(f.Store, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !scope.IsEmpty() { + t.Errorf("expected empty scope, got source=%v collection=%v", + scope.Source, scope.Collection) + } +} + +func TestResolveAccountFlag_ValidAccount(t *testing.T) { + f, accountID, _ := setupScopeFixture(t) + + scope, err := ResolveAccountFlag(f.Store, accountID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if scope.Source == nil { + t.Fatal("expected Source to be populated") + } + if scope.Source.Identifier != accountID { + t.Errorf("source identifier = %q, want %q", scope.Source.Identifier, accountID) + } + if scope.Collection != nil { + t.Error("expected Collection to be nil") + } +} + +func TestResolveCollectionFlag_ValidCollection(t *testing.T) { + f, _, collectionName := setupScopeFixture(t) + + scope, err := ResolveCollectionFlag(f.Store, collectionName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if scope.Collection == nil { + t.Fatal("expected Collection to be populated") + } + if scope.Collection.Name != collectionName { + t.Errorf("collection name = %q, want %q", scope.Collection.Name, collectionName) + } + if scope.Source != nil { + t.Error("expected Source to be nil") + } +} + +func TestResolveAccountFlag_RejectsCollectionName(t *testing.T) { + f, _, collectionName := setupScopeFixture(t) + + _, err := ResolveAccountFlag(f.Store, collectionName) + if err == nil { + t.Fatal("expected error for collection name passed as --account") + } + msg := err.Error() + if !strings.Contains(msg, "is a collection") { + t.Errorf("error should contain 'is a collection': %q", msg) + } + if !strings.Contains(msg, "--collection") { + t.Errorf("error should contain '--collection': %q", msg) + } +} + +func TestResolveCollectionFlag_RejectsAccountIdentifier(t *testing.T) { + f, accountID, _ := setupScopeFixture(t) + + _, err := ResolveCollectionFlag(f.Store, accountID) + if err == nil { + t.Fatal("expected error for account identifier passed as --collection") + } + msg := err.Error() + if !strings.Contains(msg, "is an account") { + t.Errorf("error should contain 'is an account': %q", msg) + } + if !strings.Contains(msg, "--account") { + t.Errorf("error should contain '--account': %q", msg) + } +} + +// TestResolveAccountFlag_BothExist verifies the tie-break rule: when a name +// exists as both an account and a collection, --account resolves the account. +func TestResolveAccountFlag_BothExist(t *testing.T) { + f := storetest.New(t) + + // Create a second source whose identifier matches our collection name. + sharedName := "shared-name" + src2, err := f.Store.GetOrCreateSource("mbox", sharedName) + testutil.MustNoErr(t, err, "GetOrCreateSource") + + _, err = f.Store.CreateCollection(sharedName, "", []int64{f.Source.ID}) + testutil.MustNoErr(t, err, "CreateCollection") + + scope, err := ResolveAccountFlag(f.Store, sharedName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if scope.Source == nil { + t.Fatal("expected Source to be populated") + } + if scope.Source.ID != src2.ID { + t.Errorf("source ID = %d, want %d", scope.Source.ID, src2.ID) + } + if scope.Collection != nil { + t.Error("expected Collection to be nil when resolving as --account") + } +} + +// TestResolveCollectionFlag_BothExist verifies that when a name exists as both +// an account and a collection, --collection resolves the collection. +func TestResolveCollectionFlag_BothExist(t *testing.T) { + f := storetest.New(t) + + sharedName := "shared-name" + _, err := f.Store.GetOrCreateSource("mbox", sharedName) + testutil.MustNoErr(t, err, "GetOrCreateSource") + + _, err = f.Store.CreateCollection(sharedName, "", []int64{f.Source.ID}) + testutil.MustNoErr(t, err, "CreateCollection") + + scope, err := ResolveCollectionFlag(f.Store, sharedName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if scope.Collection == nil { + t.Fatal("expected Collection to be populated") + } + if scope.Collection.Name != sharedName { + t.Errorf("collection name = %q, want %q", scope.Collection.Name, sharedName) + } + if scope.Source != nil { + t.Error("expected Source to be nil when resolving as --collection") + } +} diff --git a/cmd/msgvault/cmd/addaccount.go b/cmd/msgvault/cmd/addaccount.go index 70c7f526..3c250e4d 100644 --- a/cmd/msgvault/cmd/addaccount.go +++ b/cmd/msgvault/cmd/addaccount.go @@ -11,10 +11,11 @@ import ( ) var ( - headless bool - accountDisplayName string - forceReauth bool - oauthAppName string + headless bool + accountDisplayName string + forceReauth bool + oauthAppName string + noDefaultIdentityAddAccount bool ) var addAccountCmd = &cobra.Command{ @@ -61,6 +62,9 @@ Examples: if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrationsForIngest(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } // Look up existing source to detect binding changes existingSource, err := findGmailSource(s, email) @@ -214,6 +218,18 @@ Examples: return fmt.Errorf("set display name: %w", err) } } + // Auto-default-identity must run BEFORE the legacy migration + // retry (runPostSourceCreateMigrations). The migration's + // set-semantics merge handles the case where the legacy + // [identity] block contains the same address. Reverse order + // would leave the source without its own account identifier + // because confirmDefaultIdentity skips on any existing rows. + if !noDefaultIdentityAddAccount { + confirmDefaultIdentity(cmd.OutOrStdout(), s, source.ID, email, email, "account-identifier") + } + if err := runPostSourceCreateMigrations(s); err != nil { + return fmt.Errorf("post-source-create migrations: %w", err) + } if bindingChanged { fmt.Printf("Account %s: OAuth app binding updated to %q.\n", email, resolvedApp) } else { @@ -272,6 +288,14 @@ Examples: return fmt.Errorf("set display name: %w", err) } } + // Auto-default-identity must run BEFORE the legacy migration + // retry — see comment on the token-reusable path above. + if !noDefaultIdentityAddAccount { + confirmDefaultIdentity(cmd.OutOrStdout(), s, source.ID, email, email, "account-identifier") + } + if err := runPostSourceCreateMigrations(s); err != nil { + return fmt.Errorf("post-source-create migrations: %w", err) + } fmt.Printf("\nAccount %s authorized successfully!\n", email) fmt.Println("You can now run: msgvault sync-full", email) @@ -300,5 +324,6 @@ func init() { addAccountCmd.Flags().BoolVar(&forceReauth, "force", false, "Delete existing token and re-authorize") addAccountCmd.Flags().StringVar(&accountDisplayName, "display-name", "", "Display name for the account (e.g., \"Work\", \"Personal\")") addAccountCmd.Flags().StringVar(&oauthAppName, "oauth-app", "", "Named OAuth app from config (for Google Workspace orgs)") + addAccountCmd.Flags().BoolVar(&noDefaultIdentityAddAccount, "no-default-identity", false, noDefaultIdentityHelp) rootCmd.AddCommand(addAccountCmd) } diff --git a/cmd/msgvault/cmd/addaccount_test.go b/cmd/msgvault/cmd/addaccount_test.go index bcb6d3fa..802003fc 100644 --- a/cmd/msgvault/cmd/addaccount_test.go +++ b/cmd/msgvault/cmd/addaccount_test.go @@ -142,6 +142,7 @@ func TestAddAccount_InheritedBindingValidatesToken(t *testing.T) { testCmd.Flags().BoolVar(&headless, "headless", false, "") testCmd.Flags().BoolVar(&forceReauth, "force", false, "") testCmd.Flags().StringVar(&accountDisplayName, "display-name", "", "") + testCmd.Flags().BoolVar(&noDefaultIdentityAddAccount, "no-default-identity", false, "") root := newTestRootCmd() root.AddCommand(testCmd) @@ -242,6 +243,7 @@ func TestAddAccount_RebindWithExistingToken(t *testing.T) { testCmd.Flags().BoolVar(&headless, "headless", false, "") testCmd.Flags().BoolVar(&forceReauth, "force", false, "") testCmd.Flags().StringVar(&accountDisplayName, "display-name", "", "") + testCmd.Flags().BoolVar(&noDefaultIdentityAddAccount, "no-default-identity", false, "") root := newTestRootCmd() root.AddCommand(testCmd) @@ -345,6 +347,7 @@ func TestAddAccount_NewRegistrationRejectsMismatchedToken(t *testing.T) { testCmd.Flags().BoolVar(&headless, "headless", false, "") testCmd.Flags().BoolVar(&forceReauth, "force", false, "") testCmd.Flags().StringVar(&accountDisplayName, "display-name", "", "") + testCmd.Flags().BoolVar(&noDefaultIdentityAddAccount, "no-default-identity", false, "") root := newTestRootCmd() root.AddCommand(testCmd) @@ -417,6 +420,7 @@ func TestAddAccount_ExplicitDefaultRejectsMismatchedToken(t *testing.T) { testCmd.Flags().BoolVar(&headless, "headless", false, "") testCmd.Flags().BoolVar(&forceReauth, "force", false, "") testCmd.Flags().StringVar(&accountDisplayName, "display-name", "", "") + testCmd.Flags().BoolVar(&noDefaultIdentityAddAccount, "no-default-identity", false, "") root := newTestRootCmd() root.AddCommand(testCmd) @@ -485,6 +489,7 @@ func TestAddAccount_ExplicitDefaultAcceptsMatchingToken(t *testing.T) { testCmd.Flags().BoolVar(&headless, "headless", false, "") testCmd.Flags().BoolVar(&forceReauth, "force", false, "") testCmd.Flags().StringVar(&accountDisplayName, "display-name", "", "") + testCmd.Flags().BoolVar(&noDefaultIdentityAddAccount, "no-default-identity", false, "") // Pre-cancel so if regression causes auth attempt, it fails fast // instead of opening a browser. @@ -570,6 +575,7 @@ func TestAddAccount_ForceRebindPreservesBindingOnFailure(t *testing.T) { testCmd.Flags().BoolVar(&headless, "headless", false, "") testCmd.Flags().BoolVar(&forceReauth, "force", false, "") testCmd.Flags().StringVar(&accountDisplayName, "display-name", "", "") + testCmd.Flags().BoolVar(&noDefaultIdentityAddAccount, "no-default-identity", false, "") root := newTestRootCmd() root.AddCommand(testCmd) @@ -665,6 +671,7 @@ func TestAddAccount_HeadlessExplicitEmptyOAuthApp(t *testing.T) { testCmd.Flags().BoolVar(&headless, "headless", false, "") testCmd.Flags().BoolVar(&forceReauth, "force", false, "") testCmd.Flags().StringVar(&accountDisplayName, "display-name", "", "") + testCmd.Flags().BoolVar(&noDefaultIdentityAddAccount, "no-default-identity", false, "") root := newTestRootCmd() root.AddCommand(testCmd) @@ -692,6 +699,419 @@ func TestAddAccount_HeadlessExplicitEmptyOAuthApp(t *testing.T) { } } +// TestAddAccount_AutoDefaultIdentityFires verifies that running add-account +// with a reusable token writes an account-identifier identity row. +func TestAddAccount_AutoDefaultIdentityFires(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "msgvault.db") + + tokensDir := filepath.Join(tmpDir, "tokens") + if err := os.MkdirAll(tokensDir, 0700); err != nil { + t.Fatalf("mkdir tokens: %v", err) + } + tokenData, _ := json.Marshal(map[string]string{ + "access_token": "fake-access", + "refresh_token": "fake-refresh", + "token_type": "Bearer", + "client_id": "test.apps.googleusercontent.com", + }) + if err := os.WriteFile(filepath.Join(tokensDir, "user@example.com.json"), tokenData, 0600); err != nil { + t.Fatalf("write token: %v", err) + } + + secretsPath := filepath.Join(tmpDir, "secret.json") + if err := os.WriteFile(secretsPath, []byte(fakeClientSecrets), 0600); err != nil { + t.Fatalf("write secrets: %v", err) + } + + savedCfg := cfg + savedLogger := logger + savedOAuthApp := oauthAppName + savedNoDefault := noDefaultIdentityAddAccount + defer func() { + cfg = savedCfg + logger = savedLogger + oauthAppName = savedOAuthApp + noDefaultIdentityAddAccount = savedNoDefault + }() + + cfg = &config.Config{ + HomeDir: tmpDir, + Data: config.DataConfig{DataDir: tmpDir}, + OAuth: config.OAuthConfig{ClientSecrets: secretsPath}, + } + logger = slog.New(slog.NewTextHandler(os.Stderr, nil)) + + testCmd := &cobra.Command{ + Use: "add-account ", + Args: cobra.ExactArgs(1), + RunE: addAccountCmd.RunE, + } + testCmd.Flags().StringVar(&oauthAppName, "oauth-app", "", "") + testCmd.Flags().BoolVar(&headless, "headless", false, "") + testCmd.Flags().BoolVar(&forceReauth, "force", false, "") + testCmd.Flags().StringVar(&accountDisplayName, "display-name", "", "") + testCmd.Flags().BoolVar(&noDefaultIdentityAddAccount, "no-default-identity", false, "") + + root := newTestRootCmd() + root.AddCommand(testCmd) + root.SetArgs([]string{"add-account", "user@example.com"}) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + s, err := store.Open(dbPath) + if err != nil { + t.Fatalf("reopen store: %v", err) + } + defer func() { _ = s.Close() }() + + src, err := findGmailSource(s, "user@example.com") + if err != nil { + t.Fatalf("find source: %v", err) + } + if src == nil { + t.Fatal("source not found") + } + + ids, err := s.ListAccountIdentities(src.ID) + if err != nil { + t.Fatalf("ListAccountIdentities: %v", err) + } + if len(ids) != 1 { + t.Fatalf("expected 1 identity row, got %d", len(ids)) + } + if ids[0].Address != "user@example.com" { + t.Errorf("address = %q, want user@example.com", ids[0].Address) + } + if ids[0].SourceSignal != "account-identifier" { + t.Errorf("source_signal = %q, want account-identifier", ids[0].SourceSignal) + } +} + +// TestAddAccount_NoDefaultIdentitySuppresses verifies that --no-default-identity +// prevents the auto-identity write. +func TestAddAccount_NoDefaultIdentitySuppresses(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "msgvault.db") + + tokensDir := filepath.Join(tmpDir, "tokens") + if err := os.MkdirAll(tokensDir, 0700); err != nil { + t.Fatalf("mkdir tokens: %v", err) + } + tokenData, _ := json.Marshal(map[string]string{ + "access_token": "fake-access", + "refresh_token": "fake-refresh", + "token_type": "Bearer", + "client_id": "test.apps.googleusercontent.com", + }) + if err := os.WriteFile(filepath.Join(tokensDir, "user@example.com.json"), tokenData, 0600); err != nil { + t.Fatalf("write token: %v", err) + } + + secretsPath := filepath.Join(tmpDir, "secret.json") + if err := os.WriteFile(secretsPath, []byte(fakeClientSecrets), 0600); err != nil { + t.Fatalf("write secrets: %v", err) + } + + savedCfg := cfg + savedLogger := logger + savedOAuthApp := oauthAppName + savedNoDefault := noDefaultIdentityAddAccount + defer func() { + cfg = savedCfg + logger = savedLogger + oauthAppName = savedOAuthApp + noDefaultIdentityAddAccount = savedNoDefault + }() + + cfg = &config.Config{ + HomeDir: tmpDir, + Data: config.DataConfig{DataDir: tmpDir}, + OAuth: config.OAuthConfig{ClientSecrets: secretsPath}, + } + logger = slog.New(slog.NewTextHandler(os.Stderr, nil)) + + testCmd := &cobra.Command{ + Use: "add-account ", + Args: cobra.ExactArgs(1), + RunE: addAccountCmd.RunE, + } + testCmd.Flags().StringVar(&oauthAppName, "oauth-app", "", "") + testCmd.Flags().BoolVar(&headless, "headless", false, "") + testCmd.Flags().BoolVar(&forceReauth, "force", false, "") + testCmd.Flags().StringVar(&accountDisplayName, "display-name", "", "") + testCmd.Flags().BoolVar(&noDefaultIdentityAddAccount, "no-default-identity", false, "") + + root := newTestRootCmd() + root.AddCommand(testCmd) + root.SetArgs([]string{"add-account", "user@example.com", "--no-default-identity"}) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + s, err := store.Open(dbPath) + if err != nil { + t.Fatalf("reopen store: %v", err) + } + defer func() { _ = s.Close() }() + + src, err := findGmailSource(s, "user@example.com") + if err != nil { + t.Fatalf("find source: %v", err) + } + if src == nil { + t.Fatal("source not found") + } + + ids, err := s.ListAccountIdentities(src.ID) + if err != nil { + t.Fatalf("ListAccountIdentities: %v", err) + } + if len(ids) != 0 { + t.Fatalf("expected 0 identity rows with --no-default-identity, got %d", len(ids)) + } +} + +// TestAddAccount_DeferredLegacyIdentityMigrationFires verifies that legacy +// [identity] addresses configured before any source exists are migrated +// onto the first source created in the same add-account invocation. +// Regression test for iter10: previously, runStartupMigrations ran before +// GetOrCreateSource, so the deferred migration parked at startup and only +// applied on the *next* command — leaving the new source without its +// configured identities until then. +func TestAddAccount_DeferredLegacyIdentityMigrationFires(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "msgvault.db") + + tokensDir := filepath.Join(tmpDir, "tokens") + if err := os.MkdirAll(tokensDir, 0700); err != nil { + t.Fatalf("mkdir tokens: %v", err) + } + tokenData, _ := json.Marshal(map[string]string{ + "access_token": "fake-access", + "refresh_token": "fake-refresh", + "token_type": "Bearer", + "client_id": "test.apps.googleusercontent.com", + }) + if err := os.WriteFile(filepath.Join(tokensDir, "user@example.com.json"), tokenData, 0600); err != nil { + t.Fatalf("write token: %v", err) + } + + secretsPath := filepath.Join(tmpDir, "secret.json") + if err := os.WriteFile(secretsPath, []byte(fakeClientSecrets), 0600); err != nil { + t.Fatalf("write secrets: %v", err) + } + + savedCfg := cfg + savedLogger := logger + savedOAuthApp := oauthAppName + savedNoDefault := noDefaultIdentityAddAccount + defer func() { + cfg = savedCfg + logger = savedLogger + oauthAppName = savedOAuthApp + noDefaultIdentityAddAccount = savedNoDefault + }() + + cfg = &config.Config{ + HomeDir: tmpDir, + Data: config.DataConfig{DataDir: tmpDir}, + OAuth: config.OAuthConfig{ClientSecrets: secretsPath}, + Identity: config.IdentityConfig{ + Addresses: []string{"alias@example.com", "alt@work.com"}, + }, + } + var logBuf strings.Builder + logger = slog.New(slog.NewTextHandler(&logBuf, nil)) + + testCmd := &cobra.Command{ + Use: "add-account ", + Args: cobra.ExactArgs(1), + RunE: addAccountCmd.RunE, + } + testCmd.Flags().StringVar(&oauthAppName, "oauth-app", "", "") + testCmd.Flags().BoolVar(&headless, "headless", false, "") + testCmd.Flags().BoolVar(&forceReauth, "force", false, "") + testCmd.Flags().StringVar(&accountDisplayName, "display-name", "", "") + testCmd.Flags().BoolVar(&noDefaultIdentityAddAccount, "no-default-identity", false, "") + + root := newTestRootCmd() + root.AddCommand(testCmd) + // --no-default-identity isolates the test to the legacy migration path: + // the auto-default would otherwise add a third identity row. + root.SetArgs([]string{"add-account", "user@example.com", "--no-default-identity"}) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // The user-facing notice must only describe the applied path. + // Emitting the "deferred — will run on the next command" notice + // inside an invocation that DID apply the migration is misleading + // and a regression of the iter10 polish fix. + logs := logBuf.String() + if strings.Contains(logs, "migration deferred until a source exists") { + t.Errorf("deferred notice fired in same invocation that applied the migration; logs:\n%s", logs) + } + if !strings.Contains(logs, "legacy identity migrated") { + t.Errorf("expected applied notice in logs; got:\n%s", logs) + } + + s, err := store.Open(dbPath) + if err != nil { + t.Fatalf("reopen store: %v", err) + } + defer func() { _ = s.Close() }() + + src, err := findGmailSource(s, "user@example.com") + if err != nil { + t.Fatalf("find source: %v", err) + } + if src == nil { + t.Fatal("source not found") + } + + ids, err := s.ListAccountIdentities(src.ID) + if err != nil { + t.Fatalf("ListAccountIdentities: %v", err) + } + if len(ids) != 2 { + t.Fatalf("expected 2 legacy-migrated identity rows on first invocation, got %d: %+v", len(ids), ids) + } + got := map[string]string{ids[0].Address: ids[0].SourceSignal, ids[1].Address: ids[1].SourceSignal} + for _, addr := range []string{"alias@example.com", "alt@work.com"} { + signal, ok := got[addr] + if !ok { + t.Errorf("missing identity row for %q (have %+v)", addr, got) + continue + } + if signal != "config_migration" { + t.Errorf("address %q: source_signal = %q, want config_migration", addr, signal) + } + } + + applied, err := s.IsMigrationApplied("legacy_identity_to_per_account") + if err != nil { + t.Fatalf("IsMigrationApplied: %v", err) + } + if !applied { + t.Error("migration sentinel should be set after first successful add-account") + } +} + +// TestAddAccount_LegacyMigrationDoesNotSuppressDefaultIdentity verifies +// that the legacy [identity] migration writing rows DOES NOT suppress +// the auto-default-identity write for the source's own account +// identifier. Regression test for iter15 codex Medium: previously, +// runPostSourceCreateMigrations ran BEFORE confirmDefaultIdentity, so +// the legacy migration populated account_identities first, then +// confirmDefaultIdentity's `len(existing) > 0` guard skipped the +// account-identifier write entirely — leaving the source without its +// own identifier and breaking dedup sent-copy detection. +func TestAddAccount_LegacyMigrationDoesNotSuppressDefaultIdentity(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "msgvault.db") + + tokensDir := filepath.Join(tmpDir, "tokens") + if err := os.MkdirAll(tokensDir, 0700); err != nil { + t.Fatalf("mkdir tokens: %v", err) + } + tokenData, _ := json.Marshal(map[string]string{ + "access_token": "fake-access", + "refresh_token": "fake-refresh", + "token_type": "Bearer", + "client_id": "test.apps.googleusercontent.com", + }) + if err := os.WriteFile(filepath.Join(tokensDir, "user@example.com.json"), tokenData, 0600); err != nil { + t.Fatalf("write token: %v", err) + } + + secretsPath := filepath.Join(tmpDir, "secret.json") + if err := os.WriteFile(secretsPath, []byte(fakeClientSecrets), 0600); err != nil { + t.Fatalf("write secrets: %v", err) + } + + savedCfg := cfg + savedLogger := logger + savedOAuthApp := oauthAppName + savedNoDefault := noDefaultIdentityAddAccount + defer func() { + cfg = savedCfg + logger = savedLogger + oauthAppName = savedOAuthApp + noDefaultIdentityAddAccount = savedNoDefault + }() + + cfg = &config.Config{ + HomeDir: tmpDir, + Data: config.DataConfig{DataDir: tmpDir}, + OAuth: config.OAuthConfig{ClientSecrets: secretsPath}, + // Legacy [identity] block with two addresses, neither of which + // is the account being added. The migration must fire BUT also + // the auto-default identity must be written for user@example.com. + Identity: config.IdentityConfig{ + Addresses: []string{"alias@example.com", "alt@work.com"}, + }, + } + logger = slog.New(slog.NewTextHandler(os.Stderr, nil)) + + testCmd := &cobra.Command{ + Use: "add-account ", + Args: cobra.ExactArgs(1), + RunE: addAccountCmd.RunE, + } + testCmd.Flags().StringVar(&oauthAppName, "oauth-app", "", "") + testCmd.Flags().BoolVar(&headless, "headless", false, "") + testCmd.Flags().BoolVar(&forceReauth, "force", false, "") + testCmd.Flags().StringVar(&accountDisplayName, "display-name", "", "") + testCmd.Flags().BoolVar(&noDefaultIdentityAddAccount, "no-default-identity", false, "") + + root := newTestRootCmd() + root.AddCommand(testCmd) + // Note: NOT passing --no-default-identity. The bug only manifests + // when the auto-default write is supposed to fire. + root.SetArgs([]string{"add-account", "user@example.com"}) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + s, err := store.Open(dbPath) + if err != nil { + t.Fatalf("reopen store: %v", err) + } + defer func() { _ = s.Close() }() + + src, err := findGmailSource(s, "user@example.com") + if err != nil { + t.Fatalf("find source: %v", err) + } + if src == nil { + t.Fatal("source not found") + } + + ids, err := s.ListAccountIdentities(src.ID) + if err != nil { + t.Fatalf("ListAccountIdentities: %v", err) + } + // Want 3 rows: 2 legacy-migrated + 1 account-identifier. + if len(ids) != 3 { + t.Fatalf("expected 3 identity rows (2 legacy + 1 account-identifier), got %d: %+v", len(ids), ids) + } + got := make(map[string]bool, len(ids)) + for _, ai := range ids { + got[ai.Address] = true + } + for _, want := range []string{"alias@example.com", "alt@work.com", "user@example.com"} { + if !got[want] { + t.Errorf("missing identity row for %q (have %v)", want, got) + } + } +} + func TestAddAccount_HeadlessServiceAccountReturnsActionableError(t *testing.T) { tmpDir := t.TempDir() diff --git a/cmd/msgvault/cmd/addimap.go b/cmd/msgvault/cmd/addimap.go index 2a548843..bc9041dd 100644 --- a/cmd/msgvault/cmd/addimap.go +++ b/cmd/msgvault/cmd/addimap.go @@ -49,11 +49,12 @@ func choosePasswordStrategy( } var ( - imapHost string - imapPort int - imapUsername string - imapNoTLS bool - imapSTARTTLS bool + imapHost string + imapPort int + imapUsername string + imapNoTLS bool + imapSTARTTLS bool + noDefaultIdentityAddImap bool ) var addIMAPCmd = &cobra.Command{ @@ -148,6 +149,9 @@ Examples: if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrationsForIngest(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } // Build identifier and save credentials identifier := imapCfg.Identifier() @@ -176,6 +180,15 @@ Examples: return fmt.Errorf("set display name: %w", err) } + // Auto-default-identity must run BEFORE the legacy migration + // retry — see comment in account_identity.go. + if !noDefaultIdentityAddImap { + confirmDefaultIdentity(cmd.OutOrStdout(), s, source.ID, imapUsername, imapUsername, "account-identifier") + } + if err := runPostSourceCreateMigrations(s); err != nil { + return fmt.Errorf("post-source-create migrations: %w", err) + } + fmt.Printf("\nIMAP account added successfully!\n") fmt.Printf(" Identifier: %s\n", identifier) fmt.Printf(" Note: Password stored on disk at %s\n", imapclient.CredentialsPath(cfg.TokensDir(), identifier)) @@ -231,5 +244,6 @@ func init() { addIMAPCmd.Flags().StringVar(&imapUsername, "username", "", "IMAP username / email address (required)") addIMAPCmd.Flags().BoolVar(&imapNoTLS, "no-tls", false, "Disable TLS (plain connection, not recommended)") addIMAPCmd.Flags().BoolVar(&imapSTARTTLS, "starttls", false, "Use STARTTLS instead of implicit TLS") + addIMAPCmd.Flags().BoolVar(&noDefaultIdentityAddImap, "no-default-identity", false, noDefaultIdentityHelp) rootCmd.AddCommand(addIMAPCmd) } diff --git a/cmd/msgvault/cmd/addo365.go b/cmd/msgvault/cmd/addo365.go index a5d913f0..6f6c319b 100644 --- a/cmd/msgvault/cmd/addo365.go +++ b/cmd/msgvault/cmd/addo365.go @@ -10,7 +10,10 @@ import ( "github.com/wesm/msgvault/internal/store" ) -var o365TenantID string +var ( + o365TenantID string + noDefaultIdentityAddO365 bool +) var addO365Cmd = &cobra.Command{ Use: "add-o365 ", @@ -81,6 +84,9 @@ Examples: if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrationsForIngest(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } identifier := imapCfg.Identifier() @@ -113,7 +119,6 @@ Examples: return fmt.Errorf("create source: %w", err) } } - cfgJSON, err := imapCfg.ToJSON() if err != nil { return fmt.Errorf("serialize config: %w", err) @@ -125,6 +130,15 @@ Examples: return fmt.Errorf("set display name: %w", err) } + // Auto-default-identity must run BEFORE the legacy migration + // retry — see comment in account_identity.go. + if !noDefaultIdentityAddO365 { + confirmDefaultIdentity(cmd.OutOrStdout(), s, source.ID, email, email, "account-identifier") + } + if err := runPostSourceCreateMigrations(s); err != nil { + return fmt.Errorf("post-source-create migrations: %w", err) + } + fmt.Printf("\nMicrosoft 365 account added successfully!\n") fmt.Printf(" Email: %s\n", email) fmt.Printf(" Identifier: %s\n", identifier) @@ -155,5 +169,6 @@ func isMicrosoftIMAPSource(src *store.Source, email string) bool { func init() { addO365Cmd.Flags().StringVar(&o365TenantID, "tenant", "", "Azure AD tenant ID (default: \"common\" for multi-tenant)") + addO365Cmd.Flags().BoolVar(&noDefaultIdentityAddO365, "no-default-identity", false, noDefaultIdentityHelp) rootCmd.AddCommand(addO365Cmd) } diff --git a/cmd/msgvault/cmd/build_cache.go b/cmd/msgvault/cmd/build_cache.go index 6f25a87a..7e4730c4 100644 --- a/cmd/msgvault/cmd/build_cache.go +++ b/cmd/msgvault/cmd/build_cache.go @@ -83,6 +83,10 @@ Use --full-rebuild to recreate all cache files from scratch.`, _ = s.Close() return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrations(s); err != nil { + _ = s.Close() + return fmt.Errorf("startup migrations: %w", err) + } _ = s.Close() result, err := buildCache(dbPath, analyticsDir, fullRebuild) @@ -301,7 +305,7 @@ func buildCache(dbPath, analyticsDir string, fullRebuild bool) (*buildResult, er CAST(EXTRACT(YEAR FROM m.sent_at) AS INTEGER) as year, CAST(EXTRACT(MONTH FROM m.sent_at) AS INTEGER) as month FROM sqlite_db.messages m - WHERE m.sent_at IS NOT NULL%s + WHERE m.sent_at IS NOT NULL AND m.deleted_at IS NULL%s ) TO '%s' ( FORMAT PARQUET, PARTITION_BY (year), @@ -680,8 +684,12 @@ func setupSQLiteSource(duckDB *sql.DB, dbPath string) (cleanup func(), err error query string typeOverrides string // DuckDB types parameter for read_csv_auto (empty = infer all) }{ - {"messages", "SELECT id, source_id, source_message_id, conversation_id, subject, snippet, sent_at, size_estimate, has_attachments, attachment_count, deleted_from_source_at, sender_id, message_type FROM messages WHERE sent_at IS NOT NULL", - "types={'sent_at': 'TIMESTAMP', 'deleted_from_source_at': 'TIMESTAMP'}"}, + // deleted_at is exported so the main COPY query can apply the + // `deleted_at IS NULL` filter on this path the same way it does + // on the sqlite_scanner path; otherwise DuckDB binds against a + // CSV view that lacks the column and the export fails on Windows. + {"messages", "SELECT id, source_id, source_message_id, conversation_id, subject, snippet, sent_at, size_estimate, has_attachments, attachment_count, deleted_from_source_at, deleted_at, sender_id, message_type FROM messages WHERE sent_at IS NOT NULL", + "types={'sent_at': 'TIMESTAMP', 'deleted_from_source_at': 'TIMESTAMP', 'deleted_at': 'TIMESTAMP'}"}, {"message_recipients", "SELECT message_id, participant_id, recipient_type, display_name FROM message_recipients", ""}, {"message_labels", "SELECT message_id, label_id FROM message_labels", ""}, {"attachments", "SELECT message_id, size, filename FROM attachments", ""}, diff --git a/cmd/msgvault/cmd/build_cache_test.go b/cmd/msgvault/cmd/build_cache_test.go index e4e74ab0..05f9968e 100644 --- a/cmd/msgvault/cmd/build_cache_test.go +++ b/cmd/msgvault/cmd/build_cache_test.go @@ -55,6 +55,7 @@ func setupTestSQLite(t *testing.T) (string, func()) { deleted_from_source_at TIMESTAMP, sender_id INTEGER, message_type TEXT NOT NULL DEFAULT 'email', + deleted_at DATETIME, UNIQUE(source_id, source_message_id) ); @@ -1133,7 +1134,7 @@ func TestBuildCache_EmptyDatabase(t *testing.T) { db, _ := sql.Open("sqlite3", dbPath) _, _ = db.Exec(` CREATE TABLE sources (id INTEGER PRIMARY KEY, identifier TEXT); - CREATE TABLE messages (id INTEGER PRIMARY KEY, source_id INTEGER, source_message_id TEXT, sent_at TIMESTAMP, size_estimate INTEGER, has_attachments BOOLEAN, subject TEXT, snippet TEXT, conversation_id INTEGER, deleted_from_source_at TIMESTAMP, attachment_count INTEGER DEFAULT 0, sender_id INTEGER, message_type TEXT NOT NULL DEFAULT 'email'); + CREATE TABLE messages (id INTEGER PRIMARY KEY, source_id INTEGER, source_message_id TEXT, sent_at TIMESTAMP, size_estimate INTEGER, has_attachments BOOLEAN, subject TEXT, snippet TEXT, conversation_id INTEGER, deleted_from_source_at TIMESTAMP, attachment_count INTEGER DEFAULT 0, sender_id INTEGER, message_type TEXT NOT NULL DEFAULT 'email', deleted_at DATETIME); CREATE TABLE participants (id INTEGER PRIMARY KEY, email_address TEXT, domain TEXT, display_name TEXT, phone_number TEXT); CREATE TABLE message_recipients (message_id INTEGER, participant_id INTEGER, recipient_type TEXT, display_name TEXT); CREATE TABLE labels (id INTEGER PRIMARY KEY, name TEXT); @@ -1178,8 +1179,8 @@ func TestCSVFallbackPath(t *testing.T) { query string typeOverrides string }{ - {"messages", "SELECT id, source_id, source_message_id, conversation_id, subject, snippet, sent_at, size_estimate, has_attachments, deleted_from_source_at FROM messages WHERE sent_at IS NOT NULL", - "types={'sent_at': 'TIMESTAMP', 'deleted_from_source_at': 'TIMESTAMP'}"}, + {"messages", "SELECT id, source_id, source_message_id, conversation_id, subject, snippet, sent_at, size_estimate, has_attachments, attachment_count, deleted_from_source_at, deleted_at, sender_id, message_type FROM messages WHERE sent_at IS NOT NULL", + "types={'sent_at': 'TIMESTAMP', 'deleted_from_source_at': 'TIMESTAMP', 'deleted_at': 'TIMESTAMP'}"}, {"message_recipients", "SELECT message_id, participant_id, recipient_type, display_name FROM message_recipients", ""}, {"message_labels", "SELECT message_id, label_id FROM message_labels", ""}, {"attachments", "SELECT message_id, size, filename FROM attachments", ""}, @@ -1333,7 +1334,7 @@ func BenchmarkBuildCache(b *testing.B) { // Create schema _, _ = db.Exec(` CREATE TABLE sources (id INTEGER PRIMARY KEY, identifier TEXT); - CREATE TABLE messages (id INTEGER PRIMARY KEY, source_id INTEGER, source_message_id TEXT, sent_at TIMESTAMP, size_estimate INTEGER, has_attachments BOOLEAN, subject TEXT, snippet TEXT, conversation_id INTEGER, deleted_from_source_at TIMESTAMP, attachment_count INTEGER DEFAULT 0, sender_id INTEGER, message_type TEXT NOT NULL DEFAULT 'email'); + CREATE TABLE messages (id INTEGER PRIMARY KEY, source_id INTEGER, source_message_id TEXT, sent_at TIMESTAMP, size_estimate INTEGER, has_attachments BOOLEAN, subject TEXT, snippet TEXT, conversation_id INTEGER, deleted_from_source_at TIMESTAMP, attachment_count INTEGER DEFAULT 0, sender_id INTEGER, message_type TEXT NOT NULL DEFAULT 'email', deleted_at DATETIME); CREATE TABLE participants (id INTEGER PRIMARY KEY, email_address TEXT UNIQUE, domain TEXT, display_name TEXT, phone_number TEXT); CREATE TABLE message_recipients (message_id INTEGER, participant_id INTEGER, recipient_type TEXT, display_name TEXT); CREATE TABLE labels (id INTEGER PRIMARY KEY, name TEXT); @@ -1427,6 +1428,7 @@ func setupTestSQLiteEmpty(t *testing.T) (string, func()) { deleted_from_source_at TIMESTAMP, sender_id INTEGER, message_type TEXT NOT NULL DEFAULT 'email', + deleted_at DATETIME, UNIQUE(source_id, source_message_id) ); CREATE TABLE participants ( @@ -1887,6 +1889,60 @@ func TestCacheNeedsBuild_IgnoresAlreadyProcessedUpdatedSyncRun(t *testing.T) { } } +// TestCacheNeedsBuild_DedupHidesAfterLastSync covers the regression +// where dedup-hidden rows (deleted_at) added after the cache was built +// silently stayed in Parquet because the staleness check only watched +// deleted_from_source_at. The check now treats dedup hides the same +// way: any row whose deleted_at is at or after LastSyncAt forces a +// full rebuild. +func TestCacheNeedsBuild_DedupHidesAfterLastSync(t *testing.T) { + tmpDir, cleanup := setupTestSQLiteEmpty(t) + defer cleanup() + + dbPath := filepath.Join(tmpDir, "test.db") + analyticsDir := filepath.Join(tmpDir, "analytics") + + stateTime := time.Date(2026, 3, 18, 12, 0, 0, 0, time.UTC) + writeSyncStateAt(t, analyticsDir, 5, stateTime) + createFakeParquet(t, analyticsDir) + + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatalf("open db: %v", err) + } + defer func() { _ = db.Close() }() + + // Insert one live row and one row dedup-hidden after LastSyncAt. + if _, err := db.Exec( + `INSERT INTO messages + (id, source_id, source_message_id, sent_at, deleted_at) + VALUES (1, 1, 'msg1', datetime('now'), NULL)`, + ); err != nil { + t.Fatalf("insert live row: %v", err) + } + hiddenAt := stateTime.Add(1 * time.Hour). + Format("2006-01-02 15:04:05") + if _, err := db.Exec( + `INSERT INTO messages + (id, source_id, source_message_id, sent_at, deleted_at) + VALUES (2, 1, 'msg2', datetime('now'), ?)`, + hiddenAt, + ); err != nil { + t.Fatalf("insert dedup-hidden row: %v", err) + } + + got := cacheNeedsBuild(dbPath, analyticsDir) + if !got.NeedsBuild { + t.Fatalf("cacheNeedsBuild() = %+v, want NeedsBuild=true after dedup hide", got) + } + if !got.FullRebuild { + t.Fatalf("cacheNeedsBuild() = %+v, want FullRebuild=true after dedup hide", got) + } + if !strings.Contains(got.Reason, "dedup-hidden") { + t.Errorf("Reason = %q, want substring 'dedup-hidden'", got.Reason) + } +} + func TestBuildCache_RecordsLastCompletedSyncRunID(t *testing.T) { tmpDir, cleanup := setupTestSQLite(t) defer cleanup() @@ -1991,7 +2047,7 @@ func BenchmarkBuildCacheIncremental(b *testing.B) { // Create schema and initial data (10000 messages) _, _ = db.Exec(` CREATE TABLE sources (id INTEGER PRIMARY KEY, identifier TEXT); - CREATE TABLE messages (id INTEGER PRIMARY KEY, source_id INTEGER, source_message_id TEXT, sent_at TIMESTAMP, size_estimate INTEGER, has_attachments BOOLEAN, subject TEXT, snippet TEXT, conversation_id INTEGER, deleted_from_source_at TIMESTAMP, attachment_count INTEGER DEFAULT 0, sender_id INTEGER, message_type TEXT NOT NULL DEFAULT 'email'); + CREATE TABLE messages (id INTEGER PRIMARY KEY, source_id INTEGER, source_message_id TEXT, sent_at TIMESTAMP, size_estimate INTEGER, has_attachments BOOLEAN, subject TEXT, snippet TEXT, conversation_id INTEGER, deleted_from_source_at TIMESTAMP, attachment_count INTEGER DEFAULT 0, sender_id INTEGER, message_type TEXT NOT NULL DEFAULT 'email', deleted_at DATETIME); CREATE TABLE participants (id INTEGER PRIMARY KEY, email_address TEXT UNIQUE, domain TEXT, display_name TEXT, phone_number TEXT); CREATE TABLE message_recipients (message_id INTEGER, participant_id INTEGER, recipient_type TEXT, display_name TEXT); CREATE TABLE labels (id INTEGER PRIMARY KEY, name TEXT); diff --git a/cmd/msgvault/cmd/collection.go b/cmd/msgvault/cmd/collection.go new file mode 100644 index 00000000..d682f590 --- /dev/null +++ b/cmd/msgvault/cmd/collection.go @@ -0,0 +1,276 @@ +package cmd + +import ( + "errors" + "fmt" + "os" + "strconv" + "strings" + "text/tabwriter" + + "github.com/spf13/cobra" + "github.com/wesm/msgvault/internal/store" +) + +var collectionCmd = &cobra.Command{ + Use: "collection", + Short: "Manage named groups of accounts", + Long: `Collections are named groupings of accounts that let you view and +deduplicate across multiple sources as one unified archive. + +A default "All" collection is created automatically and includes +every account.`, +} + +var collectionCreateCmd = &cobra.Command{ + Use: "create --accounts ", + Short: "Create a new collection", + Args: cobra.ExactArgs(1), + RunE: runCollectionCreate, +} + +var collectionListCmd = &cobra.Command{ + Use: "list", + Short: "List all collections", + RunE: runCollectionList, +} + +var collectionShowCmd = &cobra.Command{ + Use: "show ", + Short: "Show collection details", + Args: cobra.ExactArgs(1), + RunE: runCollectionShow, +} + +var collectionAddCmd = &cobra.Command{ + Use: "add --accounts ", + Short: "Add accounts to a collection", + Args: cobra.ExactArgs(1), + RunE: runCollectionAdd, +} + +var collectionRemoveCmd = &cobra.Command{ + Use: "remove --accounts ", + Short: "Remove accounts from a collection", + Args: cobra.ExactArgs(1), + RunE: runCollectionRemove, +} + +var collectionDeleteCmd = &cobra.Command{ + Use: "delete ", + Short: "Delete a collection (sources and messages are untouched)", + Args: cobra.ExactArgs(1), + RunE: runCollectionDelete, +} + +var ( + collectionCreateAccounts string + collectionAddAccounts string + collectionRemoveAccounts string +) + +func runCollectionCreate(_ *cobra.Command, args []string) error { + st, err := openStoreAndInit() + if err != nil { + return err + } + defer func() { _ = st.Close() }() + + name := args[0] + sourceIDs, err := resolveAccountList(st, collectionCreateAccounts) + if err != nil { + return err + } + + coll, err := st.CreateCollection(name, "", sourceIDs) + if err != nil { + return err + } + fmt.Printf("Created collection %q with %d source(s).\n", + coll.Name, len(sourceIDs)) + return nil +} + +func runCollectionList(_ *cobra.Command, _ []string) error { + st, err := openStoreAndInit() + if err != nil { + return err + } + defer func() { _ = st.Close() }() + + collections, err := st.ListCollections() + if err != nil { + return err + } + if len(collections) == 0 { + fmt.Println("No collections.") + return nil + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + _, _ = fmt.Fprintln(w, "NAME\tSOURCES\tMESSAGES") + for _, c := range collections { + _, _ = fmt.Fprintf(w, "%s\t%d\t%s\n", + c.Name, len(c.SourceIDs), + formatCount(c.MessageCount)) + } + _ = w.Flush() + return nil +} + +func runCollectionShow(_ *cobra.Command, args []string) error { + st, err := openStoreAndInit() + if err != nil { + return err + } + defer func() { _ = st.Close() }() + + coll, err := st.GetCollectionByName(args[0]) + if err != nil { + return err + } + + fmt.Printf("Collection: %s\n", coll.Name) + if coll.Description != "" { + fmt.Printf("Description: %s\n", coll.Description) + } + fmt.Printf("Sources: %d\n", len(coll.SourceIDs)) + fmt.Printf("Messages: %s\n", formatCount(coll.MessageCount)) + fmt.Printf("Created: %s\n", coll.CreatedAt.Format("2006-01-02 15:04")) + + if len(coll.SourceIDs) > 0 { + fmt.Println("\nMember sources:") + for _, sid := range coll.SourceIDs { + src, err := st.GetSourceByID(sid) + if err != nil { + return fmt.Errorf("get source %d: %w", sid, err) + } + label := src.Identifier + if src.DisplayName.Valid && src.DisplayName.String != "" { + label = src.DisplayName.String + } + fmt.Printf("- %s (id %d)\n", label, src.ID) + } + } + return nil +} + +func runCollectionAdd(_ *cobra.Command, args []string) error { + st, err := openStoreAndInit() + if err != nil { + return err + } + defer func() { _ = st.Close() }() + + sourceIDs, err := resolveAccountList(st, collectionAddAccounts) + if err != nil { + return err + } + + if err := st.AddSourcesToCollection(args[0], sourceIDs); err != nil { + return err + } + fmt.Printf("Added %d source(s) to %q.\n", len(sourceIDs), args[0]) + return nil +} + +func runCollectionRemove(_ *cobra.Command, args []string) error { + st, err := openStoreAndInit() + if err != nil { + return err + } + defer func() { _ = st.Close() }() + + sourceIDs, err := resolveAccountList(st, collectionRemoveAccounts) + if err != nil { + return err + } + + if err := st.RemoveSourcesFromCollection(args[0], sourceIDs); err != nil { + return err + } + fmt.Printf("Removed %d source(s) from %q.\n", len(sourceIDs), args[0]) + return nil +} + +func runCollectionDelete(_ *cobra.Command, args []string) error { + st, err := openStoreAndInit() + if err != nil { + return err + } + defer func() { _ = st.Close() }() + + if err := st.DeleteCollection(args[0]); err != nil { + return err + } + fmt.Printf("Deleted collection %q.\n", args[0]) + return nil +} + +func resolveAccountList(st *store.Store, accounts string) ([]int64, error) { + if accounts == "" { + return nil, fmt.Errorf("--accounts is required") + } + parts := strings.Split(accounts, ",") + var ids []int64 + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + // Try as numeric ID first, but only for plain digit tokens. + // strconv.ParseInt accepts a leading '+' or '-' sign, so an + // E.164 phone identifier like "+15551234567" would parse as + // the integer 15551234567 and be treated as a source ID, + // silently breaking WhatsApp/Google Voice accounts that key + // on phone numbers. Restrict the numeric branch to tokens + // whose first byte is a decimal digit so signed inputs fall + // through to identifier resolution. If the numeric lookup + // returns ErrSourceNotFound, fall through to ResolveAccountFlag + // — the digit string may be a numeric identifier (e.g. + // unprefixed phone number, account name) rather than a source + // ID. Surface any other error (real DB failure) so it isn't + // masked as a "not found". + if p[0] >= '0' && p[0] <= '9' { + if id, err := strconv.ParseInt(p, 10, 64); err == nil { + _, lookupErr := st.GetSourceByID(id) + switch { + case lookupErr == nil: + ids = append(ids, id) + continue + case errors.Is(lookupErr, store.ErrSourceNotFound): + // fall through to identifier resolution + default: + return nil, fmt.Errorf("get source %d: %w", id, lookupErr) + } + } + } + // Resolve by identifier + scope, err := ResolveAccountFlag(st, p) + if err != nil { + return nil, err + } + ids = append(ids, scope.SourceIDs()...) + } + if len(ids) == 0 { + return nil, fmt.Errorf("no valid accounts in --accounts") + } + return ids, nil +} + +func init() { + rootCmd.AddCommand(collectionCmd) + collectionCmd.AddCommand(collectionCreateCmd) + collectionCmd.AddCommand(collectionListCmd) + collectionCmd.AddCommand(collectionShowCmd) + collectionCmd.AddCommand(collectionAddCmd) + collectionCmd.AddCommand(collectionRemoveCmd) + collectionCmd.AddCommand(collectionDeleteCmd) + + collectionCreateCmd.Flags().StringVar(&collectionCreateAccounts, + "accounts", "", "Comma-separated account emails or source IDs") + collectionAddCmd.Flags().StringVar(&collectionAddAccounts, + "accounts", "", "Comma-separated account emails or source IDs") + collectionRemoveCmd.Flags().StringVar(&collectionRemoveAccounts, + "accounts", "", "Comma-separated account emails or source IDs") +} diff --git a/cmd/msgvault/cmd/collection_test.go b/cmd/msgvault/cmd/collection_test.go new file mode 100644 index 00000000..32d22dcd --- /dev/null +++ b/cmd/msgvault/cmd/collection_test.go @@ -0,0 +1,138 @@ +package cmd + +import ( + "fmt" + "path/filepath" + "strings" + "testing" + + "github.com/spf13/cobra" + "github.com/wesm/msgvault/internal/config" + "github.com/wesm/msgvault/internal/store" +) + +func TestCollectionShowPrintsReadableSourceNames(t *testing.T) { + savedCfg := cfg + defer func() { cfg = savedCfg }() + + tmpDir := t.TempDir() + cfg = &config.Config{ + HomeDir: tmpDir, + Data: config.DataConfig{DataDir: tmpDir}, + } + + dbPath := filepath.Join(tmpDir, "msgvault.db") + st, err := store.Open(dbPath) + if err != nil { + t.Fatalf("open store: %v", err) + } + if err := st.InitSchema(); err != nil { + t.Fatalf("init schema: %v", err) + } + + alice, err := st.GetOrCreateSource("gmail", "alice@example.com") + if err != nil { + t.Fatalf("create alice source: %v", err) + } + if err := st.UpdateSourceDisplayName(alice.ID, "Personal"); err != nil { + t.Fatalf("set display name: %v", err) + } + bob, err := st.GetOrCreateSource("imap", "bob@example.com") + if err != nil { + t.Fatalf("create bob source: %v", err) + } + if _, err := st.CreateCollection("team", "", []int64{alice.ID, bob.ID}); err != nil { + t.Fatalf("create collection: %v", err) + } + if err := st.Close(); err != nil { + t.Fatalf("close setup store: %v", err) + } + + done := captureStdout(t) + if err := runCollectionShow(&cobra.Command{}, []string{"team"}); err != nil { + t.Fatalf("runCollectionShow: %v", err) + } + out := done() + + if !strings.Contains(out, "Personal (id ") { + t.Fatalf("missing display name in output:\n%s", out) + } + if !strings.Contains(out, "bob@example.com (id ") { + t.Fatalf("missing identifier in output:\n%s", out) + } +} + +func TestResolveAccountListRejectsMissingNumericID(t *testing.T) { + tmpDir := t.TempDir() + st, err := store.Open(filepath.Join(tmpDir, "msgvault.db")) + if err != nil { + t.Fatalf("open store: %v", err) + } + defer func() { _ = st.Close() }() + if err := st.InitSchema(); err != nil { + t.Fatalf("init schema: %v", err) + } + + src, err := st.GetOrCreateSource("gmail", "alice@example.com") + if err != nil { + t.Fatalf("create source: %v", err) + } + + ids, err := resolveAccountList(st, fmt.Sprintf("%d", src.ID)) + if err != nil { + t.Fatalf("resolveAccountList(existing id): %v", err) + } + if len(ids) != 1 || ids[0] != src.ID { + t.Fatalf("resolveAccountList(existing id) = %v, want [%d]", ids, src.ID) + } + + // "999999" is neither an existing source ID nor an existing + // identifier/display name, so resolveAccountList errors via the + // final ResolveAccountFlag fall-through. Iter12 codex flagged that + // the prior shape errored *before* the fall-through, so a numeric + // identifier (e.g. unprefixed phone "15551234567") that wasn't a + // source ID would never get a chance to match by identifier. The + // test below asserts the fall-through path is reachable. + if _, err := resolveAccountList(st, "999999"); err == nil { + t.Fatal("expected error for missing numeric source ID, got nil") + } +} + +// TestResolveAccountListNumericFallthroughResolvesIdentifier verifies +// that a plain-digit token that does NOT match a source ID falls +// through to identifier resolution. Regression test for iter12 codex +// Low: previously, a numeric identifier (e.g. an unprefixed phone +// number) that happened to not match a source ID would error +// immediately instead of being looked up by identifier. +func TestResolveAccountListNumericFallthroughResolvesIdentifier(t *testing.T) { + tmpDir := t.TempDir() + st, err := store.Open(filepath.Join(tmpDir, "msgvault.db")) + if err != nil { + t.Fatalf("open store: %v", err) + } + defer func() { _ = st.Close() }() + if err := st.InitSchema(); err != nil { + t.Fatalf("init schema: %v", err) + } + + // Create a source with a numeric identifier that is unlikely to + // collide with the auto-assigned source ID. Use a 12-digit string + // (way past any plausible primary-key value) so the test stays + // stable. + phoneIdentifier := "987654321098" + src, err := st.GetOrCreateSource("whatsapp", phoneIdentifier) + if err != nil { + t.Fatalf("create source: %v", err) + } + if fmt.Sprintf("%d", src.ID) == phoneIdentifier { + t.Fatalf("test assumption broken: source id %d collides with identifier", src.ID) + } + + ids, err := resolveAccountList(st, phoneIdentifier) + if err != nil { + t.Fatalf("resolveAccountList(numeric identifier): %v", err) + } + if len(ids) != 1 || ids[0] != src.ID { + t.Fatalf("resolveAccountList(numeric identifier) = %v, want [%d]", ids, src.ID) + } +} diff --git a/cmd/msgvault/cmd/confirm.go b/cmd/msgvault/cmd/confirm.go new file mode 100644 index 00000000..f8e5d693 --- /dev/null +++ b/cmd/msgvault/cmd/confirm.go @@ -0,0 +1,87 @@ +package cmd + +import ( + "bufio" + "fmt" + "io" + "strings" +) + +// ConfirmMode selects the prompt and validation for confirmDestructive. +type ConfirmMode int + +const ( + // ConfirmModePermanent — destructive remote delete (rung 04). + // Requires the literal word "delete" to confirm. Anything else, + // including EOF, prints the verbatim cancellation message and + // returns (false, nil). + ConfirmModePermanent ConfirmMode = iota + + // ConfirmModeAllHidden — destructive local hard delete (rung 03) + // targeting every hidden row. Accepts y/yes; n/no/EOF. EOF produces + // the contract-naming error (cannot be skipped with --yes). + ConfirmModeAllHidden + + // ConfirmModeYesNo — ordinary destructive prompt that may be + // skipped with --yes by the caller. Accepts y/yes; n/no/EOF cancel + // without an error so scripted/non-interactive use exits cleanly + // when the prompt is reached unexpectedly. + ConfirmModeYesNo +) + +// confirmDestructive prompts on the provided writer and reads a single +// line of input from the provided reader. Returns (true, nil) on +// confirmation, (false, nil) on cancellation, (_, err) on a contract +// violation that should fail the command (e.g. AllHidden EOF). +// +// The reader/writer split lets unit tests inject fixed input and +// inspect the prompt + cancellation messages without standing up the +// full cobra RunE harness. +func confirmDestructive(r io.Reader, w io.Writer, mode ConfirmMode) (bool, error) { + switch mode { + case ConfirmModePermanent: + _, _ = fmt.Fprint(w, `Type "delete" to confirm permanent deletion (no recovery): `) + scanner := bufio.NewScanner(r) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return false, fmt.Errorf("read confirmation: %w", err) + } + _, _ = fmt.Fprintln(w, "Cancelled. Drop --permanent to use trash deletion without elevated permissions.") + return false, nil + } + if strings.TrimSpace(scanner.Text()) != "delete" { + _, _ = fmt.Fprintln(w, "Cancelled. Drop --permanent to use trash deletion without elevated permissions.") + return false, nil + } + return true, nil + + case ConfirmModeAllHidden: + _, _ = fmt.Fprint(w, "Proceed? This is irreversible. [y/N]: ") + scanner := bufio.NewScanner(r) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return false, fmt.Errorf("read confirmation: %w", err) + } + return false, fmt.Errorf( + "no confirmation input (stdin closed); --all-hidden cannot be skipped with --yes", + ) + } + answer := strings.TrimSpace(strings.ToLower(scanner.Text())) + return answer == "y" || answer == "yes", nil + + case ConfirmModeYesNo: + _, _ = fmt.Fprint(w, "Proceed? This is irreversible. [y/N]: ") + scanner := bufio.NewScanner(r) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return false, fmt.Errorf("read confirmation: %w", err) + } + return false, nil + } + answer := strings.TrimSpace(strings.ToLower(scanner.Text())) + return answer == "y" || answer == "yes", nil + + default: + return false, fmt.Errorf("unknown ConfirmMode: %d", mode) + } +} diff --git a/cmd/msgvault/cmd/confirm_test.go b/cmd/msgvault/cmd/confirm_test.go new file mode 100644 index 00000000..9223692a --- /dev/null +++ b/cmd/msgvault/cmd/confirm_test.go @@ -0,0 +1,113 @@ +package cmd + +import ( + "bytes" + "strings" + "testing" +) + +func TestConfirmDestructive_Permanent_LiteralDeleteAccepted(t *testing.T) { + var out bytes.Buffer + ok, err := confirmDestructive(strings.NewReader("delete\n"), &out, ConfirmModePermanent) + if err != nil { + t.Fatalf("err = %v", err) + } + if !ok { + t.Errorf("ok = false, want true after typing 'delete'") + } +} + +func TestConfirmDestructive_Permanent_NonDeleteRejected(t *testing.T) { + var out bytes.Buffer + ok, err := confirmDestructive(strings.NewReader("y\n"), &out, ConfirmModePermanent) + if err != nil { + t.Fatalf("err = %v", err) + } + if ok { + t.Errorf("ok = true, want false when input is 'y' under Permanent mode") + } + if !strings.Contains(out.String(), `Cancelled. Drop --permanent to use trash deletion without elevated permissions.`) { + t.Errorf("output missing verbatim cancellation message: %q", out.String()) + } +} + +func TestConfirmDestructive_Permanent_StdinClosed(t *testing.T) { + var out bytes.Buffer + ok, err := confirmDestructive(strings.NewReader(""), &out, ConfirmModePermanent) + if err != nil { + t.Fatalf("err = %v", err) + } + if ok { + t.Errorf("ok = true on closed stdin, want false") + } + if !strings.Contains(out.String(), `Cancelled. Drop --permanent to use trash deletion without elevated permissions.`) { + t.Errorf("output missing verbatim cancellation message on EOF: %q", out.String()) + } +} + +func TestConfirmDestructive_AllHidden_YesAccepted(t *testing.T) { + var out bytes.Buffer + ok, err := confirmDestructive(strings.NewReader("y\n"), &out, ConfirmModeAllHidden) + if err != nil { + t.Fatalf("err = %v", err) + } + if !ok { + t.Errorf("ok = false on 'y', want true under AllHidden mode") + } +} + +func TestConfirmDestructive_AllHidden_NoRejected(t *testing.T) { + var out bytes.Buffer + ok, err := confirmDestructive(strings.NewReader("n\n"), &out, ConfirmModeAllHidden) + if err != nil { + t.Fatalf("err = %v", err) + } + if ok { + t.Errorf("ok = true on 'n', want false") + } +} + +func TestConfirmDestructive_AllHidden_StdinClosed(t *testing.T) { + var out bytes.Buffer + _, err := confirmDestructive(strings.NewReader(""), &out, ConfirmModeAllHidden) + if err == nil { + t.Fatalf("err = nil on closed stdin, want named error") + } + wantSubstr := "no confirmation input (stdin closed); --all-hidden cannot be skipped with --yes" + if !strings.Contains(err.Error(), wantSubstr) { + t.Errorf("err = %q, want substring %q", err.Error(), wantSubstr) + } +} + +func TestConfirmDestructive_YesNo_YesAccepted(t *testing.T) { + var out bytes.Buffer + ok, err := confirmDestructive(strings.NewReader("y\n"), &out, ConfirmModeYesNo) + if err != nil { + t.Fatalf("err = %v", err) + } + if !ok { + t.Errorf("ok = false on 'y', want true under YesNo mode") + } +} + +func TestConfirmDestructive_YesNo_NoRejected(t *testing.T) { + var out bytes.Buffer + ok, err := confirmDestructive(strings.NewReader("n\n"), &out, ConfirmModeYesNo) + if err != nil { + t.Fatalf("err = %v", err) + } + if ok { + t.Errorf("ok = true on 'n', want false") + } +} + +func TestConfirmDestructive_YesNo_StdinClosed(t *testing.T) { + var out bytes.Buffer + ok, err := confirmDestructive(strings.NewReader(""), &out, ConfirmModeYesNo) + if err != nil { + t.Fatalf("err = %v on closed stdin, want nil (cancel-on-EOF)", err) + } + if ok { + t.Errorf("ok = true on closed stdin, want false") + } +} diff --git a/cmd/msgvault/cmd/deduplicate.go b/cmd/msgvault/cmd/deduplicate.go new file mode 100644 index 00000000..33386868 --- /dev/null +++ b/cmd/msgvault/cmd/deduplicate.go @@ -0,0 +1,643 @@ +package cmd + +import ( + "bufio" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "time" + + "github.com/spf13/cobra" + "github.com/wesm/msgvault/internal/dedup" + "github.com/wesm/msgvault/internal/store" +) + +var deduplicateCmd = &cobra.Command{ + Use: "deduplicate", + Aliases: []string{"dedup", "dedupe"}, + Short: "Find and merge duplicate messages within an account", + Long: `Find and merge duplicate messages within a single account +(for example, the same mbox imported twice, or stored MIME that +generates two copies of the same RFC822 Message-ID inside one ingest +source). Cross-source comparison requires --collection. + +Duplicates are grouped by the RFC822 Message-ID header. For each group the +engine selects a survivor, unions the labels from every copy onto the +survivor, and hides the pruned copies in the msgvault database. + +By default, deduplicate ONLY modifies the msgvault database. Your original +source files and remote servers are never modified. Hidden rows can be +restored with --undo, so a dedup run is fully reversible. + +Terminology: + "account" One ingest source/archive (a single Gmail OAuth + connection, one mbox import, one IMAP source, etc.). + "collection" A named, user-defined grouping of accounts. + +Scope: + --account Scope dedup to one account. Never crosses + source boundaries. + --collection Dedup across every member account of a collection. + This is the only way to compare messages across + sources, and it is an explicit user opt-in: + a duplicate Message-ID or matching content hash + across two accounts in the collection will hide + the loser locally. Use --dry-run first to + review what would be merged. Cross-source pruning + is local-only and reversible with --undo; + --delete-dups-from-source-server only stages + remote deletion when the loser and the survivor + share a source (same-source-only). + (no flag) Dedup runs per-account independently for every + account. Source boundaries are never crossed. + +Use --dry-run to scan and report without writing anything. +Use --content-hash to also group messages by normalized raw MIME when +Message-ID matching is insufficient. +Use --undo to reverse a previous dedup run. Pass --undo +multiple times to reverse several batches in one invocation; failures +on one batch do not skip later batches, and any errors are aggregated +and reported at the end.`, + RunE: runDeduplicate, +} + +var ( + dedupDryRun bool + dedupNoBackup bool + dedupPrefer string + dedupContentHash bool + dedupUndo []string + dedupAccount string + dedupCollection string + dedupDeleteFromSourceSrvr bool + dedupYes bool +) + +func runDeduplicate(cmd *cobra.Command, _ []string) error { + st, err := openStoreAndInit() + if err != nil { + return err + } + defer func() { _ = st.Close() }() + + // dbPath is the on-disk filesystem path used by VACUUM INTO + // backup; resolving it now also rejects non-file DSNs (e.g. + // postgres://) up-front rather than at the first backup attempt. + dbPath, err := cfg.DatabasePath() + if err != nil { + return fmt.Errorf("resolve database path: %w", err) + } + + deletionsDir := filepath.Join(cfg.Data.DataDir, "deletions") + + // --undo operates on a recorded batch ID; scope is captured in the + // batch itself. Cobra rejects --undo combined with --account or + // --collection, so by the time we reach this branch undo can run + // without resolving scope flags (a stale or renamed account would + // otherwise block a valid undo). + if len(dedupUndo) > 0 { + undoConfig := dedup.Config{DeletionsDir: deletionsDir} + engine := dedup.NewEngine(st, undoConfig, logger) + var allStillRunning []string + var undoErrs []error + for _, batchID := range dedupUndo { + restored, stillRunning, err := engine.Undo(batchID) + // Undo is best-effort: database rows may have been restored + // even if cancelling pending manifests failed. Always report + // the restored count and any still-running manifests before + // continuing so the user isn't left thinking the undo did + // nothing. Errors aggregate across batches so a failure on + // one batch ID doesn't skip the rest. + fmt.Printf("Restored %d messages from batch %q.\n", + restored, batchID) + allStillRunning = append(allStillRunning, stillRunning...) + if err != nil { + fmt.Fprintf(os.Stderr, + "\nError cancelling one or more pending manifests "+ + "for batch %q:\n %v\n", batchID, err) + undoErrs = append(undoErrs, fmt.Errorf("undo dedup %q: %w", batchID, err)) + } + } + printStillRunningWarning(allStillRunning) + return errors.Join(undoErrs...) + } + + preference := dedup.DefaultSourcePreference + if dedupPrefer != "" { + preference = strings.Split(dedupPrefer, ",") + known := make(map[string]bool, len(dedup.DefaultSourcePreference)) + for _, t := range dedup.DefaultSourcePreference { + known[t] = true + } + for i := range preference { + preference[i] = strings.TrimSpace(preference[i]) + if !known[preference[i]] { + fmt.Fprintf(os.Stderr, "Warning: unknown source type in --prefer: %q\n", preference[i]) + } + } + } + + var ( + accountSourceIDs []int64 + canonicalAccount string + scopeIsCollection bool + ) + switch { + case dedupAccount != "": + scope, err := ResolveAccountFlag(st, dedupAccount) + if err != nil { + return err + } + accountSourceIDs = scope.SourceIDs() + if len(accountSourceIDs) == 0 { + return fmt.Errorf("--account %q resolved to zero sources", dedupAccount) + } + canonicalAccount = scope.DisplayName() + case dedupCollection != "": + scope, err := ResolveCollectionFlag(st, dedupCollection) + if err != nil { + return err + } + accountSourceIDs = scope.SourceIDs() + if len(accountSourceIDs) == 0 { + return fmt.Errorf("--collection %q has no member accounts", dedupCollection) + } + canonicalAccount = scope.DisplayName() + scopeIsCollection = true + } + + config := dedup.Config{ + SourcePreference: preference, + ContentHashFallback: dedupContentHash, + DryRun: dedupDryRun, + AccountSourceIDs: accountSourceIDs, + Account: canonicalAccount, + ScopeIsCollection: scopeIsCollection, + DeleteDupsFromSourceServer: dedupDeleteFromSourceSrvr, + DeletionsDir: deletionsDir, + } + + if len(accountSourceIDs) > 0 { + bySource, err := loadPerSourceIdentities(st, accountSourceIDs) + if err != nil { + return fmt.Errorf("load per-source identities: %w", err) + } + config.IdentityAddressesBySource = bySource + if len(bySource) > 0 { + logger.Info("dedup per-source identities loaded", + "sources", len(bySource)) + } + } + + if scopeIsCollection { + allSources, err := st.ListSources("") + if err != nil { + return fmt.Errorf("list sources: %w", err) + } + idSet := make(map[int64]struct{}, len(accountSourceIDs)) + for _, id := range accountSourceIDs { + idSet[id] = struct{}{} + } + var memberNames []string + for _, src := range allSources { + if _, ok := idSet[src.ID]; ok { + memberNames = append(memberNames, src.Identifier) + } + } + fmt.Printf("Deduping across collection %q (%d accounts: %s)\n", + canonicalAccount, len(memberNames), strings.Join(memberNames, ", ")) + // When the collection spans more than one account, dedup is + // crossing source boundaries — a duplicate Message-ID or matching + // content hash between two accounts will hide the loser locally. + // Print a one-line hint so the user can confirm they meant to + // cross those boundaries (and remind them that --dry-run would + // preview without writing). + if len(memberNames) > 1 { + fmt.Println( + " Note: cross-source dedup is reversible (--undo); " + + "remote deletion stays same-source-only. " + + "Re-run with --dry-run to preview.", + ) + } + } + + if len(accountSourceIDs) == 0 { + // Per-source path constructs its own scoped engines per + // source, so no top-level engine is needed here. + return runDeduplicatePerSource(cmd, st, dbPath, config) + } + + // Single-account/single-collection path uses one engine shared + // across the whole scope. + engine := dedup.NewEngine(st, config, logger) + return runDeduplicateOnce(cmd, st, dbPath, config, engine) +} + +func runDeduplicatePerSource( + cmd *cobra.Command, + st *store.Store, + dbPath string, + cfgBase dedup.Config, +) error { + sources, err := st.ListSources("") + if err != nil { + return fmt.Errorf("list sources: %w", err) + } + if len(sources) == 0 { + fmt.Println("No sources found.") + return nil + } + + fmt.Println( + "No --account specified; deduping each source independently.", + ) + fmt.Println() + + backedUp := false + anyRan := false + var executedBatches []string + for _, src := range sources { + cfgScoped := cfgBase + cfgScoped.AccountSourceIDs = []int64{src.ID} + cfgScoped.Account = src.Identifier + bySource, err := loadPerSourceIdentities(st, []int64{src.ID}) + if err != nil { + return fmt.Errorf("load identities for %s: %w", src.Identifier, err) + } + cfgScoped.IdentityAddressesBySource = bySource + engineScoped := dedup.NewEngine(st, cfgScoped, logger) + + fmt.Printf("--- %s (%s) ---\n", src.Identifier, src.SourceType) + report, err := engineScoped.Scan(cmd.Context()) + if err != nil { + return fmt.Errorf("scan %s: %w", src.Identifier, err) + } + if report.DuplicateGroups == 0 { + // Scan can backfill rfc822_message_id even when no duplicate + // groups are produced (idempotent metadata derivation). Report + // that side effect so the user knows the scan did something + // before falling through to the "No duplicates." message. + if report.BackfilledCount != 0 { + fmt.Print(engineScoped.FormatReport(report)) + } + fmt.Println(" No duplicates.") + fmt.Println() + continue + } + + anyRan = true + fmt.Print(engineScoped.FormatReport(report)) + if cfgScoped.DryRun { + fmt.Println() + continue + } + + if !dedupYes { + // See runDeduplicateOnce for the rationale on the + // rfc822-backfill note: scan already performed it + // (idempotent metadata derivation) regardless of the + // answer below, so the prompt explicitly scopes "hide N + // duplicates" to the merge that follows. + if report.BackfilledCount > 0 { + fmt.Printf( + "\nNote: scan already backfilled %d "+ + "rfc822_message_id value(s) for %s from "+ + "stored MIME. This is metadata derivation "+ + "and is kept regardless of your answer.\n", + report.BackfilledCount, src.Identifier, + ) + } + fmt.Printf( + "\nProceed with deduplication for %s? "+ + "This will hide %d duplicates "+ + "(reversible with --undo). [y/N]: ", + src.Identifier, report.DuplicateMessages, + ) + ok, err := readDedupYesNo(cmd) + if err != nil { + return err + } + if !ok { + fmt.Println("Skipped.") + continue + } + } + + if !backedUp && !dedupNoBackup { + backedUp = true + backupPath := fmt.Sprintf( + "%s.dedup-backup-%s", dbPath, + time.Now().Format("20060102-150405"), + ) + fmt.Printf("Backing up database to %s...\n", + filepath.Base(backupPath)) + if err := backupDatabase(st, backupPath); err != nil { + return fmt.Errorf("backup database: %w", err) + } + } + + batchID := fmt.Sprintf( + "dedup-%s-%d-%s-%s", + time.Now().Format("20060102-150405"), + src.ID, + dedup.SanitizeFilenameComponent(src.Identifier), + randomBatchToken(), + ) + summary, err := engineScoped.Execute( + cmd.Context(), report, batchID, + ) + if err != nil { + if summary != nil && summary.GroupsMerged > 0 { + printDedupSummary(summary) + fmt.Println() + } + // Surface the undo hint for any prior sources that DID + // succeed in this run before returning the error. Without + // this, a user who hit an error on source N has no + // visibility into how to undo sources 1..N-1's changes + // without grepping the slog output. + printAccumulatedUndoHint(executedBatches) + return fmt.Errorf("execute %s: %w", src.Identifier, err) + } + executedBatches = append(executedBatches, summary.BatchID) + printDedupSummary(summary) + fmt.Println() + } + + if cfgBase.DryRun { + fmt.Println("\nDry run complete. No changes made.") + } else if !anyRan { + fmt.Println("No duplicates found in any source.") + } else if len(executedBatches) > 1 { + printAccumulatedUndoHint(executedBatches) + } + return nil +} + +// printAccumulatedUndoHint prints the multi-batch undo recipe for an +// in-progress per-source dedup run. Called from both the happy path +// (after all sources complete) and the Execute-error path (so a user +// who hit an error mid-loop still sees how to undo what already ran). +// No-op for fewer than 2 batches. +func printAccumulatedUndoHint(executedBatches []string) { + if len(executedBatches) < 2 { + return + } + var b strings.Builder + b.WriteString("\nTo undo all of the above:\n msgvault deduplicate") + for _, id := range executedBatches { + fmt.Fprintf(&b, " --undo %s", id) + } + b.WriteString("\n") + fmt.Print(b.String()) +} + +func runDeduplicateOnce( + cmd *cobra.Command, + st *store.Store, + dbPath string, + cfgScoped dedup.Config, + engine *dedup.Engine, +) error { + fmt.Println("Scanning for duplicate messages...") + report, err := engine.Scan(cmd.Context()) + if err != nil { + return fmt.Errorf("scan: %w", err) + } + + fmt.Print(engine.FormatMethodology()) + fmt.Print(engine.FormatReport(report)) + + if cfgScoped.DryRun { + fmt.Println("\nDry run complete. No changes made.") + return nil + } + if report.DuplicateGroups == 0 { + fmt.Println("\nNo duplicates found.") + return nil + } + + if !dedupYes { + // Surface the rfc822 backfill that scan already performed so + // the user knows what state the database is in before they + // answer. The backfill is idempotent metadata derivation + // (fills a previously-NULL column from stored MIME, never + // overwrites or changes content) and is kept regardless of + // this answer; the prompt and the backup that follows are + // scoped to the dedup merge itself. + if report.BackfilledCount > 0 { + fmt.Printf( + "\nNote: scan already backfilled %d rfc822_message_id "+ + "value(s) from stored MIME. This is metadata "+ + "derivation and is kept regardless of your answer.\n", + report.BackfilledCount, + ) + } + fmt.Printf( + "\nProceed with deduplication? This will hide %d "+ + "duplicates (reversible with --undo). [y/N]: ", + report.DuplicateMessages, + ) + ok, err := readDedupYesNo(cmd) + if err != nil { + return err + } + if !ok { + fmt.Println("Aborted.") + return nil + } + } + + if !dedupNoBackup { + backupPath := fmt.Sprintf( + "%s.dedup-backup-%s", dbPath, + time.Now().Format("20060102-150405"), + ) + fmt.Printf("Backing up database to %s...\n", + filepath.Base(backupPath)) + if err := backupDatabase(st, backupPath); err != nil { + return fmt.Errorf("backup database: %w", err) + } + } + + batchID := fmt.Sprintf( + "dedup-%s-run-%s", + time.Now().Format("20060102-150405"), + randomBatchToken(), + ) + fmt.Println("Merging duplicates...") + summary, err := engine.Execute(cmd.Context(), report, batchID) + if err != nil { + if summary != nil && summary.GroupsMerged > 0 { + printDedupSummary(summary) + fmt.Println() + } + return fmt.Errorf("execute: %w", err) + } + + printDedupSummary(summary) + // The analytics cache picks up dedup hides on the next TUI launch + // (cacheNeedsBuild detects deleted_at after LastSyncAt and forces a + // full rebuild). No manual rebuild required. + return nil +} + +func printDedupSummary(summary *dedup.ExecutionSummary) { + fmt.Printf("\n=== Deduplication Complete ===\n") + fmt.Printf("Batch ID: %s\n", summary.BatchID) + fmt.Printf("Groups merged: %d\n", summary.GroupsMerged) + fmt.Printf("Messages pruned: %d\n", summary.MessagesRemoved) + fmt.Printf("Labels transferred: %d\n", summary.LabelsTransferred) + fmt.Printf("Raw MIME backfilled: %d\n", summary.RawMIMEBackfilled) + + if len(summary.StagedManifests) > 0 { + fmt.Println("\nStaged deletion manifests (pending):") + for _, m := range summary.StagedManifests { + fmt.Printf(" %s [%s] %d messages (%s)\n", + m.ManifestID, m.SourceType, m.MessageCount, m.Account) + } + fmt.Println( + "\nRun 'msgvault delete-staged --list' to inspect, or " + + "MSGVAULT_ENABLE_REMOTE_DELETE=1 msgvault delete-staged " + + "to remove the duplicates from the remote server.", + ) + } + fmt.Printf("\nTo undo: msgvault deduplicate --undo %s\n", + summary.BatchID) +} + +func readDedupYesNo(cmd *cobra.Command) (bool, error) { + reader := bufio.NewReader(cmd.InOrStdin()) + response, err := reader.ReadString('\n') + if err != nil && !errors.Is(err, io.EOF) { + return false, fmt.Errorf("read confirmation: %w", err) + } + response = strings.TrimSpace(strings.ToLower(response)) + return response == "y" || response == "yes", nil +} + +// randomBatchToken returns a short random hex token used to disambiguate +// single-run dedup batch IDs from per-source batch IDs that may have been +// generated in the same second. +func randomBatchToken() string { + var b [4]byte + if _, err := rand.Read(b[:]); err != nil { + return fmt.Sprintf("%08x", time.Now().UnixNano()&0xffffffff) + } + return hex.EncodeToString(b[:]) +} + +// backupDatabase writes a point-in-time consistent copy of the SQLite +// database to dst using VACUUM INTO. Unlike a file-system copy of the +// main/-wal/-shm triple, this is atomic and handles uncheckpointed WAL +// pages without any external coordination. +func backupDatabase(st *store.Store, dst string) error { + if _, err := os.Stat(dst); err == nil { + return fmt.Errorf("backup target already exists: %s", dst) + } + if _, err := st.DB().Exec("VACUUM INTO ?", dst); err != nil { + return fmt.Errorf("vacuum into %s: %w", dst, err) + } + return nil +} + +// loadPerSourceIdentities builds a per-source identity map for the given +// source IDs by calling GetIdentitiesForScope once per source. Addresses +// are normalized via store.NormalizeIdentifierForCompare so the dedup +// engine's lookup uses the same case-aware rule as the store layer: +// email-shaped identities lowercase, synthetic identifiers (Matrix +// MXIDs, chat handles, phone E.164) preserve case. Without this, +// blanket-lowercasing would misclassify case-sensitive synthetic +// identifiers as sent copies. +func loadPerSourceIdentities(st *store.Store, sourceIDs []int64) (map[int64]map[string]struct{}, error) { + out := make(map[int64]map[string]struct{}, len(sourceIDs)) + for _, id := range sourceIDs { + addrs, err := st.GetIdentitiesForScope([]int64{id}) + if err != nil { + return nil, fmt.Errorf("get identities for source %d: %w", id, err) + } + if len(addrs) == 0 { + continue + } + normalized := make(map[string]struct{}, len(addrs)) + for addr := range addrs { + normalized[store.NormalizeIdentifierForCompare(addr)] = struct{}{} + } + out[id] = normalized + } + return out, nil +} + +func printStillRunningWarning(ids []string) { + if len(ids) == 0 { + return + } + // "Currently executing" specifically — these manifests have already + // been promoted from pending to in-progress, so they can't be + // cancelled (the executor will run them to completion). This is a + // different class of message from a pending-cancel *failure* + // (which surfaces as a returned error from Undo, not via this + // warning). + fmt.Printf( + "\nWarning: the following deletion manifests are currently " + + "executing\nand cannot be cancelled (the executor will run " + + "them to completion):\n", + ) + for _, id := range ids { + fmt.Printf(" - %s\n", id) + } +} + +func init() { + rootCmd.AddCommand(deduplicateCmd) + deduplicateCmd.Flags().BoolVar(&dedupDryRun, "dry-run", false, + "Scan and report only; do not modify data") + deduplicateCmd.Flags().BoolVar(&dedupNoBackup, "no-backup", false, + "Skip database backup before merging (backup covers pre-dedup state for all sources, not per-batch)") + deduplicateCmd.Flags().StringVar(&dedupPrefer, "prefer", "", + "Comma-separated source type preference order "+ + "(default: gmail,imap,mbox,emlx,hey)") + deduplicateCmd.Flags().BoolVar(&dedupContentHash, "content-hash", false, + "Also detect duplicates by normalized raw MIME content") + deduplicateCmd.Flags().StringArrayVar(&dedupUndo, "undo", nil, + "Undo a previous dedup run by batch ID "+ + "(repeat for multiple batches; failures on one batch do not "+ + "skip later batches and errors are aggregated; cannot be "+ + "combined with --account or --collection)") + deduplicateCmd.Flags().StringVar(&dedupAccount, "account", "", + "Scope dedup to one account; never crosses source boundaries") + deduplicateCmd.Flags().StringVar(&dedupCollection, "collection", "", + "Dedup across every member of a collection; opts into "+ + "cross-source comparison (use --dry-run to preview)") + deduplicateCmd.MarkFlagsMutuallyExclusive("account", "collection") + // --undo executes a write; --dry-run promises no writes. Reject the + // combination explicitly rather than silently letting --undo win. + deduplicateCmd.MarkFlagsMutuallyExclusive("dry-run", "undo") + // --undo is keyed by batch ID; the batch already records its scope. + // Combining --undo with --account/--collection is meaningless and + // would force a stale-account lookup before reaching the undo path. + deduplicateCmd.MarkFlagsMutuallyExclusive("undo", "account") + deduplicateCmd.MarkFlagsMutuallyExclusive("undo", "collection") + deduplicateCmd.Flags().BoolVar(&dedupDeleteFromSourceSrvr, + "delete-dups-from-source-server", false, + "DESTRUCTIVE: stage pruned duplicates for remote deletion "+ + "(execution requires MSGVAULT_ENABLE_REMOTE_DELETE=1)") + deduplicateCmd.Flags().BoolVarP(&dedupYes, "yes", "y", false, + "Skip confirmation prompt") + // --undo restores rows from a recorded batch; none of the + // scan/merge/stage flags below apply. Reject the combinations + // explicitly so a user invoking + // `msgvault deduplicate --undo X --delete-dups-from-source-server` + // gets an error instead of having the destructive flag silently + // ignored. + deduplicateCmd.MarkFlagsMutuallyExclusive("undo", "delete-dups-from-source-server") + deduplicateCmd.MarkFlagsMutuallyExclusive("undo", "prefer") + deduplicateCmd.MarkFlagsMutuallyExclusive("undo", "content-hash") + deduplicateCmd.MarkFlagsMutuallyExclusive("undo", "no-backup") + deduplicateCmd.MarkFlagsMutuallyExclusive("undo", "yes") +} diff --git a/cmd/msgvault/cmd/deduplicate_test.go b/cmd/msgvault/cmd/deduplicate_test.go new file mode 100644 index 00000000..1d38557f --- /dev/null +++ b/cmd/msgvault/cmd/deduplicate_test.go @@ -0,0 +1,133 @@ +package cmd + +import ( + "strings" + "testing" + + "github.com/spf13/cobra" + "github.com/wesm/msgvault/internal/testutil" + "github.com/wesm/msgvault/internal/testutil/storetest" +) + +// TestDeduplicateMutualExclusion confirms that passing both --account and +// --collection to the deduplicate command is rejected by cobra. +func TestDeduplicateMutualExclusion(t *testing.T) { + // Build a minimal parent so Execute() returns errors rather than printing + // them and swallowing them via the global rootCmd error handler. + var a, b string + cmd := &cobra.Command{Use: "dedup-test", SilenceErrors: true} + sub := &cobra.Command{Use: "deduplicate", RunE: func(cmd *cobra.Command, args []string) error { return nil }} + sub.Flags().StringVar(&a, "account", "", "") + sub.Flags().StringVar(&b, "collection", "", "") + sub.MarkFlagsMutuallyExclusive("account", "collection") + cmd.AddCommand(sub) + cmd.SetArgs([]string{"deduplicate", "--account", "alpha@example.com", "--collection", "work"}) + + err := cmd.Execute() + if err == nil { + t.Fatal("expected error when both --account and --collection are set, got nil") + } + msg := err.Error() + if !strings.Contains(msg, "account") || !strings.Contains(msg, "collection") { + t.Errorf("error should mention both flag names; got: %q", msg) + } + _ = a + _ = b +} + +// TestDeduplicateCollectionResolution confirms that --collection resolves +// successfully when the name matches a real collection in the store. +func TestDeduplicateCollectionResolution(t *testing.T) { + f, _, collectionName := setupScopeFixture(t) + + scope, err := ResolveCollectionFlag(f.Store, collectionName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if scope.Collection == nil { + t.Fatal("expected Collection to be populated") + } + if scope.Collection.Name != collectionName { + t.Errorf("collection name = %q, want %q", scope.Collection.Name, collectionName) + } + ids := scope.SourceIDs() + if len(ids) == 0 { + t.Error("expected non-empty SourceIDs for collection") + } +} + +// TestDeduplicateCollectionResolution_MultiSource confirms SourceIDs expands +// to all members when a collection has more than one source. +func TestDeduplicateCollectionResolution_MultiSource(t *testing.T) { + f := storetest.New(t) + + src2, err := f.Store.GetOrCreateSource("mbox", "backup@example.com") + testutil.MustNoErr(t, err, "GetOrCreateSource src2") + + collName := "two-account-collection" + _, err = f.Store.CreateCollection(collName, "", []int64{f.Source.ID, src2.ID}) + testutil.MustNoErr(t, err, "CreateCollection") + + scope, err := ResolveCollectionFlag(f.Store, collName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + ids := scope.SourceIDs() + if len(ids) != 2 { + t.Errorf("expected 2 source IDs, got %d: %v", len(ids), ids) + } + if scope.DisplayName() != collName { + t.Errorf("DisplayName = %q, want %q", scope.DisplayName(), collName) + } +} + +// TestPrintAccumulatedUndoHint asserts the helper's behavior: +// no-op for <2 batches, prints recipe for ≥2. Iter15 follow-up: +// the exit-on-Execute-error path now also calls this helper so a +// user who hits an error mid-loop still sees how to undo what +// already ran. +func TestPrintAccumulatedUndoHint(t *testing.T) { + for _, tc := range []struct { + name string + batches []string + wantContains []string + wantNoOutput bool + }{ + { + name: "no batches", + batches: nil, + wantNoOutput: true, + }, + { + name: "single batch", + batches: []string{"dedup-1"}, + wantNoOutput: true, + }, + { + name: "two batches", + batches: []string{"dedup-a", "dedup-b"}, + wantContains: []string{ + "To undo all of the above", + "--undo dedup-a", + "--undo dedup-b", + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + done := captureStdout(t) + printAccumulatedUndoHint(tc.batches) + out := done() + if tc.wantNoOutput { + if out != "" { + t.Errorf("expected no output, got %q", out) + } + return + } + for _, want := range tc.wantContains { + if !strings.Contains(out, want) { + t.Errorf("output missing %q; got:\n%s", want, out) + } + } + }) + } +} diff --git a/cmd/msgvault/cmd/delete_deduped.go b/cmd/msgvault/cmd/delete_deduped.go new file mode 100644 index 00000000..df864a9f --- /dev/null +++ b/cmd/msgvault/cmd/delete_deduped.go @@ -0,0 +1,194 @@ +package cmd + +import ( + "fmt" + "path/filepath" + "time" + + "github.com/spf13/cobra" +) + +var deleteDedupedCmd = &cobra.Command{ + Use: "delete-deduped", + Short: "Permanently delete dedup-hidden messages from the local archive", + Long: `Permanently delete dedup-hidden messages from the local archive. This is +the third rung of the safety progression: scan -> hide -> local hard +delete -> remote delete. Each rung is a separate, explicit user action. + +Use --batch to delete rows hidden by a specific dedup batch. +Use --all-hidden to delete every dedup-hidden row regardless of batch. + +Deleted rows cannot be recovered with --undo. Pending remote-deletion +manifests still reference Gmail/IMAP message IDs and remain valid +after a local delete. + +Parquet analytics and the vector index may contain stale entries for +deleted rows until rebuilt; the rebuild commands are separate. Run +'msgvault build-cache --full-rebuild' for parquet analytics and +'msgvault build-embeddings --full-rebuild' for the vector index.`, + RunE: runDeleteDeduped, +} + +var ( + deleteDedupedBatchIDs []string + deleteDedupedAllHidden bool + deleteDedupedNoBackup bool + deleteDedupedYes bool +) + +func runDeleteDeduped(cmd *cobra.Command, _ []string) error { + // delete-deduped mutates local SQLite directly, has no remote API + // equivalent, and the local DB is not reachable in remote mode. + // Reject upfront so the user gets a clear error rather than the + // generic "must specify --batch or --all-hidden" hint. + if IsRemoteMode() { + return fmt.Errorf("delete-deduped is local-only; not supported in remote mode") + } + + if len(deleteDedupedBatchIDs) == 0 && !deleteDedupedAllHidden { + return fmt.Errorf("must specify --batch or --all-hidden") + } + + st, err := openStoreAndInit() + if err != nil { + return err + } + defer func() { _ = st.Close() }() + + // compute pre-delete stats and totalN before prompting. + var totalN int64 + if deleteDedupedAllHidden { + // Match DeleteAllDeduped's predicate exactly: only rows that + // the dedup pipeline soft-hid (deleted_at IS NOT NULL AND + // delete_batch_id IS NOT NULL) are eligible for purge, so the + // prompt counts must use the same gate. A bare deleted_at row + // would over-report compared to the actual delete. + var distinctBatches int64 + err = st.DB().QueryRow( + st.Rebind("SELECT COUNT(*) FROM messages WHERE deleted_at IS NOT NULL AND delete_batch_id IS NOT NULL"), + ).Scan(&totalN) + if err != nil { + return fmt.Errorf("count hidden messages: %w", err) + } + err = st.DB().QueryRow( + st.Rebind("SELECT COUNT(DISTINCT delete_batch_id) FROM messages WHERE deleted_at IS NOT NULL AND delete_batch_id IS NOT NULL"), + ).Scan(&distinctBatches) + if err != nil { + return fmt.Errorf("count distinct batches: %w", err) + } + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Will permanently delete %d hidden message(s) from %d distinct batch(es).\n", + totalN, distinctBatches) + } else { + type batchStat struct { + id string + cnt int64 + } + stats := make([]batchStat, 0, len(deleteDedupedBatchIDs)) + for _, id := range deleteDedupedBatchIDs { + var cnt int64 + err = st.DB().QueryRow( + st.Rebind("SELECT COUNT(*) FROM messages WHERE delete_batch_id = ? AND deleted_at IS NOT NULL"), + id, + ).Scan(&cnt) + if err != nil { + return fmt.Errorf("count rows for batch %q: %w", id, err) + } + totalN += cnt + stats = append(stats, batchStat{id: id, cnt: cnt}) + } + out := cmd.OutOrStdout() + _, _ = fmt.Fprintf(out, "Will permanently delete %d hidden message(s) from %d batch(es):\n", + totalN, len(deleteDedupedBatchIDs)) + for _, s := range stats { + _, _ = fmt.Fprintf(out, " %s: %d row(s)\n", s.id, s.cnt) + } + } + + if totalN == 0 { + _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Nothing to delete.") + return nil + } + + // --all-hidden always prompts, even when --yes is set; spec rung 03 invariant. + // Mode picks how EOF is handled: AllHidden treats closed stdin as a contract + // violation (must not be silently bypassed), YesNo treats it as cancel. + if !deleteDedupedYes || deleteDedupedAllHidden { + mode := ConfirmModeYesNo + if deleteDedupedAllHidden { + mode = ConfirmModeAllHidden + } + ok, err := confirmDestructive(cmd.InOrStdin(), cmd.OutOrStdout(), mode) + if err != nil { + return err + } + if !ok { + _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Aborted.") + return nil + } + } + + if !deleteDedupedNoBackup { + // Resolve the DSN to a real filesystem path so backups work + // when [data].database_url is a "file:" URI; reject non-file + // DSNs (postgres://, etc.) which the VACUUM INTO backup path + // can't operate on. + dbFilePath, err := cfg.DatabasePath() + if err != nil { + return fmt.Errorf("resolve database path: %w", err) + } + backupPath := filepath.Join( + filepath.Dir(dbFilePath), + filepath.Base(dbFilePath)+".delete-deduped-backup-"+time.Now().Format("20060102-150405"), + ) + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Backing up database to %s...\n", filepath.Base(backupPath)) + if err := backupDatabase(st, backupPath); err != nil { + return fmt.Errorf("backup database: %w", err) + } + } + + // Note: parquet analytics and the vector index may contain entries + // for deleted rows; the post-run summary recommends rebuilding each + // separately ('build-cache --full-rebuild' and + // 'build-embeddings --full-rebuild'). + + var deletedTotal int64 + var batchCount int64 + if deleteDedupedAllHidden { + deleted, distinct, err := st.DeleteAllDeduped() + if err != nil { + return fmt.Errorf("delete all dedup-hidden: %w", err) + } + deletedTotal = deleted + batchCount = distinct + } else { + batchCount = int64(len(deleteDedupedBatchIDs)) + for _, id := range deleteDedupedBatchIDs { + deleted, err := st.DeleteDedupedBatch(id) + if err != nil { + return fmt.Errorf("delete dedup batch %q: %w", id, err) + } + deletedTotal += deleted + } + } + + out := cmd.OutOrStdout() + _, _ = fmt.Fprintf(out, "\nDeleted %d message(s) from %d batch(es).\n\n", deletedTotal, batchCount) + _, _ = fmt.Fprintln(out, "Caches may have stale entries; rebuild each separately:") + _, _ = fmt.Fprintln(out, " 'msgvault build-cache --full-rebuild' (parquet analytics)") + _, _ = fmt.Fprintln(out, " 'msgvault build-embeddings --full-rebuild' (vector index, if enabled)") + + return nil +} + +func init() { + rootCmd.AddCommand(deleteDedupedCmd) + deleteDedupedCmd.Flags().StringArrayVar(&deleteDedupedBatchIDs, "batch", nil, + "Delete rows hidden by this batch ID (repeat for multiple batches)") + deleteDedupedCmd.Flags().BoolVar(&deleteDedupedAllHidden, "all-hidden", false, + "Delete every dedup-hidden row regardless of batch") + deleteDedupedCmd.MarkFlagsMutuallyExclusive("batch", "all-hidden") + deleteDedupedCmd.Flags().BoolVar(&deleteDedupedNoBackup, "no-backup", false, + "Skip database backup before deleting") + deleteDedupedCmd.Flags().BoolVarP(&deleteDedupedYes, "yes", "y", false, + "Skip confirmation prompt") +} diff --git a/cmd/msgvault/cmd/delete_deduped_test.go b/cmd/msgvault/cmd/delete_deduped_test.go new file mode 100644 index 00000000..a1f823c9 --- /dev/null +++ b/cmd/msgvault/cmd/delete_deduped_test.go @@ -0,0 +1,65 @@ +package cmd + +import ( + "fmt" + "strings" + "testing" + + "github.com/spf13/cobra" +) + +// TestDeleteDeduped_NeitherFlag verifies that omitting both --batch and +// --all-hidden produces an error mentioning both flag names. +func TestDeleteDeduped_NeitherFlag(t *testing.T) { + var batch []string + var allHidden bool + cmd := &cobra.Command{Use: "delete-test", SilenceErrors: true} + sub := &cobra.Command{ + Use: "delete-deduped", + RunE: func(cmd *cobra.Command, args []string) error { + if len(batch) == 0 && !allHidden { + return fmt.Errorf("must specify --batch or --all-hidden") + } + return nil + }, + } + sub.Flags().StringArrayVar(&batch, "batch", nil, "") + sub.Flags().BoolVar(&allHidden, "all-hidden", false, "") + sub.MarkFlagsMutuallyExclusive("batch", "all-hidden") + cmd.AddCommand(sub) + cmd.SetArgs([]string{"delete-deduped"}) + + err := cmd.Execute() + if err == nil { + t.Fatal("expected error when neither --batch nor --all-hidden is set, got nil") + } + msg := err.Error() + if !strings.Contains(msg, "--batch") || !strings.Contains(msg, "--all-hidden") { + t.Errorf("error should mention both flag names; got: %q", msg) + } +} + +// TestDeleteDeduped_MutualExclusion verifies that passing both --batch and +// --all-hidden is rejected by cobra. +func TestDeleteDeduped_MutualExclusion(t *testing.T) { + var batch []string + var allHidden bool + cmd := &cobra.Command{Use: "delete-test", SilenceErrors: true} + sub := &cobra.Command{Use: "delete-deduped", RunE: func(cmd *cobra.Command, args []string) error { return nil }} + sub.Flags().StringArrayVar(&batch, "batch", nil, "") + sub.Flags().BoolVar(&allHidden, "all-hidden", false, "") + sub.MarkFlagsMutuallyExclusive("batch", "all-hidden") + cmd.AddCommand(sub) + cmd.SetArgs([]string{"delete-deduped", "--batch", "some-id", "--all-hidden"}) + + err := cmd.Execute() + if err == nil { + t.Fatal("expected error when both --batch and --all-hidden are set, got nil") + } + msg := err.Error() + if !strings.Contains(msg, "batch") || !strings.Contains(msg, "all-hidden") { + t.Errorf("error should mention both flag names; got: %q", msg) + } + _ = batch + _ = allHidden +} diff --git a/cmd/msgvault/cmd/deletions.go b/cmd/msgvault/cmd/deletions.go index b9456951..b9d3424e 100644 --- a/cmd/msgvault/cmd/deletions.go +++ b/cmd/msgvault/cmd/deletions.go @@ -1,9 +1,11 @@ package cmd import ( + "bufio" "context" "errors" "fmt" + "io" "os" "os/signal" "path/filepath" @@ -30,55 +32,62 @@ with their ID, status, message count, and creation date.`, if err != nil { return fmt.Errorf("create manager: %w", err) } + return runListDeletionsForManager(manager, cmd.OutOrStdout()) + }, +} - // List all statuses - pending, err := manager.ListPending() - if err != nil { - return fmt.Errorf("list pending deletions: %w", err) - } - inProgress, err := manager.ListInProgress() - if err != nil { - return fmt.Errorf("list in-progress deletions: %w", err) - } - completed, err := manager.ListCompleted() - if err != nil { - return fmt.Errorf("list completed deletions: %w", err) - } - failed, err := manager.ListFailed() - if err != nil { - return fmt.Errorf("list failed deletions: %w", err) - } +func runListDeletionsForManager(mgr *deletion.Manager, w io.Writer) error { + pending, err := mgr.ListPending() + if err != nil { + return fmt.Errorf("list pending deletions: %w", err) + } + inProgress, err := mgr.ListInProgress() + if err != nil { + return fmt.Errorf("list in-progress deletions: %w", err) + } + completed, err := mgr.ListCompleted() + if err != nil { + return fmt.Errorf("list completed deletions: %w", err) + } + failed, err := mgr.ListFailed() + if err != nil { + return fmt.Errorf("list failed deletions: %w", err) + } + cancelled, err := mgr.ListCancelled() + if err != nil { + return fmt.Errorf("list cancelled deletions: %w", err) + } - if len(pending) == 0 && len(inProgress) == 0 && len(completed) == 0 && len(failed) == 0 { - fmt.Println("No deletion batches found.") - fmt.Println("\nTo stage messages for deletion, use the TUI or create a manifest manually.") - return nil - } + if len(pending) == 0 && len(inProgress) == 0 && len(completed) == 0 && len(failed) == 0 && len(cancelled) == 0 { + _, _ = fmt.Fprintln(w, "No deletion batches found.") + _, _ = fmt.Fprintln(w, "\nTo stage messages for deletion, use the TUI or create a manifest manually.") + return nil + } - printManifestTable := func(status string, manifests []*deletion.Manifest) { - if len(manifests) == 0 { - return - } - fmt.Printf("\n%s:\n", status) - fmt.Printf(" %-25s %-10s %10s %s\n", "ID", "Status", "Messages", "Created") - fmt.Printf(" %-25s %-10s %10s %s\n", "---", "------", "--------", "-------") - for _, m := range manifests { - fmt.Printf(" %-25s %-10s %10d %s\n", - truncate(m.ID, 25), - m.Status, - len(m.GmailIDs), - m.CreatedAt.Format("2006-01-02 15:04"), - ) - } + printManifestTable := func(status string, manifests []*deletion.Manifest) { + if len(manifests) == 0 { + return } + _, _ = fmt.Fprintf(w, "\n%s:\n", status) + _, _ = fmt.Fprintf(w, " %-25s %-10s %10s %s\n", "ID", "Status", "Messages", "Created") + _, _ = fmt.Fprintf(w, " %-25s %-10s %10s %s\n", "---", "------", "--------", "-------") + for _, m := range manifests { + _, _ = fmt.Fprintf(w, " %-25s %-10s %10d %s\n", + truncate(m.ID, 25), + m.Status, + len(m.GmailIDs), + m.CreatedAt.Format("2006-01-02 15:04"), + ) + } + } - printManifestTable("Pending", pending) - printManifestTable("In Progress", inProgress) - printManifestTable("Completed (recent)", limitManifests(completed, 10)) - printManifestTable("Failed", failed) + printManifestTable("Pending", pending) + printManifestTable("In Progress", inProgress) + printManifestTable("Completed (recent)", limitManifests(completed, 10)) + printManifestTable("Failed", failed) + printManifestTable("Cancelled (recent)", limitManifests(cancelled, 10)) - return nil - }, + return nil } var showDeletionCmd = &cobra.Command{ @@ -201,27 +210,50 @@ Examples: } var ( - deleteTrash bool // Use trash instead of permanent delete - deleteYes bool - deleteDryRun bool - deleteList bool - deleteAccount string + // deletePermanent opts in to permanent batch deletion. Default is + // trash (30-day Gmail recovery), which is the safer choice for the + // v1 release: every other rung of the deletion progression + // (dedup-hide, local hard delete) is locally reversible, so the + // remote rung should be too unless the user explicitly says + // otherwise. + deletePermanent bool + deleteYes bool + deleteDryRun bool + deleteList bool + deleteAccount string ) +// remoteDeleteEnvVar gates execution of staged deletions against Gmail +// for the v1 release. Staging, listing, and inspecting manifests stay +// available unconditionally so the rest of the pipeline can be exercised; +// only the destructive Gmail-API call is gated. +const remoteDeleteEnvVar = "MSGVAULT_ENABLE_REMOTE_DELETE" + +func remoteDeleteEnabled() bool { + return os.Getenv(remoteDeleteEnvVar) == "1" +} + var deleteStagedCmd = &cobra.Command{ Use: "delete-staged [batch-id]", Short: "Execute staged deletions", Long: `Execute pending deletion batches. -By default, messages are permanently deleted using batch API (fast, no recovery). -Use --trash to move messages to Gmail trash instead (recoverable for 30 days, slower). +By default, messages are moved to Gmail trash (recoverable for 30 days). +Use --permanent for batch-API permanent deletion (fast, no recovery). +The default is trash because every other rung of the deletion progression +in msgvault is locally reversible; the remote rung is too unless the user +explicitly opts out of recoverability. + +Execution is gated for the v1 release. Set MSGVAULT_ENABLE_REMOTE_DELETE=1 to +opt in. Read-only modes (--list, --dry-run) work without the gate. Examples: - msgvault delete-staged # Permanent delete all pending (fast) - msgvault delete-staged batch-123 # Delete specific batch - msgvault delete-staged --list # Show staged batches without executing - msgvault delete-staged --trash # Move to trash instead (slower) - msgvault delete-staged --yes # Skip confirmation`, + msgvault delete-staged --list # Show staged batches (always allowed) + msgvault delete-staged --dry-run # Preview without executing (always allowed) + MSGVAULT_ENABLE_REMOTE_DELETE=1 msgvault delete-staged + MSGVAULT_ENABLE_REMOTE_DELETE=1 msgvault delete-staged batch-123 + MSGVAULT_ENABLE_REMOTE_DELETE=1 msgvault delete-staged --permanent + MSGVAULT_ENABLE_REMOTE_DELETE=1 msgvault delete-staged --yes`, RunE: func(cmd *cobra.Command, args []string) error { deletionsDir := filepath.Join(cfg.Data.DataDir, "deletions") manager, err := deletion.NewManager(deletionsDir) @@ -285,9 +317,9 @@ Examples: } // Show summary - method := "PERMANENT DELETE (fast, no recovery)" - if deleteTrash { - method = "trash (30-day recovery, slower)" + method := "trash (30-day recovery)" + if deletePermanent { + method = "PERMANENT DELETE (fast, no recovery)" } fmt.Printf("Deletion Summary:\n") @@ -307,13 +339,46 @@ Examples: return nil } + // Gate the destructive Gmail-API call for the v1 release. + // --list and --dry-run already returned above without hitting this. + if !remoteDeleteEnabled() { + return fmt.Errorf( + "remote deletion is gated in this release; "+ + "set %s=1 to opt in "+ + "(use 'msgvault delete-staged --list' or --dry-run to inspect "+ + "staged batches without executing)", + remoteDeleteEnvVar, + ) + } + // Require confirmation - if !deleteYes { - fmt.Print("Proceed with deletion? [y/N]: ") - var response string - _, _ = fmt.Scanln(&response) - if response != "y" && response != "Y" { - fmt.Println("Cancelled.") + if deletePermanent { + ok, err := confirmDestructive(cmd.InOrStdin(), cmd.OutOrStdout(), ConfirmModePermanent) + if err != nil { + return err + } + if !ok { + return nil + } + } else if !deleteYes { + // Trash path is reversible (~30-day Gmail recovery), so the + // shared confirmDestructive helper's "irreversible" wording + // would be misleading here. Hand-rolled prompt matches the + // action's reversibility — accepts y/Y/yes/Yes for parity + // with the shared helper's input contract. + out := cmd.OutOrStdout() + _, _ = fmt.Fprint(out, "Proceed with deletion? Messages move to Gmail/Trash (recoverable ~30 days). [y/N]: ") + scanner := bufio.NewScanner(cmd.InOrStdin()) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return fmt.Errorf("read confirmation: %w", err) + } + _, _ = fmt.Fprintln(out, "Cancelled.") + return nil + } + answer := strings.TrimSpace(strings.ToLower(scanner.Text())) + if answer != "y" && answer != "yes" { + _, _ = fmt.Fprintln(out, "Cancelled.") return nil } } @@ -329,6 +394,9 @@ Examples: if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrations(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } // Collect unique accounts from manifests accountSet := make(map[string]bool) @@ -437,7 +505,7 @@ Examples: return err } - needsBatchDelete := !deleteTrash + needsBatchDelete := deletePermanent if needsBatchDelete { requiredScopes := oauth.ScopesDeletion oauthMgr, err := oauth.NewManagerWithScopes(clientSecretsPath, cfg.TokensDir(), logger, requiredScopes) @@ -467,7 +535,7 @@ Examples: } } scopes := oauth.Scopes - if !deleteTrash { + if deletePermanent { scopes = oauth.ScopesDeletion } return oauth.NewManagerWithScopes(secretsPath, cfg.TokensDir(), logger, scopes) @@ -475,7 +543,7 @@ Examples: // For permanent deletion (not trash), service-account flows need the // elevated mail.google.com scope; trash-only uses the standard set. saScopes := oauth.Scopes - if !deleteTrash { + if deletePermanent { saScopes = oauth.ScopesDeletion } client, err := buildAPIClient(ctx, src, getOAuthMgr, saScopes) @@ -498,8 +566,8 @@ Examples: var execErr error // For in-progress manifests, honor the stored method to avoid - // accidentally switching from trash to permanent delete mid-batch - useTrash := deleteTrash + // accidentally switching between trash and permanent mid-batch. + useTrash := !deletePermanent if m.Status == deletion.StatusInProgress && m.Execution != nil { useTrash = (m.Execution.Method == deletion.MethodTrash) } @@ -714,7 +782,7 @@ func promptScopeEscalation(ctx context.Context, oauthMgr *oauth.Manager, account _, _ = fmt.Scanln(&response) if response != "y" && response != "Y" { if batchDelete { - fmt.Println("Cancelled. Use --trash for slower deletion without elevated permissions.") + fmt.Println("Cancelled. Drop --permanent to use trash deletion without elevated permissions.") } else { fmt.Println("Cancelled.") } @@ -755,12 +823,13 @@ func isInsufficientScopeError(err error) bool { } func init() { - deleteStagedCmd.Flags().BoolVar(&deleteTrash, "trash", false, "Move to trash instead of permanent delete (slower)") + deleteStagedCmd.Flags().BoolVar(&deletePermanent, "permanent", false, "DESTRUCTIVE: permanently delete via batch API instead of moving to trash (fast, no recovery)") deleteStagedCmd.Flags().BoolVarP(&deleteYes, "yes", "y", false, "Skip confirmation") deleteStagedCmd.Flags().BoolVar(&deleteDryRun, "dry-run", false, "Show what would be deleted") deleteStagedCmd.Flags().BoolVarP(&deleteList, "list", "l", false, "List staged batches without executing") deleteStagedCmd.Flags().StringVar(&deleteAccount, "account", "", "Account to use (Gmail or IMAP)") + deleteStagedCmd.MarkFlagsMutuallyExclusive("permanent", "yes") rootCmd.AddCommand(listDeletionsCmd) rootCmd.AddCommand(showDeletionCmd) cancelDeletionCmd.Flags().BoolVar(&cancelAll, "all", false, "Cancel all pending and in-progress batches") diff --git a/cmd/msgvault/cmd/deletions_test.go b/cmd/msgvault/cmd/deletions_test.go new file mode 100644 index 00000000..c5d11859 --- /dev/null +++ b/cmd/msgvault/cmd/deletions_test.go @@ -0,0 +1,66 @@ +package cmd + +import ( + "bytes" + "path/filepath" + "strings" + "testing" + + "github.com/spf13/cobra" + "github.com/wesm/msgvault/internal/deletion" +) + +func TestDeleteStaged_PermanentAndYesMutuallyExclusive(t *testing.T) { + cmd := &cobra.Command{ + Use: "delete-staged", + RunE: func(cmd *cobra.Command, args []string) error { return nil }, + } + var permanent, yes bool + cmd.Flags().BoolVar(&permanent, "permanent", false, "") + cmd.Flags().BoolVarP(&yes, "yes", "y", false, "") + cmd.MarkFlagsMutuallyExclusive("permanent", "yes") + cmd.SetArgs([]string{"--permanent", "--yes"}) + cmd.SetOut(new(bytes.Buffer)) + cmd.SetErr(new(bytes.Buffer)) + err := cmd.Execute() + if err == nil { + t.Fatalf("err = nil, want mutual-exclusion error") + } + if !strings.Contains(err.Error(), "permanent") || !strings.Contains(err.Error(), "yes") { + t.Errorf("err = %q, want substrings 'permanent' and 'yes'", err.Error()) + } +} + +func TestListDeletions_ShowsCancelled(t *testing.T) { + tmpDir := t.TempDir() + mgr, err := deletion.NewManager(tmpDir) + if err != nil { + t.Fatalf("NewManager: %v", err) + } + + manifest := deletion.NewManifest("test cancel", []string{"abc123"}) + if err := manifest.Save(filepath.Join(tmpDir, "pending", manifest.ID+".json")); err != nil { + t.Fatalf("save manifest: %v", err) + } + if err := mgr.CancelManifest(manifest.ID); err != nil { + t.Fatalf("CancelManifest: %v", err) + } + + var buf bytes.Buffer + if err := runListDeletionsForManager(mgr, &buf); err != nil { + t.Fatalf("runListDeletionsForManager: %v", err) + } + + if !strings.Contains(buf.String(), "Cancelled") { + t.Errorf("output missing 'Cancelled' header:\n%s", buf.String()) + } + // The ID is truncated to 25 chars in the table; check the first 20 chars + // (the timestamp prefix) which always survive truncation. + idPrefix := manifest.ID + if len(idPrefix) > 20 { + idPrefix = idPrefix[:20] + } + if !strings.Contains(buf.String(), idPrefix) { + t.Errorf("output missing manifest ID prefix %q:\n%s", idPrefix, buf.String()) + } +} diff --git a/cmd/msgvault/cmd/embed_vector_test.go b/cmd/msgvault/cmd/embed_vector_test.go index 3348aad4..b8042330 100644 --- a/cmd/msgvault/cmd/embed_vector_test.go +++ b/cmd/msgvault/cmd/embed_vector_test.go @@ -37,6 +37,7 @@ func openTestBackend(t *testing.T) *sqlitevec.Backend { schema := ` CREATE TABLE messages ( id INTEGER PRIMARY KEY, + deleted_at DATETIME, deleted_from_source_at DATETIME );` if _, err := main.Exec(schema); err != nil { diff --git a/cmd/msgvault/cmd/export_attachments.go b/cmd/msgvault/cmd/export_attachments.go index 578e4607..1186b6b2 100644 --- a/cmd/msgvault/cmd/export_attachments.go +++ b/cmd/msgvault/cmd/export_attachments.go @@ -45,6 +45,9 @@ func runExportAttachments(cmd *cobra.Command, args []string) error { if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrations(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } engine := query.NewSQLiteEngine(s.DB()) diff --git a/cmd/msgvault/cmd/export_eml.go b/cmd/msgvault/cmd/export_eml.go index 82b610f7..d6ef263b 100644 --- a/cmd/msgvault/cmd/export_eml.go +++ b/cmd/msgvault/cmd/export_eml.go @@ -94,6 +94,9 @@ func runExportEML(cmd *cobra.Command, messageRef, outputPath string) error { if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrations(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } engine := query.NewSQLiteEngine(s.DB()) diff --git a/cmd/msgvault/cmd/identity.go b/cmd/msgvault/cmd/identity.go new file mode 100644 index 00000000..af213b51 --- /dev/null +++ b/cmd/msgvault/cmd/identity.go @@ -0,0 +1,419 @@ +package cmd + +import ( + "encoding/json" + "fmt" + "io" + "slices" + "sort" + "strings" + "text/tabwriter" + "time" + + "github.com/spf13/cobra" + "github.com/wesm/msgvault/internal/store" +) + +var ( + identityListAccount string + identityListCollection string + identityListJSON bool + identityShowJSON bool + identityAddSignal string +) + +var identityCmd = &cobra.Command{ + Use: "identity", + Short: "Manage the confirmed \"me\" identifiers for each account", + Long: `Each account has one identity: the set of identifiers (email +addresses, phone numbers, chat handles, synthetic identifiers) that mean +"me" inside that account. Dedup's sent-copy detection compares a message's +From: against the identifiers confirmed for the message's account. + +Identifiers are stored verbatim; case is preserved so synthetic identifiers +like Slack member IDs and Matrix MXIDs round-trip correctly. Email-address +case-insensitivity is handled at compare time by consumers, not at the store.`, +} + +var identityListCmd = &cobra.Command{ + Use: "list", + Short: "List confirmed identifiers across one or more accounts", + RunE: runIdentityList, +} + +func runIdentityList(cmd *cobra.Command, _ []string) error { + st, err := openStoreAndInit() + if err != nil { + return err + } + defer func() { _ = st.Close() }() + + var sourceIDs []int64 + switch { + case identityListAccount != "": + scope, err := ResolveAccountFlag(st, identityListAccount) + if err != nil { + return err + } + sourceIDs = scope.SourceIDs() + case identityListCollection != "": + scope, err := ResolveCollectionFlag(st, identityListCollection) + if err != nil { + return err + } + sourceIDs = scope.SourceIDs() + default: + sources, err := st.ListSources("") + if err != nil { + return fmt.Errorf("list sources: %w", err) + } + sourceIDs = make([]int64, len(sources)) + for i, src := range sources { + sourceIDs[i] = src.ID + } + } + + rows, err := collectIdentityRows(st, sourceIDs) + if err != nil { + return err + } + w := cmd.OutOrStdout() + if identityListJSON { + return writeIdentityJSON(w, rows) + } + return writeIdentityTable(w, rows) +} + +// identityRow is the unified view used by both `identity list` and +// `identity show`. (none) rows have empty Identifier and Signal. +type identityRow struct { + Account string + SourceID int64 + SourceType string + Identifier string + Signals []string + ConfirmedAt time.Time + None bool +} + +// collectIdentityRows assembles per-source rows for the given source IDs. +// For each source, it emits one row per confirmed identifier; if a source +// has zero confirmed identifiers, it emits a single (none) row so the +// account is still visible. +func collectIdentityRows(st *store.Store, sourceIDs []int64) ([]identityRow, error) { + var out []identityRow + for _, sid := range sourceIDs { + src, err := st.GetSourceByID(sid) + if err != nil { + return nil, fmt.Errorf("get source %d: %w", sid, err) + } + identifiers, err := st.ListAccountIdentities(sid) + if err != nil { + return nil, fmt.Errorf("list identities for source %d: %w", sid, err) + } + if len(identifiers) == 0 { + out = append(out, identityRow{ + Account: src.Identifier, + SourceID: src.ID, + SourceType: src.SourceType, + None: true, + }) + continue + } + for _, ai := range identifiers { + out = append(out, identityRow{ + Account: src.Identifier, + SourceID: src.ID, + SourceType: src.SourceType, + Identifier: ai.Address, + Signals: splitSignalSet(ai.SourceSignal), + ConfirmedAt: ai.ConfirmedAt, + }) + } + } + sort.SliceStable(out, func(i, j int) bool { + if out[i].Account != out[j].Account { + return out[i].Account < out[j].Account + } + return out[i].Identifier < out[j].Identifier + }) + return out, nil +} + +// splitSignalSet parses a stored source_signal field into a sorted slice. +// Empty input returns an empty slice (so JSON encoding emits [], not null). +// Empty parts (from stray commas in legacy data) are filtered to mirror +// mergeSignalSet's producer-side normalization. +func splitSignalSet(s string) []string { + if s == "" { + return []string{} + } + parts := strings.Split(s, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + if p != "" { + out = append(out, p) + } + } + sort.Strings(out) + return out +} + +func writeIdentityTable(w io.Writer, rows []identityRow) error { + if len(rows) == 0 { + _, _ = fmt.Fprintln(w, "No accounts in scope.") + return nil + } + tw := tabwriter.NewWriter(w, 0, 0, 2, ' ', 0) + _, _ = fmt.Fprintln(tw, "ACCOUNT\tSOURCE_TYPE\tIDENTIFIER\tSIGNALS\tCONFIRMED") + confirmedCount := 0 + accountCount := 0 + seenAccounts := make(map[int64]struct{}) + noIdentityCount := 0 + for _, r := range rows { + if _, seen := seenAccounts[r.SourceID]; !seen { + accountCount++ + seenAccounts[r.SourceID] = struct{}{} + } + if r.None { + noIdentityCount++ + _, _ = fmt.Fprintf(tw, "%s\t%s\t(none)\t-\t-\n", + r.Account, r.SourceType) + continue + } + confirmedCount++ + _, _ = fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\n", + r.Account, r.SourceType, r.Identifier, + strings.Join(r.Signals, ","), + r.ConfirmedAt.Format("2006-01-02 15:04")) + } + _ = tw.Flush() + _, _ = fmt.Fprintf(w, "---\n%d confirmed identifier(s) across %d account(s); %d account(s) have no identity.\n", + confirmedCount, accountCount, noIdentityCount) + return nil +} + +func writeIdentityJSON(w io.Writer, rows []identityRow) error { + type entry struct { + Account string `json:"account"` + SourceID int64 `json:"source_id"` + SourceType string `json:"source_type"` + Identifier string `json:"identifier"` + Signals []string `json:"signals"` + ConfirmedAt time.Time `json:"confirmed_at"` + } + out := make([]entry, 0, len(rows)) + for _, r := range rows { + if r.None { + continue + } + out = append(out, entry{ + Account: r.Account, + SourceID: r.SourceID, + SourceType: r.SourceType, + Identifier: r.Identifier, + Signals: r.Signals, + ConfirmedAt: r.ConfirmedAt, + }) + } + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + return enc.Encode(out) +} + +var identityShowCmd = &cobra.Command{ + Use: "show ", + Short: "Show one account's identity in detail", + Args: cobra.ExactArgs(1), + RunE: runIdentityShow, +} + +func runIdentityShow(cmd *cobra.Command, args []string) error { + st, err := openStoreAndInit() + if err != nil { + return err + } + defer func() { _ = st.Close() }() + + scope, err := ResolveAccountFlag(st, args[0]) + if err != nil { + return err + } + if scope.Source == nil { + return fmt.Errorf("no account found for %q", args[0]) + } + + rows, err := collectIdentityRows(st, []int64{scope.Source.ID}) + if err != nil { + return err + } + if identityShowJSON { + return writeIdentityJSON(cmd.OutOrStdout(), rows) + } + if err := writeIdentityTable(cmd.OutOrStdout(), rows); err != nil { + return err + } + if len(rows) == 1 && rows[0].None { + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "\nThis account has no confirmed identity. Add one with:\n") + _, _ = fmt.Fprintf(cmd.OutOrStdout(), " msgvault identity add %s \n", scope.Source.Identifier) + } + return nil +} + +var identityAddCmd = &cobra.Command{ + Use: "add ", + Short: "Add a confirmed identifier to an account's identity", + Args: cobra.ExactArgs(2), + RunE: runIdentityAdd, +} + +func runIdentityAdd(cmd *cobra.Command, args []string) error { + st, err := openStoreAndInit() + if err != nil { + return err + } + defer func() { _ = st.Close() }() + + accountArg, identifierArg := args[0], args[1] + identifier := strings.TrimSpace(identifierArg) + if identifier == "" { + return fmt.Errorf("identifier cannot be empty") + } + if strings.Contains(identityAddSignal, ",") { + return fmt.Errorf("signal names cannot contain commas: %q", identityAddSignal) + } + + scope, err := ResolveAccountFlag(st, accountArg) + if err != nil { + return err + } + if scope.Source == nil { + return fmt.Errorf("no account found for %q", accountArg) + } + + existing, err := st.ListAccountIdentities(scope.Source.ID) + if err != nil { + return fmt.Errorf("list existing: %w", err) + } + // Match the SQL-side LOWER() rule used by AddAccountIdentity so a + // re-add of "Foo@x.com" against a stored "foo@x.com" hits the + // "already confirmed" / "additional signal" branches instead of + // silently looking new at the CLI layer. + var prevSignals []string + for _, ai := range existing { + if store.EqualIdentifier(ai.Address, identifier) { + prevSignals = splitSignalSet(ai.SourceSignal) + break + } + } + + if err := st.AddAccountIdentity(scope.Source.ID, identifier, identityAddSignal); err != nil { + return fmt.Errorf("add identity: %w", err) + } + + switch { + case len(prevSignals) == 0: + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Added %s to %s (signal: %s).\n", + identifier, scope.Source.Identifier, identityAddSignal) + case slices.Contains(prevSignals, identityAddSignal): + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s already confirmed for %s with signal %s.\n", + identifier, scope.Source.Identifier, identityAddSignal) + default: + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Recorded additional signal %s for %s on %s.\n", + identityAddSignal, identifier, scope.Source.Identifier) + } + return nil +} + +var identityRemoveCmd = &cobra.Command{ + Use: "remove ", + Short: "Remove a confirmed identifier from an account's identity", + Args: cobra.ExactArgs(2), + RunE: runIdentityRemove, +} + +func runIdentityRemove(cmd *cobra.Command, args []string) error { + st, err := openStoreAndInit() + if err != nil { + return err + } + defer func() { _ = st.Close() }() + + identifier := strings.TrimSpace(args[1]) + if identifier == "" { + return fmt.Errorf("identifier must not be empty") + } + + scope, err := ResolveAccountFlag(st, args[0]) + if err != nil { + return err + } + if scope.Source == nil { + return fmt.Errorf("no account found for %q", args[0]) + } + + removed, err := st.RemoveAccountIdentity(scope.Source.ID, identifier) + if err != nil { + return fmt.Errorf("remove identity: %w", err) + } + if removed == 0 { + existing, listErr := st.ListAccountIdentities(scope.Source.ID) + if listErr != nil { + return fmt.Errorf("%s is not in %s's identity (and looking up the current set failed: %w)", + identifier, scope.Source.Identifier, listErr) + } + var have []string + for _, ai := range existing { + have = append(have, ai.Address) + } + if len(have) == 0 { + return fmt.Errorf("%s is not in %s's identity (no confirmed identifiers on this account)", + identifier, scope.Source.Identifier) + } + return fmt.Errorf("%s is not in %s's identity. Currently confirmed: %s", + identifier, scope.Source.Identifier, strings.Join(have, ", ")) + } + switch removed { + case 1: + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Removed %s from %s.\n", identifier, scope.Source.Identifier) + default: + // >1 means a legacy database held case-variant duplicates of an + // email-shaped identifier; the case-fold remove cleaned them up + // in one call. Report the count so the user knows. + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Removed %d entries matching %s from %s.\n", + removed, identifier, scope.Source.Identifier) + } + + // Best-effort post-remove warning. If the lookup errors we suppress + // the warning rather than risk a misleading "no identity left" + // message — the remove itself already succeeded and was reported. + rest, listErr := st.ListAccountIdentities(scope.Source.ID) + if listErr == nil && len(rest) == 0 { + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Warning: %s now has no confirmed identity. "+ + "Dedup sent-copy detection for this account will rely on is_from_me "+ + "and SENT label signals only.\n", scope.Source.Identifier) + } + return nil +} + +func init() { + rootCmd.AddCommand(identityCmd) + identityCmd.AddCommand(identityListCmd) + identityCmd.AddCommand(identityShowCmd) + identityCmd.AddCommand(identityAddCmd) + identityCmd.AddCommand(identityRemoveCmd) + + identityListCmd.Flags().StringVar(&identityListAccount, + "account", "", "Restrict to a single account") + identityListCmd.Flags().StringVar(&identityListCollection, + "collection", "", "Restrict to all member accounts of one collection") + identityListCmd.MarkFlagsMutuallyExclusive("account", "collection") + identityListCmd.Flags().BoolVar(&identityListJSON, + "json", false, "Output as JSON") + identityShowCmd.Flags().BoolVar(&identityShowJSON, + "json", false, "Output as JSON") + identityAddCmd.Flags().StringVar(&identityAddSignal, + "signal", "manual", + "Evidence signal name (e.g. manual, account-identifier, phone-e164). "+ + "Cannot contain commas.") +} diff --git a/cmd/msgvault/cmd/identity_test.go b/cmd/msgvault/cmd/identity_test.go new file mode 100644 index 00000000..6a060e82 --- /dev/null +++ b/cmd/msgvault/cmd/identity_test.go @@ -0,0 +1,448 @@ +package cmd + +import ( + "bytes" + "encoding/json" + "io" + "log/slog" + "path/filepath" + "strings" + "testing" + + "github.com/spf13/cobra" + "github.com/wesm/msgvault/internal/config" + "github.com/wesm/msgvault/internal/store" +) + +// newIdentityCLITest creates an isolated store and test root command for +// identity subcommand tests. Returns (store, root, stdout buffer, stderr buffer). +func newIdentityCLITest(t *testing.T) (*store.Store, *cobra.Command, *bytes.Buffer, *bytes.Buffer) { + t.Helper() + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "msgvault.db") + + s, err := store.Open(dbPath) + if err != nil { + t.Fatal(err) + } + if err := s.InitSchema(); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = s.Close() }) + + // Save and restore package-level globals. + savedCfg := cfg + savedLogger := logger + savedAccount := identityListAccount + savedCollection := identityListCollection + savedListJSON := identityListJSON + savedShowJSON := identityShowJSON + savedAddSignal := identityAddSignal + t.Cleanup(func() { + cfg = savedCfg + logger = savedLogger + identityListAccount = savedAccount + identityListCollection = savedCollection + identityListJSON = savedListJSON + identityShowJSON = savedShowJSON + identityAddSignal = savedAddSignal + // Reset cobra's "Changed" state so mutually-exclusive flag groups + // don't carry over between tests that share the package-level command. + for _, name := range []string{"account", "collection", "json"} { + if f := identityListCmd.Flags().Lookup(name); f != nil { + f.Changed = false + } + } + if f := identityShowCmd.Flags().Lookup("json"); f != nil { + f.Changed = false + } + if f := identityAddCmd.Flags().Lookup("signal"); f != nil { + f.Changed = false + } + }) + + cfg = &config.Config{ + HomeDir: tmpDir, + Data: config.DataConfig{DataDir: tmpDir}, + } + logger = slog.New(slog.NewTextHandler(io.Discard, nil)) + + var stdout, stderr bytes.Buffer + root := newTestRootCmd() + root.SetOut(&stdout) + root.SetErr(&stderr) + root.AddCommand(identityCmd) + + return s, root, &stdout, &stderr +} + +func TestIdentityList_NoScope(t *testing.T) { + s, root, out, _ := newIdentityCLITest(t) + a, _ := s.GetOrCreateSource("gmail", "alice@example.com") + b, _ := s.GetOrCreateSource("imap", "bob@example.com") + _ = s.AddAccountIdentity(a.ID, "alice@example.com", "account-identifier") + _ = s.AddAccountIdentity(b.ID, "bob@example.com", "account-identifier") + + root.SetArgs([]string{"identity", "list"}) + if err := root.Execute(); err != nil { + t.Fatal(err) + } + text := out.String() + if !strings.Contains(text, "alice@example.com") { + t.Errorf("missing alice in output: %s", text) + } + if !strings.Contains(text, "bob@example.com") { + t.Errorf("missing bob in output: %s", text) + } + if !strings.Contains(text, "ACCOUNT") { + t.Errorf("missing header in output: %s", text) + } +} + +func TestIdentityList_AccountFilter(t *testing.T) { + s, root, out, _ := newIdentityCLITest(t) + a, _ := s.GetOrCreateSource("gmail", "alice@example.com") + _, _ = s.GetOrCreateSource("imap", "bob@example.com") + _ = s.AddAccountIdentity(a.ID, "alice@example.com", "manual") + + root.SetArgs([]string{"identity", "list", "--account", "alice@example.com"}) + if err := root.Execute(); err != nil { + t.Fatal(err) + } + text := out.String() + if !strings.Contains(text, "alice@example.com") { + t.Errorf("missing alice: %s", text) + } + if strings.Contains(text, "bob@example.com") { + t.Errorf("bob leaked into account-filtered output: %s", text) + } +} + +func TestIdentityList_AccountWithNoneRow(t *testing.T) { + s, root, out, _ := newIdentityCLITest(t) + _, _ = s.GetOrCreateSource("mbox", "old-mbox-2018") + + root.SetArgs([]string{"identity", "list"}) + if err := root.Execute(); err != nil { + t.Fatal(err) + } + text := out.String() + if !strings.Contains(text, "(none)") { + t.Errorf("expected (none) row for account with no identifiers: %s", text) + } +} + +func TestIdentityList_JSONShape(t *testing.T) { + s, root, out, _ := newIdentityCLITest(t) + a, _ := s.GetOrCreateSource("gmail", "alice@example.com") + _ = s.AddAccountIdentity(a.ID, "alice@example.com", "manual") + + root.SetArgs([]string{"identity", "list", "--json"}) + if err := root.Execute(); err != nil { + t.Fatal(err) + } + var rows []map[string]any + if err := json.Unmarshal(out.Bytes(), &rows); err != nil { + t.Fatalf("json decode: %v\n%s", err, out.String()) + } + if len(rows) != 1 { + t.Fatalf("want 1 row, got %d: %+v", len(rows), rows) + } + sigs, ok := rows[0]["signals"].([]any) + if !ok || len(sigs) != 1 || sigs[0] != "manual" { + t.Errorf("signals=%v", rows[0]["signals"]) + } +} + +func TestIdentityList_JSONEmptySignals(t *testing.T) { + s, root, out, _ := newIdentityCLITest(t) + a, _ := s.GetOrCreateSource("gmail", "alice@example.com") + _ = s.AddAccountIdentity(a.ID, "alice@example.com", "") // empty signal + + root.SetArgs([]string{"identity", "list", "--json"}) + if err := root.Execute(); err != nil { + t.Fatal(err) + } + // Unmarshal into raw JSON to check the literal value (not Go nil). + raw := out.Bytes() + if !strings.Contains(string(raw), `"signals": []`) { + t.Errorf("expected signals to be [] not null; got: %s", raw) + } + var rows []map[string]any + if err := json.Unmarshal(raw, &rows); err != nil { + t.Fatalf("json decode: %v\n%s", err, raw) + } + if len(rows) != 1 { + t.Fatalf("want 1 row, got %d", len(rows)) + } + sigs, ok := rows[0]["signals"].([]any) + if !ok { + t.Errorf("signals field is not a JSON array; got %T(%v)", rows[0]["signals"], rows[0]["signals"]) + } else if len(sigs) != 0 { + t.Errorf("want empty signals array, got %v", sigs) + } +} + +func TestIdentityShow_Populated(t *testing.T) { + s, root, out, _ := newIdentityCLITest(t) + a, _ := s.GetOrCreateSource("gmail", "alice@example.com") + _ = s.AddAccountIdentity(a.ID, "alice@example.com", "account-identifier") + + root.SetArgs([]string{"identity", "show", "alice@example.com"}) + if err := root.Execute(); err != nil { + t.Fatal(err) + } + if !strings.Contains(out.String(), "alice@example.com") { + t.Errorf("missing alice: %s", out.String()) + } +} + +func TestIdentityShow_Empty(t *testing.T) { + s, root, out, _ := newIdentityCLITest(t) + _, _ = s.GetOrCreateSource("gmail", "alice@example.com") + + root.SetArgs([]string{"identity", "show", "alice@example.com"}) + if err := root.Execute(); err != nil { + t.Fatal(err) + } + text := out.String() + if !strings.Contains(text, "(none)") { + t.Errorf("missing (none) row: %s", text) + } + if !strings.Contains(text, "identity add") { + t.Errorf("missing hint: %s", text) + } +} + +func TestIdentityShow_UnknownAccount(t *testing.T) { + _, root, _, _ := newIdentityCLITest(t) + root.SetArgs([]string{"identity", "show", "ghost@example.com"}) + err := root.Execute() + if err == nil { + t.Fatal("expected error") + } +} + +func TestIdentityShow_JSONShape(t *testing.T) { + s, root, out, _ := newIdentityCLITest(t) + a, _ := s.GetOrCreateSource("gmail", "alice@example.com") + _ = s.AddAccountIdentity(a.ID, "alice@example.com", "manual") + + root.SetArgs([]string{"identity", "show", "alice@example.com", "--json"}) + if err := root.Execute(); err != nil { + t.Fatal(err) + } + var rows []map[string]any + if err := json.Unmarshal(out.Bytes(), &rows); err != nil { + t.Fatalf("json decode: %v\n%s", err, out.String()) + } + if len(rows) != 1 { + t.Fatalf("want 1 row, got %d: %+v", len(rows), rows) + } + if rows[0]["account"] != "alice@example.com" { + t.Errorf("account=%v", rows[0]["account"]) + } + if rows[0]["identifier"] != "alice@example.com" { + t.Errorf("identifier=%v", rows[0]["identifier"]) + } + sigs, ok := rows[0]["signals"].([]any) + if !ok { + t.Errorf("signals field is not a JSON array; got %T(%v)", rows[0]["signals"], rows[0]["signals"]) + } else if len(sigs) != 1 || sigs[0] != "manual" { + t.Errorf("signals=%v", sigs) + } +} + +func TestIdentityShow_JSONEmpty(t *testing.T) { + s, root, out, _ := newIdentityCLITest(t) + _, _ = s.GetOrCreateSource("gmail", "alice@example.com") + + root.SetArgs([]string{"identity", "show", "alice@example.com", "--json"}) + if err := root.Execute(); err != nil { + t.Fatal(err) + } + var rows []map[string]any + if err := json.Unmarshal(out.Bytes(), &rows); err != nil { + t.Fatalf("json decode: %v\n%s", err, out.String()) + } + if len(rows) != 0 { + t.Fatalf("want empty slice, got %d rows: %+v", len(rows), rows) + } +} + +func TestIdentityAdd_FirstTime(t *testing.T) { + s, root, out, _ := newIdentityCLITest(t) + _, _ = s.GetOrCreateSource("gmail", "alice@example.com") + + root.SetArgs([]string{"identity", "add", "alice@example.com", "extra@example.com"}) + if err := root.Execute(); err != nil { + t.Fatal(err) + } + if !strings.Contains(out.String(), "Added extra@example.com") { + t.Errorf("missing add confirmation: %s", out.String()) + } +} + +func TestIdentityAdd_IdempotentSameSignal(t *testing.T) { + s, root, out, _ := newIdentityCLITest(t) + a, _ := s.GetOrCreateSource("gmail", "alice@example.com") + _ = s.AddAccountIdentity(a.ID, "extra@example.com", "manual") + + root.SetArgs([]string{"identity", "add", "alice@example.com", "extra@example.com"}) + if err := root.Execute(); err != nil { + t.Fatal(err) + } + if !strings.Contains(out.String(), "already confirmed") { + t.Errorf("missing idempotent confirmation: %s", out.String()) + } +} + +func TestIdentityAdd_AdditionalSignal(t *testing.T) { + s, root, out, _ := newIdentityCLITest(t) + a, _ := s.GetOrCreateSource("gmail", "alice@example.com") + _ = s.AddAccountIdentity(a.ID, "extra@example.com", "manual") + + root.SetArgs([]string{"identity", "add", "alice@example.com", "extra@example.com", + "--signal", "account-identifier"}) + if err := root.Execute(); err != nil { + t.Fatal(err) + } + if !strings.Contains(out.String(), "additional signal") { + t.Errorf("missing additional-signal confirmation: %s", out.String()) + } +} + +func TestIdentityAdd_RejectsCommaInSignal(t *testing.T) { + s, root, _, _ := newIdentityCLITest(t) + _, _ = s.GetOrCreateSource("gmail", "alice@example.com") + root.SetArgs([]string{"identity", "add", "alice@example.com", "foo@example.com", + "--signal", "a,b"}) + err := root.Execute() + if err == nil || !strings.Contains(err.Error(), "comma") { + t.Fatalf("want comma error, got %v", err) + } +} + +func TestIdentityAdd_RejectsEmptyIdentifier(t *testing.T) { + s, root, _, _ := newIdentityCLITest(t) + _, _ = s.GetOrCreateSource("gmail", "alice@example.com") + root.SetArgs([]string{"identity", "add", "alice@example.com", " "}) + err := root.Execute() + if err == nil || !strings.Contains(err.Error(), "empty") { + t.Fatalf("want empty-identifier error, got %v", err) + } +} + +func TestIdentityAdd_RejectsCollectionAsAccount(t *testing.T) { + s, root, _, _ := newIdentityCLITest(t) + a, _ := s.GetOrCreateSource("gmail", "alice@example.com") + _, _ = s.CreateCollection("team", "", []int64{a.ID}) + + root.SetArgs([]string{"identity", "add", "team", "extra@example.com"}) + err := root.Execute() + if err == nil || !strings.Contains(err.Error(), "collection") { + t.Fatalf("want collection-rejection error, got %v", err) + } +} + +func TestIdentityRemove_Hit(t *testing.T) { + s, root, out, _ := newIdentityCLITest(t) + a, _ := s.GetOrCreateSource("gmail", "alice@example.com") + _ = s.AddAccountIdentity(a.ID, "alice@example.com", "manual") + _ = s.AddAccountIdentity(a.ID, "extra@example.com", "manual") + + root.SetArgs([]string{"identity", "remove", "alice@example.com", "extra@example.com"}) + if err := root.Execute(); err != nil { + t.Fatal(err) + } + if !strings.Contains(out.String(), "Removed extra@example.com") { + t.Errorf("missing remove confirmation: %s", out.String()) + } +} + +func TestIdentityRemove_Miss(t *testing.T) { + s, root, out, errOut := newIdentityCLITest(t) + a, _ := s.GetOrCreateSource("gmail", "alice@example.com") + _ = s.AddAccountIdentity(a.ID, "alice@example.com", "manual") + + root.SetArgs([]string{"identity", "remove", "alice@example.com", "ghost@example.com"}) + err := root.Execute() + if err == nil { + t.Fatal("expected error on miss") + } + combined := out.String() + errOut.String() + err.Error() + if !strings.Contains(combined, "Currently confirmed:") { + t.Errorf("error should hint at present identifiers: %s", combined) + } +} + +func TestIdentityRemove_MissOnEmptyAccount(t *testing.T) { + s, root, _, _ := newIdentityCLITest(t) + _, _ = s.GetOrCreateSource("gmail", "alice@example.com") + + root.SetArgs([]string{"identity", "remove", "alice@example.com", "ghost@example.com"}) + err := root.Execute() + if err == nil { + t.Fatal("expected error on miss") + } + if !strings.Contains(err.Error(), "no confirmed identifiers") { + t.Errorf("missing zero-identifiers explanation: %v", err) + } +} + +func TestIdentityRemove_WhitespaceIdentifier(t *testing.T) { + _, root, _, _ := newIdentityCLITest(t) + + root.SetArgs([]string{"identity", "remove", "alice@example.com", " "}) + err := root.Execute() + if err == nil { + t.Fatal("expected error for whitespace identifier") + } + if !strings.Contains(err.Error(), "identifier must not be empty") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestIdentityRemove_LastIdentifierWarns(t *testing.T) { + s, root, out, _ := newIdentityCLITest(t) + a, _ := s.GetOrCreateSource("gmail", "alice@example.com") + _ = s.AddAccountIdentity(a.ID, "alice@example.com", "manual") + + root.SetArgs([]string{"identity", "remove", "alice@example.com", "alice@example.com"}) + if err := root.Execute(); err != nil { + t.Fatal(err) + } + if !strings.Contains(out.String(), "no confirmed identity") { + t.Errorf("missing degraded-dedup warning: %s", out.String()) + } +} + +func TestIdentityList_CollectionFilter(t *testing.T) { + s, root, out, _ := newIdentityCLITest(t) + a, _ := s.GetOrCreateSource("gmail", "alice@example.com") + b, _ := s.GetOrCreateSource("gmail", "bob@example.com") + c, _ := s.GetOrCreateSource("gmail", "carol@example.com") + _ = s.AddAccountIdentity(a.ID, "alice@example.com", "account-identifier") + _ = s.AddAccountIdentity(b.ID, "bob@example.com", "account-identifier") + _ = s.AddAccountIdentity(c.ID, "carol@example.com", "account-identifier") + + _, err := s.CreateCollection("team", "", []int64{a.ID, b.ID}) + if err != nil { + t.Fatal(err) + } + + root.SetArgs([]string{"identity", "list", "--collection", "team"}) + if err := root.Execute(); err != nil { + t.Fatal(err) + } + text := out.String() + if !strings.Contains(text, "alice@example.com") { + t.Errorf("missing alice in collection output: %s", text) + } + if !strings.Contains(text, "bob@example.com") { + t.Errorf("missing bob in collection output: %s", text) + } + if strings.Contains(text, "carol@example.com") { + t.Errorf("carol leaked into collection-filtered output: %s", text) + } +} diff --git a/cmd/msgvault/cmd/import.go b/cmd/msgvault/cmd/import.go index 62895865..4570f1d1 100644 --- a/cmd/msgvault/cmd/import.go +++ b/cmd/msgvault/cmd/import.go @@ -16,11 +16,12 @@ import ( ) var ( - importPhone string - importMediaDir string - importContacts string - importLimit int - importDisplayName string + importPhone string + importMediaDir string + importContacts string + importLimit int + importDisplayName string + noDefaultIdentityImportWhatsApp bool ) var importWhatsappCmd = &cobra.Command{ @@ -73,6 +74,9 @@ func runWhatsAppImport(cmd *cobra.Command, sourcePath string) error { if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrationsForIngest(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } // Set up context with cancellation. ctx, cancel := context.WithCancel(cmd.Context()) @@ -119,6 +123,18 @@ func runWhatsAppImport(cmd *cobra.Command, sourcePath string) error { return fmt.Errorf("import failed: %w", err) } + // Auto-default-identity must run BEFORE the legacy migration + // retry — see comment in account_identity.go. + if !noDefaultIdentityImportWhatsApp && summary.Errors == 0 && summary.SourceID != 0 { + confirmDefaultIdentity(cmd.OutOrStdout(), s, summary.SourceID, importPhone, importPhone, "phone-e164") + } + + if summary.SourceID != 0 { + if err := runPostSourceCreateMigrations(s); err != nil { + return fmt.Errorf("post-source-create migrations: %w", err) + } + } + // Import contacts if provided. if importContacts != "" { fmt.Printf("\nImporting contacts from %s...\n", importContacts) @@ -249,6 +265,7 @@ func init() { importWhatsappCmd.Flags().StringVar(&importContacts, "contacts", "", "path to contacts .vcf file for name resolution (optional)") importWhatsappCmd.Flags().IntVar(&importLimit, "limit", 0, "limit number of messages (for testing)") importWhatsappCmd.Flags().StringVar(&importDisplayName, "display-name", "", "display name for the phone owner") + importWhatsappCmd.Flags().BoolVar(&noDefaultIdentityImportWhatsApp, "no-default-identity", false, noDefaultIdentityHelp) _ = importWhatsappCmd.MarkFlagRequired("phone") rootCmd.AddCommand(importWhatsappCmd) diff --git a/cmd/msgvault/cmd/import_emlx.go b/cmd/msgvault/cmd/import_emlx.go index c1f8b052..d9014057 100644 --- a/cmd/msgvault/cmd/import_emlx.go +++ b/cmd/msgvault/cmd/import_emlx.go @@ -24,6 +24,7 @@ var ( importEmlxAccountsDB string importEmlxAccounts []string importEmlxIdentifier string + noDefaultIdentityImportEmlx bool ) var importEmlxCmd = &cobra.Command{ @@ -159,6 +160,9 @@ Examples: if err := st.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrationsForIngest(st); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } attachmentsDir := cfg.AttachmentsDir() if importEmlxNoAttachments { @@ -199,6 +203,31 @@ func importSingleAccount( return err } + // Auto-default-identity must run BEFORE the legacy migration + // retry — see comment in account_identity.go. + if ctx.Err() == nil && !summary.HardErrors && !noDefaultIdentityImportEmlx { + if summary.SourceID != 0 { + confirmDefaultIdentity(cmd.OutOrStdout(), st, summary.SourceID, identifier, identifier, "account-identifier") + } else { + logger.Warn("auto-default-identity: missing source id", "identifier", identifier) + } + } + + if summary.SourceID != 0 { + // Migration error returns before printImportSummary on + // purpose: this is the minimal in-scope shape for the deferred + // legacy [identity] migration retry that #304 introduces. The + // alternative (capture error, print summary, return after) + // would restructure pre-existing summary-print code that this + // PR otherwise leaves alone. The migration is idempotent — the + // next invocation retries and prints the summary then. UX + // polish tracked separately in + // private/drafts/2026-05-02-issue-import-migration-error-ux.md. + if err := runPostSourceCreateMigrations(st); err != nil { + return fmt.Errorf("post-source-create migrations: %w", err) + } + } + printImportSummary(cmd, ctx, *summary) return importResultError(ctx, *summary) } @@ -313,6 +342,32 @@ func importAutoAccounts( continue } + // Auto-default-identity must run BEFORE the legacy migration + // retry — see comment in account_identity.go. + if ctx.Err() == nil && !summary.HardErrors && !noDefaultIdentityImportEmlx { + accountDisplay := account.Identifier() + if account.Email != "" { + accountDisplay = account.Email + } + if summary.SourceID != 0 { + confirmDefaultIdentity(cmd.OutOrStdout(), st, summary.SourceID, accountDisplay, identifier, "account-identifier") + } else { + logger.Warn("auto-default-identity: missing source id", "identifier", identifier) + } + } + + if summary.SourceID != 0 { + // `continue` skips the per-account summary print on + // purpose — same scope rationale as importSingleAccount + // above. UX polish tracked in + // private/drafts/2026-05-02-issue-import-migration-error-ux.md + // for a follow-up PR. + if err := runPostSourceCreateMigrations(st); err != nil { + importErrors = append(importErrors, fmt.Errorf("%s: post-source-create migrations: %w", identifier, err)) + continue + } + } + printImportSummary(cmd, ctx, *summary) _, _ = fmt.Fprintln(out) @@ -434,4 +489,8 @@ func init() { &importEmlxIdentifier, "identifier", "", "Explicit email/identifier for single-directory import (manual fallback)", ) + importEmlxCmd.Flags().BoolVar( + &noDefaultIdentityImportEmlx, "no-default-identity", false, + noDefaultIdentityHelp, + ) } diff --git a/cmd/msgvault/cmd/import_gvoice.go b/cmd/msgvault/cmd/import_gvoice.go index beb2c403..0e56b9b9 100644 --- a/cmd/msgvault/cmd/import_gvoice.go +++ b/cmd/msgvault/cmd/import_gvoice.go @@ -14,9 +14,10 @@ import ( ) var ( - importGvoiceBefore string - importGvoiceAfter string - importGvoiceLimit int + importGvoiceBefore string + importGvoiceAfter string + importGvoiceLimit int + noDefaultIdentityImportGVoice bool ) var importGvoiceCmd = &cobra.Command{ @@ -39,7 +40,7 @@ Examples: func runImportGvoice(cmd *cobra.Command, args []string) error { takeoutDir := args[0] - s, err := openStoreAndInit() + s, err := openStoreAndInitForIngest() if err != nil { return err } @@ -95,6 +96,16 @@ func runImportGvoice(cmd *cobra.Command, args []string) error { return fmt.Errorf("import failed: %w", err) } + phone := client.Identifier() + // Auto-default-identity must run BEFORE the legacy migration + // retry — see comment in account_identity.go. + if !noDefaultIdentityImportGVoice && strings.HasPrefix(phone, "+") { + confirmDefaultIdentity(cmd.OutOrStdout(), s, src.ID, phone, phone, "phone-e164") + } + if err := runPostSourceCreateMigrations(s); err != nil { + return fmt.Errorf("post-source-create migrations: %w", err) + } + printGvoiceSummary(summary, startTime) rebuildCacheAfterWrite(cfg.DatabaseDSN()) return nil @@ -196,5 +207,9 @@ func init() { &importGvoiceLimit, "limit", 0, "limit number of messages (for testing)", ) + importGvoiceCmd.Flags().BoolVar( + &noDefaultIdentityImportGVoice, "no-default-identity", false, + noDefaultIdentityHelp, + ) rootCmd.AddCommand(importGvoiceCmd) } diff --git a/cmd/msgvault/cmd/import_imessage.go b/cmd/msgvault/cmd/import_imessage.go index b537846f..bb327c61 100644 --- a/cmd/msgvault/cmd/import_imessage.go +++ b/cmd/msgvault/cmd/import_imessage.go @@ -47,7 +47,7 @@ Examples: } func runImportImessage(cmd *cobra.Command, _ []string) error { - s, err := openStoreAndInit() + s, err := openStoreAndInitForIngest() if err != nil { return err } @@ -78,6 +78,9 @@ func runImportImessage(cmd *cobra.Command, _ []string) error { if err != nil { return fmt.Errorf("get or create source: %w", err) } + if err := runPostSourceCreateMigrations(s); err != nil { + return fmt.Errorf("post-source-create migrations: %w", err) + } ctx, cancel := context.WithCancel(cmd.Context()) defer cancel() @@ -119,6 +122,26 @@ func runImportImessage(cmd *cobra.Command, _ []string) error { } func openStoreAndInit() (*store.Store, error) { + // Shared by 15 commands (deduplicate, identity, collection, + // import-imessage/gvoice, delete-deduped, …). store.Open + InitSchema + // create the database file on first use, which is the right behavior + // for a freshly-installed CLI: a missing file is not an error here. + // init-db remains the explicit setup command for users who want to + // pre-create the DB. + return openStoreAndInitWith(runStartupMigrations) +} + +// openStoreAndInitForIngest is the ingest-command variant of +// openStoreAndInit. It uses runStartupMigrationsForIngest so that, on a +// fresh install with [identity] addresses configured but no source yet, +// the misleading "migration will run on the next command" notice does +// not fire. The post-source-create migration call run by ingest commands +// after GetOrCreateSource emits the accurate "applied" notice once. +func openStoreAndInitForIngest() (*store.Store, error) { + return openStoreAndInitWith(runStartupMigrationsForIngest) +} + +func openStoreAndInitWith(migrate func(*store.Store) error) (*store.Store, error) { dbPath := cfg.DatabaseDSN() s, err := store.Open(dbPath) if err != nil { @@ -128,6 +151,10 @@ func openStoreAndInit() (*store.Store, error) { _ = s.Close() return nil, fmt.Errorf("init schema: %w", err) } + if err := migrate(s); err != nil { + _ = s.Close() + return nil, fmt.Errorf("startup migrations: %w", err) + } return s, nil } diff --git a/cmd/msgvault/cmd/import_imessage_test.go b/cmd/msgvault/cmd/import_imessage_test.go index b6107dd7..bf8ac24c 100644 --- a/cmd/msgvault/cmd/import_imessage_test.go +++ b/cmd/msgvault/cmd/import_imessage_test.go @@ -1,11 +1,45 @@ package cmd import ( + "path/filepath" "testing" "github.com/wesm/msgvault/internal/store" ) +// TestImportIMessage_NoAutoDefaultIdentity pins the documented behavior: the +// apple_messages source uses identifier "local" and the spec explicitly excludes +// this ingest path from auto-default-identity. After source creation via +// resolveImessageSource, account_identities must remain empty. +func TestImportIMessage_NoAutoDefaultIdentity(t *testing.T) { + // After a successful import, account_identities has zero rows for the + // apple_messages source. The source identifier is "local"; we never + // auto-write because there's no per-user identifier known at source + // creation time. Spec § Auto-default-identity § "Ingest paths that do + // not auto-write". + s, err := store.Open(filepath.Join(t.TempDir(), "msgvault.db")) + if err != nil { + t.Fatal(err) + } + defer func() { _ = s.Close() }() + if err := s.InitSchema(); err != nil { + t.Fatal(err) + } + + src, err := resolveImessageSource(s) + if err != nil { + t.Fatalf("resolveImessageSource: %v", err) + } + + rows, err := s.ListAccountIdentities(src.ID) + if err != nil { + t.Fatalf("ListAccountIdentities: %v", err) + } + if len(rows) != 0 { + t.Errorf("expected no account_identities rows for apple_messages source, got %d: %+v", len(rows), rows) + } +} + func TestResolveImessageSource(t *testing.T) { tests := []struct { name string diff --git a/cmd/msgvault/cmd/import_mbox.go b/cmd/msgvault/cmd/import_mbox.go index 0bc677b3..856a3dbd 100644 --- a/cmd/msgvault/cmd/import_mbox.go +++ b/cmd/msgvault/cmd/import_mbox.go @@ -21,6 +21,7 @@ var ( importMboxNoResume bool importMboxCheckpointInterval int importMboxNoAttachments bool + noDefaultIdentityImportMbox bool ) type mboxCheckpoint struct { @@ -106,6 +107,9 @@ Examples: if err := st.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrationsForIngest(st); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } attachmentsDir := cfg.AttachmentsDir() if importMboxNoAttachments { @@ -118,6 +122,9 @@ Examples: } // If we're resuming, start from the active file in a multi-file zip export. + // Source creation here is for resume detection only; the post-import + // runPostSourceCreateMigrations call below covers both resume and + // --no-resume paths. if !importMboxNoResume { src, err := st.GetOrCreateSource(importMboxSourceType, identifier) if err != nil { @@ -240,6 +247,7 @@ Examples: totalErrors int64 totalBytes int64 hadHardErrors bool + sourceID int64 ) type processedFile struct { Path string @@ -268,6 +276,9 @@ Examples: totalLabelsUpdated += summary.LabelsUpdated totalErrors += summary.Errors totalBytes += summary.BytesProcessed + if sourceID == 0 && summary.SourceID != 0 { + sourceID = summary.SourceID + } if summary.HardErrors { hadHardErrors = true } @@ -291,6 +302,36 @@ Examples: } } + // Auto-default-identity must run BEFORE the legacy migration + // retry — see comment in account_identity.go. Earlier shape + // ran the migration first and confirmDefaultIdentity later, + // which suppressed the source's own account identifier + // whenever the legacy [identity] block had populated rows. + if ctx.Err() == nil && !hadHardErrors && !noDefaultIdentityImportMbox { + if sourceID != 0 { + confirmDefaultIdentity(cmd.OutOrStdout(), st, sourceID, identifier, identifier, "account-identifier") + } + } + + // Re-run startup migrations after the importer has had a chance + // to create the first source. Required when the deferred legacy + // identity migration parked at startup because no source existed. + // Cheap no-op once the migration sentinel is set. + // + // Migration error returns before the summary print on purpose: + // minimal in-scope shape for #304's deferred legacy [identity] + // migration retry. The alternative (capture, print summary, + // return after) would restructure pre-existing summary code + // this PR otherwise leaves alone. Migration is idempotent — + // next invocation retries and prints summary then. UX polish + // tracked in + // private/drafts/2026-05-02-issue-import-migration-error-ux.md. + if sourceID != 0 { + if err := runPostSourceCreateMigrations(st); err != nil { + return fmt.Errorf("post-source-create migrations: %w", err) + } + } + out := cmd.OutOrStdout() if ctx.Err() != nil { _, _ = fmt.Fprintln(out, "Import interrupted. Run again to resume.") @@ -334,4 +375,5 @@ func init() { importMboxCmd.Flags().BoolVar(&importMboxNoResume, "no-resume", false, "Do not resume from an interrupted import") importMboxCmd.Flags().IntVar(&importMboxCheckpointInterval, "checkpoint-interval", 200, "Save progress every N messages") importMboxCmd.Flags().BoolVar(&importMboxNoAttachments, "no-attachments", false, "Do not store attachments (disk or database). Messages will still be marked as having attachments. Note: rerunning later without --no-attachments will not backfill attachments for already-imported messages.") + importMboxCmd.Flags().BoolVar(&noDefaultIdentityImportMbox, "no-default-identity", false, noDefaultIdentityHelp) } diff --git a/cmd/msgvault/cmd/initdb.go b/cmd/msgvault/cmd/initdb.go index e33ebc75..42dc431c 100644 --- a/cmd/msgvault/cmd/initdb.go +++ b/cmd/msgvault/cmd/initdb.go @@ -28,6 +28,9 @@ created if they don't already exist.`, if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrations(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } logger.Info("database initialized successfully") diff --git a/cmd/msgvault/cmd/list_accounts.go b/cmd/msgvault/cmd/list_accounts.go index b21b313b..ce4a3c57 100644 --- a/cmd/msgvault/cmd/list_accounts.go +++ b/cmd/msgvault/cmd/list_accounts.go @@ -74,6 +74,9 @@ func listLocalAccounts() error { if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrations(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } sources, err := s.ListSources("") if err != nil { diff --git a/cmd/msgvault/cmd/list_domains.go b/cmd/msgvault/cmd/list_domains.go index a45c30bf..7ddb6f8a 100644 --- a/cmd/msgvault/cmd/list_domains.go +++ b/cmd/msgvault/cmd/list_domains.go @@ -37,6 +37,9 @@ Examples: if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrations(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } // Create query engine engine := query.NewSQLiteEngine(s.DB()) diff --git a/cmd/msgvault/cmd/list_labels.go b/cmd/msgvault/cmd/list_labels.go index efebeeeb..96597495 100644 --- a/cmd/msgvault/cmd/list_labels.go +++ b/cmd/msgvault/cmd/list_labels.go @@ -37,6 +37,9 @@ Examples: if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrations(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } // Create query engine engine := query.NewSQLiteEngine(s.DB()) diff --git a/cmd/msgvault/cmd/list_senders.go b/cmd/msgvault/cmd/list_senders.go index e00b8154..7e702895 100644 --- a/cmd/msgvault/cmd/list_senders.go +++ b/cmd/msgvault/cmd/list_senders.go @@ -37,6 +37,9 @@ Examples: if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrations(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } // Create query engine engine := query.NewSQLiteEngine(s.DB()) diff --git a/cmd/msgvault/cmd/quickstart.md b/cmd/msgvault/cmd/quickstart.md index a149c4f0..fd60c930 100644 --- a/cmd/msgvault/cmd/quickstart.md +++ b/cmd/msgvault/cmd/quickstart.md @@ -208,24 +208,34 @@ msgvault show-deletion # Cancel a pending batch msgvault cancel-deletion -# Execute pending deletions (permanent, fast — no recovery) -msgvault delete-staged --yes +# List staged batches without executing (always allowed) +msgvault delete-staged --list + +# Execute pending deletions (gated; default = trash, recoverable for 30 days) +MSGVAULT_ENABLE_REMOTE_DELETE=1 msgvault delete-staged --yes # Execute a specific batch -msgvault delete-staged +MSGVAULT_ENABLE_REMOTE_DELETE=1 msgvault delete-staged -# Move to trash instead (recoverable for 30 days, slower) -msgvault delete-staged --trash +# Permanently delete via batch API (fast, no recovery — opt-in) +MSGVAULT_ENABLE_REMOTE_DELETE=1 msgvault delete-staged --permanent -# Dry run — show what would be deleted without doing it +# Dry run — show what would be deleted without doing it (always allowed) msgvault delete-staged --dry-run # Specify which account to delete from -msgvault delete-staged --account user@gmail.com +MSGVAULT_ENABLE_REMOTE_DELETE=1 msgvault delete-staged --account user@gmail.com ``` -**Warning:** `delete-staged` without `--trash` permanently deletes messages from -Gmail. This is irreversible. Always verify with `--dry-run` first. +**Note:** Remote deletion is gated for the v1 release. Staging, listing, +inspecting, and dry-running deletion batches works without the gate; +executing against Gmail requires `MSGVAULT_ENABLE_REMOTE_DELETE=1` in the +environment. + +**Warning:** `delete-staged --permanent` permanently deletes messages from +Gmail with no Gmail-side recovery. The default mode moves messages to Gmail +trash, which is recoverable for 30 days. Always verify with `--dry-run` +first regardless of mode. ## Verify archive integrity diff --git a/cmd/msgvault/cmd/remove_account.go b/cmd/msgvault/cmd/remove_account.go index a8c8e58c..fcb770cc 100644 --- a/cmd/msgvault/cmd/remove_account.go +++ b/cmd/msgvault/cmd/remove_account.go @@ -72,6 +72,9 @@ func runRemoveAccount(cmd *cobra.Command, args []string) error { if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrations(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } source, err := resolveSource(s, email, sourceType) if err != nil { diff --git a/cmd/msgvault/cmd/repair_encoding.go b/cmd/msgvault/cmd/repair_encoding.go index 00066579..2055e7b5 100644 --- a/cmd/msgvault/cmd/repair_encoding.go +++ b/cmd/msgvault/cmd/repair_encoding.go @@ -48,6 +48,9 @@ charset detection issues in the MIME parser.`, if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrations(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } reembedNeededIDs, err := repairEncoding(s) if err != nil { diff --git a/cmd/msgvault/cmd/search.go b/cmd/msgvault/cmd/search.go index 8f6e29a2..2ec98eca 100644 --- a/cmd/msgvault/cmd/search.go +++ b/cmd/msgvault/cmd/search.go @@ -14,12 +14,13 @@ import ( ) var ( - searchLimit int - searchOffset int - searchJSON bool - searchAccount string - searchMode string - searchExplain bool + searchLimit int + searchOffset int + searchJSON bool + searchAccount string + searchCollection string + searchMode string + searchExplain bool ) var searchCmd = &cobra.Command{ @@ -57,8 +58,8 @@ Examples: // Join all args to form the query (allows unquoted multi-term searches) queryStr := strings.Join(args, " ") - if queryStr == "" && searchAccount == "" { - return fmt.Errorf("provide a search query or --account flag") + if queryStr == "" && searchAccount == "" && searchCollection == "" { + return fmt.Errorf("provide a search query or --account/--collection flag") } // Use remote search if configured @@ -68,26 +69,142 @@ Examples: "--account is not supported in remote mode", ) } + if searchCollection != "" { + return fmt.Errorf("--collection is not supported in remote mode") + } if searchMode != "fts" { return fmt.Errorf("--mode is not supported in remote mode") } return runRemoteSearch(queryStr) } + // Validate mode before any scope work so we fail fast on a typo. + if searchMode != "fts" && searchMode != "vector" && searchMode != "hybrid" { + return fmt.Errorf("invalid --mode: %q (want fts|vector|hybrid)", searchMode) + } + if searchMode != "fts" && searchOffset > 0 { + return fmt.Errorf("--offset is not supported with --mode=%s (pagination is single-page)", searchMode) + } + // Vector and hybrid modes need free-text terms to embed; both + // an empty raw query and a filter-only query (e.g. `from:alice`) + // would fail at the embed call. Check both up front and surface + // a CLI error rather than a late engine-level one. FTS still + // allows scoped queryless searches. if searchMode != "fts" { - if searchMode != "vector" && searchMode != "hybrid" { - return fmt.Errorf("invalid --mode: %q (want fts|vector|hybrid)", searchMode) + if queryStr == "" { + return fmt.Errorf("--mode=%s requires query text to embed; pass a query or use --mode=fts", searchMode) } - if searchOffset > 0 { - return fmt.Errorf("--offset is not supported with --mode=%s (pagination is single-page)", searchMode) + if len(search.Parse(queryStr).TextTerms) == 0 { + return fmt.Errorf("--mode=%s requires free-text terms to embed; %q parsed to filters only — add a search phrase or use --mode=fts", searchMode, queryStr) } - return runHybridSearch(cmd, queryStr, searchMode, searchExplain) } - return runLocalSearch(cmd, queryStr) + // Resolve --account / --collection once, before the mode branch, + // so FTS, vector, and hybrid all see the same SourceIDs. Earlier, + // scope was resolved inside runLocalSearch only and the vector + // path applied --account directly while ignoring --collection. + scope, scopedStore, err := resolveSearchScope(searchAccount, searchCollection) + if err != nil { + return err + } + + if searchMode != "fts" { + // Hybrid/vector path opens its own sql.DB directly. When a + // scoped store is in hand, schema init has already run on + // this DSN; otherwise we have to run it ourselves so the + // raw sql.DB inside runHybridSearch sees the deleted_at / + // deleted_from_source_at columns the vector backend filters + // on. Close immediately — migration state persists in the + // file on disk, and runHybridSearch will reopen. + if scopedStore != nil { + _ = scopedStore.Close() + } else { + if err := initLocalSchema(); err != nil { + return err + } + } + return runHybridSearch(cmd, queryStr, searchMode, searchExplain, scope) + } + return runLocalSearch(cmd, queryStr, scope, scopedStore) }, } +// initLocalSchema opens the local store, runs InitSchema and the +// startup migrations, then closes it. Used by the unscoped vector/ +// hybrid path so the raw sql.DB that runHybridSearch opens sees a +// fully-migrated schema (notably the deleted_at column the vector +// backend filters on, added by this branch's ALTER TABLE migration). +// Scoped queries don't need this because resolveSearchScope already +// runs the same init on the same DSN. +func initLocalSchema() error { + s, err := store.Open(cfg.DatabaseDSN()) + if err != nil { + return fmt.Errorf("open database: %w", err) + } + defer func() { _ = s.Close() }() + if err := s.InitSchema(); err != nil { + return fmt.Errorf("init schema: %w", err) + } + if err := runStartupMigrations(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } + return nil +} + +// resolveSearchScope opens the local store just long enough to resolve +// the user-supplied --account/--collection flag into a Scope. Returning +// the Scope (rather than just SourceIDs) lets callers print a banner +// using its DisplayName. The opened store is returned so a caller that +// needs it (runLocalSearch) can reuse it instead of re-running +// InitSchema + runStartupMigrations a second time. Callers that don't +// need the store must Close it themselves. +// +// When no scope flag was supplied, returns (Scope{}, nil, nil) and the +// caller is responsible for opening its own store. +func resolveSearchScope(account, collection string) (Scope, *store.Store, error) { + if account == "" && collection == "" { + return Scope{}, nil, nil + } + s, err := store.Open(cfg.DatabaseDSN()) + if err != nil { + return Scope{}, nil, fmt.Errorf("open database: %w", err) + } + if err := s.InitSchema(); err != nil { + _ = s.Close() + return Scope{}, nil, fmt.Errorf("init schema: %w", err) + } + if err := runStartupMigrations(s); err != nil { + _ = s.Close() + return Scope{}, nil, fmt.Errorf("startup migrations: %w", err) + } + switch { + case account != "": + scope, err := ResolveAccountFlag(s, account) + if err != nil { + _ = s.Close() + return Scope{}, nil, err + } + if scope.IsEmpty() { + _ = s.Close() + return Scope{}, nil, fmt.Errorf("--account %q resolved to zero sources", account) + } + return scope, s, nil + case collection != "": + scope, err := ResolveCollectionFlag(s, collection) + if err != nil { + _ = s.Close() + return Scope{}, nil, err + } + if len(scope.SourceIDs()) == 0 { + _ = s.Close() + return Scope{}, nil, fmt.Errorf("--collection %q has no member accounts", collection) + } + return scope, s, nil + } + _ = s.Close() + return Scope{}, nil, nil +} + // runRemoteSearch performs a search against the remote API. func runRemoteSearch(queryStr string) error { fmt.Fprintf(os.Stderr, "Searching %s...", cfg.Remote.URL) @@ -115,44 +232,61 @@ func runRemoteSearch(queryStr string) error { return outputRemoteSearchResultsTable(results, total) } -// runLocalSearch performs a search against the local database. -func runLocalSearch(cmd *cobra.Command, queryStr string) error { - // Parse the query +// runLocalSearch performs a search against the local database. The +// caller is expected to have resolved scope already; an empty Scope +// means no --account or --collection was supplied. If scopedStore is +// non-nil it carries an already-initialized store from scope +// resolution; runLocalSearch reuses it (avoiding a second +// InitSchema + runStartupMigrations pass). When scopedStore is nil +// (no scope flag supplied), runLocalSearch opens and initializes +// its own store. +func runLocalSearch(cmd *cobra.Command, queryStr string, scope Scope, scopedStore *store.Store) error { + // Parse the query and apply any pre-resolved scope before the + // emptiness check so a bare --account/--collection is enough to + // produce a non-empty query. q := search.Parse(queryStr) - - // Fail fast on invalid queries before touching the database, - // unless --account is set (which requires a DB lookup to resolve). - if searchAccount == "" && q.IsEmpty() { - return fmt.Errorf("empty search query") - } - - // Open database - dbPath := cfg.DatabaseDSN() - s, err := store.Open(dbPath) - if err != nil { - return fmt.Errorf("open database: %w", err) + if !scope.IsEmpty() { + q.AccountIDs = scope.SourceIDs() } - defer func() { _ = s.Close() }() - - // Ensure schema is up to date and FTS index is populated - if err := s.InitSchema(); err != nil { - return fmt.Errorf("init schema: %w", err) + if q.IsEmpty() { + if scopedStore != nil { + _ = scopedStore.Close() + } + return fmt.Errorf("empty search query") } - // Resolve --account and recheck emptiness. - if searchAccount != "" { - src, err := s.GetSourceByIdentifier(searchAccount) + var s *store.Store + if scopedStore != nil { + s = scopedStore + } else { + var err error + s, err = store.Open(cfg.DatabaseDSN()) if err != nil { - return fmt.Errorf("look up account: %w", err) + return fmt.Errorf("open database: %w", err) } - if src == nil { - return fmt.Errorf("account %q not found", searchAccount) + if err := s.InitSchema(); err != nil { + _ = s.Close() + return fmt.Errorf("init schema: %w", err) + } + if err := runStartupMigrations(s); err != nil { + _ = s.Close() + return fmt.Errorf("startup migrations: %w", err) } - q.AccountID = &src.ID } + defer func() { _ = s.Close() }() - if q.IsEmpty() { - return fmt.Errorf("empty search query") + // Print a scope banner when searching a collection. + if scope.IsCollection() { + members := scope.SourceIDs() + n := len(members) + suffix := "s" + if n == 1 { + suffix = "" + } + fmt.Fprintf(os.Stderr, + "Searching collection %q (%d account%s)\n", + scope.DisplayName(), n, suffix, + ) } fmt.Fprintf(os.Stderr, "Searching...") @@ -164,7 +298,7 @@ func runLocalSearch(cmd *cobra.Command, queryStr string) error { // Log the search operation. Raw query text and account // identifiers may contain PII — log coarse metadata at // info and full values only at debug. - hasAccount := q.AccountID != nil + hasAccount := len(q.AccountIDs) > 0 logger.Info("search start", "query_len", len(queryStr), "has_account", hasAccount, @@ -279,6 +413,9 @@ func init() { searchCmd.Flags().IntVar(&searchOffset, "offset", 0, "Skip first N results") searchCmd.Flags().BoolVar(&searchJSON, "json", false, "Output as JSON") searchCmd.Flags().StringVar(&searchAccount, "account", "", "Limit results to a specific account (email address)") + searchCmd.Flags().StringVar(&searchCollection, "collection", "", + "Limit results to all member accounts of one collection") + searchCmd.MarkFlagsMutuallyExclusive("account", "collection") searchCmd.Flags().StringVar(&searchMode, "mode", "fts", "Search mode: fts|vector|hybrid") searchCmd.Flags().BoolVar(&searchExplain, "explain", false, "Include per-signal scores in output (hybrid/vector modes)") } diff --git a/cmd/msgvault/cmd/search_test.go b/cmd/msgvault/cmd/search_test.go index 1978ad59..718b2b59 100644 --- a/cmd/msgvault/cmd/search_test.go +++ b/cmd/msgvault/cmd/search_test.go @@ -7,6 +7,8 @@ import ( "strings" "testing" + "github.com/spf13/cobra" + "github.com/spf13/pflag" "github.com/wesm/msgvault/internal/config" "github.com/wesm/msgvault/internal/store" ) @@ -49,9 +51,17 @@ func captureStdout(t *testing.T) func() string { func resetSearchFlags() { searchAccount = "" + searchCollection = "" searchLimit = 50 searchOffset = 0 searchJSON = false + searchMode = "fts" + searchExplain = false + // Cobra remembers per-flag `Changed` state on the global searchCmd + // across test invocations. Without clearing it, mutually-exclusive + // pairs (--account / --collection) trip when a subsequent test only + // passes one of them. + searchCmd.Flags().VisitAll(func(f *pflag.Flag) { f.Changed = false }) } func TestSearchCmd_AccountFlagRejectsRemoteMode(t *testing.T) { @@ -272,3 +282,197 @@ func TestSearchCmd_NoQueryNoAccount(t *testing.T) { t.Errorf("error = %q, want 'provide a search query'", err) } } + +// TestSearchCmd_CollectionFlagScopesResults seeds two accounts and one +// collection containing only the first, then runs FTS search with +// --collection. Only the first account's message must come back. +func TestSearchCmd_CollectionFlagScopesResults(t *testing.T) { + tmpDir := t.TempDir() + dbPath := tmpDir + "/msgvault.db" + + s, err := store.Open(dbPath) + if err != nil { + t.Fatalf("open store: %v", err) + } + if err := s.InitSchema(); err != nil { + t.Fatalf("init schema: %v", err) + } + src1, err := s.GetOrCreateSource("gmail", "alice@example.com") + if err != nil { + t.Fatalf("create source 1: %v", err) + } + src2, err := s.GetOrCreateSource("gmail", "bob@example.com") + if err != nil { + t.Fatalf("create source 2: %v", err) + } + conv1, err := s.EnsureConversation(src1.ID, "c1", "") + if err != nil { + t.Fatalf("create conv 1: %v", err) + } + conv2, err := s.EnsureConversation(src2.ID, "c2", "") + if err != nil { + t.Fatalf("create conv 2: %v", err) + } + if _, err := s.UpsertMessage(&store.Message{ + SourceID: src1.ID, ConversationID: conv1, + SourceMessageID: "m1", MessageType: "email", + Subject: sql.NullString{String: "Alice msg", Valid: true}, + SizeEstimate: 100, + }); err != nil { + t.Fatalf("insert msg 1: %v", err) + } + if _, err := s.UpsertMessage(&store.Message{ + SourceID: src2.ID, ConversationID: conv2, + SourceMessageID: "m2", MessageType: "email", + Subject: sql.NullString{String: "Bob msg", Valid: true}, + SizeEstimate: 200, + }); err != nil { + t.Fatalf("insert msg 2: %v", err) + } + if _, err := s.CreateCollection("alice-only", "", []int64{src1.ID}); err != nil { + t.Fatalf("create collection: %v", err) + } + _ = s.Close() + + savedCfg := cfg + defer func() { cfg = savedCfg; resetSearchFlags() }() + + cfg = &config.Config{ + HomeDir: tmpDir, + Data: config.DataConfig{DataDir: tmpDir}, + } + + done := captureStdout(t) + root := newTestRootCmd() + root.AddCommand(searchCmd) + root.SetArgs([]string{ + "search", "--collection", "alice-only", "--json", + }) + err = root.Execute() + out := done() + if err != nil { + t.Fatalf("collection-only search failed: %v", err) + } + if !strings.Contains(out, "Alice msg") { + t.Errorf("expected Alice's message in output, got: %s", out) + } + if strings.Contains(out, "Bob msg") { + t.Errorf("Bob's message must be filtered out, got: %s", out) + } +} + +// TestSearchCmd_CollectionFlagUnknown returns a clear error when the +// named collection does not exist. +func TestSearchCmd_CollectionFlagUnknown(t *testing.T) { + tmpDir := t.TempDir() + dbPath := tmpDir + "/msgvault.db" + s, err := store.Open(dbPath) + if err != nil { + t.Fatalf("open store: %v", err) + } + if err := s.InitSchema(); err != nil { + t.Fatalf("init schema: %v", err) + } + _ = s.Close() + + savedCfg := cfg + defer func() { cfg = savedCfg; resetSearchFlags() }() + cfg = &config.Config{ + HomeDir: tmpDir, + Data: config.DataConfig{DataDir: tmpDir}, + } + + root := newTestRootCmd() + root.AddCommand(searchCmd) + root.SetArgs([]string{ + "search", "--collection", "does-not-exist", "anything", + }) + err = root.Execute() + if err == nil { + t.Fatal("expected error for unknown collection") + } + if !strings.Contains(err.Error(), "no collection") { + t.Errorf("error = %q, want substring 'no collection'", err) + } +} + +// TestSearchCmd_VectorOrHybridRequireQueryText rejects empty-query +// vector/hybrid invocations even when scope flags are supplied. +// FTS allows queryless scoped searches; vector/hybrid don't, because +// the embeddings client needs text to vectorize. +func TestSearchCmd_VectorOrHybridRequireQueryText(t *testing.T) { + for _, mode := range []string{"vector", "hybrid"} { + t.Run(mode, func(t *testing.T) { + savedCfg := cfg + defer func() { cfg = savedCfg; resetSearchFlags() }() + + cfg = &config.Config{} + + root := newTestRootCmd() + root.AddCommand(searchCmd) + root.SetArgs([]string{ + "search", "--mode", mode, + "--account", "alice@example.com", + }) + err := root.Execute() + if err == nil { + t.Fatalf("expected error for queryless --mode=%s", mode) + } + if !strings.Contains(err.Error(), "requires query text") { + t.Errorf("error = %q, want substring 'requires query text'", err) + } + }) + } +} + +// TestSearchCmd_VectorOrHybridRejectFilterOnlyQuery rejects vector/ +// hybrid invocations whose query parses to filter terms only (no +// free-text). The embed client needs text to vectorize, so a query +// like `from:alice` would fail at the engine layer; reject it at the +// CLI surface instead. +func TestSearchCmd_VectorOrHybridRejectFilterOnlyQuery(t *testing.T) { + for _, mode := range []string{"vector", "hybrid"} { + t.Run(mode, func(t *testing.T) { + savedCfg := cfg + defer func() { cfg = savedCfg; resetSearchFlags() }() + + cfg = &config.Config{} + + root := newTestRootCmd() + root.AddCommand(searchCmd) + root.SetArgs([]string{ + "search", "--mode", mode, "from:alice", + }) + err := root.Execute() + if err == nil { + t.Fatalf("expected error for filter-only --mode=%s query", mode) + } + if !strings.Contains(err.Error(), "free-text terms") { + t.Errorf("error = %q, want substring 'free-text terms'", err) + } + }) + } +} + +// TestSearchCmd_MutualExclusion confirms --account and --collection are rejected together. +func TestSearchCmd_MutualExclusion(t *testing.T) { + var a, b string + cmd := &cobra.Command{Use: "search-test", SilenceErrors: true} + sub := &cobra.Command{Use: "search", RunE: func(cmd *cobra.Command, args []string) error { return nil }} + sub.Flags().StringVar(&a, "account", "", "") + sub.Flags().StringVar(&b, "collection", "", "") + sub.MarkFlagsMutuallyExclusive("account", "collection") + cmd.AddCommand(sub) + cmd.SetArgs([]string{"search", "--account", "alpha@example.com", "--collection", "work"}) + + err := cmd.Execute() + if err == nil { + t.Fatal("expected error when both --account and --collection are set, got nil") + } + msg := err.Error() + if !strings.Contains(msg, "account") || !strings.Contains(msg, "collection") { + t.Errorf("error should mention both flag names; got: %q", msg) + } + _ = a + _ = b +} diff --git a/cmd/msgvault/cmd/search_vector.go b/cmd/msgvault/cmd/search_vector.go index 109a9777..d453021a 100644 --- a/cmd/msgvault/cmd/search_vector.go +++ b/cmd/msgvault/cmd/search_vector.go @@ -16,7 +16,6 @@ import ( _ "github.com/mattn/go-sqlite3" "github.com/spf13/cobra" "github.com/wesm/msgvault/internal/search" - "github.com/wesm/msgvault/internal/store" "github.com/wesm/msgvault/internal/vector" "github.com/wesm/msgvault/internal/vector/embed" "github.com/wesm/msgvault/internal/vector/hybrid" @@ -26,7 +25,9 @@ import ( // runHybridSearch executes a vector or hybrid search against the local // msgvault archive using the sqlite-vec backend and configured embedding // endpoint. Invoked from search.go when --mode is "vector" or "hybrid". -func runHybridSearch(cmd *cobra.Command, queryStr, mode string, explain bool) error { +// scope carries any resolved --account/--collection scope; an empty +// Scope means no scope flag was supplied. +func runHybridSearch(cmd *cobra.Command, queryStr, mode string, explain bool, scope Scope) error { if queryStr == "" { return fmt.Errorf("empty search query") } @@ -93,27 +94,22 @@ func runHybridSearch(cmd *cobra.Command, queryStr, mode string, explain bool) er return fmt.Errorf("build filter: %w", err) } - // Resolve --account to a SourceID so vector/hybrid respects - // account scoping the same way FTS mode does. A missing account - // must return a clear error rather than silently searching the - // whole corpus. - if searchAccount != "" { - s, err := store.Open(cfg.DatabaseDSN()) - if err != nil { - return fmt.Errorf("open store for account lookup: %w", err) - } - src, err := s.GetSourceByIdentifier(searchAccount) - closeErr := s.Close() - if err != nil { - return fmt.Errorf("look up account: %w", err) - } - if closeErr != nil { - return fmt.Errorf("close store: %w", closeErr) - } - if src == nil { - return fmt.Errorf("account %q not found", searchAccount) + // Apply resolved --account/--collection scope so vector and hybrid + // modes honour the same scope as FTS. Earlier this branch only + // looked at --account directly and silently ignored --collection. + if !scope.IsEmpty() { + filter.SourceIDs = scope.SourceIDs() + if scope.IsCollection() { + n := len(filter.SourceIDs) + suffix := "s" + if n == 1 { + suffix = "" + } + fmt.Fprintf(os.Stderr, + "Searching collection %q (%d account%s)\n", + scope.DisplayName(), n, suffix, + ) } - filter.SourceIDs = []int64{src.ID} } freeText := strings.Join(q.TextTerms, " ") @@ -192,6 +188,11 @@ func hydrateHybridResults(ctx context.Context, db *sql.DB, hits []vector.FusedHi args[i] = h.MessageID orderIdx[h.MessageID] = i } + // Liveness is enforced upstream in the sqlite-vec backend's filter + // CTE used for ranking; re-filtering here would silently drop hits + // whose row was soft-deleted between ranking and hydration, + // returning a result list shorter than the ranked hits. Hydrate + // whatever was ranked. q := fmt.Sprintf(` SELECT m.id, COALESCE(m.subject,''), COALESCE(p.email_address,''), m.sent_at FROM messages m @@ -202,6 +203,10 @@ func hydrateHybridResults(ctx context.Context, db *sql.DB, hits []vector.FusedHi return nil, fmt.Errorf("query messages: %w", err) } defer func() { _ = rows.Close() }() + // hits ranked by the vector engine may have been soft-deleted between + // ranking and hydration. Track which slots got filled so we can drop + // the empty ones and warn about the gap. + filled := make([]bool, len(hits)) out := make([]hybridResultRow, len(hits)) for rows.Next() { var id int64 @@ -228,11 +233,25 @@ func hydrateHybridResults(ctx context.Context, db *sql.DB, hits []vector.FusedHi row.SentAt = sentAt.Time } out[idx] = row + filled[idx] = true } if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate messages: %w", err) } - return out, nil + dropped := 0 + compact := out[:0] + for i, ok := range filled { + if ok { + compact = append(compact, out[i]) + } else { + dropped++ + } + } + if dropped > 0 { + logger.Warn("hydration dropped hits (likely soft-deleted between rank and hydrate)", + "dropped", dropped, "ranked", len(hits)) + } + return compact, nil } func outputHybridResultsTable(results []hybridResultRow, meta hybrid.ResultMeta, explain bool) error { diff --git a/cmd/msgvault/cmd/search_vector_stub.go b/cmd/msgvault/cmd/search_vector_stub.go index b170d5de..969c5500 100644 --- a/cmd/msgvault/cmd/search_vector_stub.go +++ b/cmd/msgvault/cmd/search_vector_stub.go @@ -12,7 +12,7 @@ import ( // tag. The sqlite-vec extension is required for vector search; binaries // produced by `make build` (which sets `-tags "fts5 sqlite_vec"`) use // the real implementation in search_vector.go. -func runHybridSearch(_ *cobra.Command, _ string, mode string, _ bool) error { +func runHybridSearch(_ *cobra.Command, _ string, mode string, _ bool, _ Scope) error { return fmt.Errorf( "--mode=%s requires sqlite-vec support; rebuild with `go build -tags \"fts5 sqlite_vec\"`", mode) diff --git a/cmd/msgvault/cmd/search_vector_test.go b/cmd/msgvault/cmd/search_vector_test.go index a7f8ca5a..56e2f3d4 100644 --- a/cmd/msgvault/cmd/search_vector_test.go +++ b/cmd/msgvault/cmd/search_vector_test.go @@ -151,8 +151,8 @@ func TestSearchCmd_VectorMode_UnknownAccount(t *testing.T) { if err == nil { t.Fatal("expected error for unknown --account, got nil") } - if !strings.Contains(err.Error(), "not found") { - t.Errorf("error = %q, want substring 'not found'", err) + if !strings.Contains(err.Error(), "no account found") { + t.Errorf("error = %q, want substring 'no account found'", err) } } @@ -189,6 +189,67 @@ func TestSearchCmd_VectorMode_AccountScopingResolves(t *testing.T) { } } +// TestSearchCmd_VectorMode_CollectionScopingResolves verifies that +// --collection is plumbed through to filter.SourceIDs in the vector +// path. Earlier the vector branch only looked at --account directly +// and silently ignored --collection. +func TestSearchCmd_VectorMode_CollectionScopingResolves(t *testing.T) { + srv := fakeEmbedServer(t, 4) + defer srv.Close() + + s, restore := newVectorSearchTestEnv(t, srv.URL) + defer restore() + src, err := s.GetOrCreateSource("gmail", "alice@example.com") + if err != nil { + t.Fatalf("seed source: %v", err) + } + if _, err := s.CreateCollection("alice-only", "", []int64{src.ID}); err != nil { + t.Fatalf("create collection: %v", err) + } + + done := captureStdout(t) + root := newTestRootCmd() + root.AddCommand(searchCmd) + root.SetArgs([]string{ + "search", "--mode", "vector", + "--collection", "alice-only", + "hello", + }) + err = root.Execute() + out := done() + if err != nil { + t.Fatalf("expected no error for known collection, got %v (out=%s)", err, out) + } + if !strings.Contains(out, "No messages found") { + t.Errorf("expected 'No messages found' (empty index), got: %s", out) + } +} + +// TestSearchCmd_VectorMode_CollectionUnknown mirrors the FTS path's +// unknown-collection rejection. +func TestSearchCmd_VectorMode_CollectionUnknown(t *testing.T) { + srv := fakeEmbedServer(t, 4) + defer srv.Close() + + _, restore := newVectorSearchTestEnv(t, srv.URL) + defer restore() + + root := newTestRootCmd() + root.AddCommand(searchCmd) + root.SetArgs([]string{ + "search", "--mode", "vector", + "--collection", "does-not-exist", + "hello", + }) + err := root.Execute() + if err == nil { + t.Fatal("expected error for unknown --collection, got nil") + } + if !strings.Contains(err.Error(), "no collection") { + t.Errorf("error = %q, want substring 'no collection'", err) + } +} + // TestSearchCmd_HybridMode_UnknownAccount mirrors the vector test for // mode=hybrid, since the account-lookup path is shared. func TestSearchCmd_HybridMode_UnknownAccount(t *testing.T) { @@ -209,7 +270,63 @@ func TestSearchCmd_HybridMode_UnknownAccount(t *testing.T) { if err == nil { t.Fatal("expected error for unknown --account, got nil") } - if !strings.Contains(err.Error(), "not found") { - t.Errorf("error = %q, want substring 'not found'", err) + if !strings.Contains(err.Error(), "no account found") { + t.Errorf("error = %q, want substring 'no account found'", err) + } +} + +// TestSearchCmd_VectorMode_UnscopedRunsMigrations regression-guards +// the upgrade path: a user upgrading to a build that adds the +// deleted_at column whose first command is an unscoped +// `search --mode=vector|hybrid` must not crash with +// "no such column: deleted_at". The unscoped path skips +// resolveSearchScope (which is what runs InitSchema for scoped +// queries) and runHybridSearch opens a raw sql.DB, so the dispatch +// itself must run the migrations. Verified directly: drop +// deleted_at, then assert the dispatch path restores it before +// runHybridSearch's raw sql.DB sees the schema. +func TestSearchCmd_VectorMode_UnscopedRunsMigrations(t *testing.T) { + srv := fakeEmbedServer(t, 4) + defer srv.Close() + + s, restore := newVectorSearchTestEnv(t, srv.URL) + defer restore() + + if _, err := s.DB().Exec(`ALTER TABLE messages DROP COLUMN deleted_at`); err != nil { + t.Fatalf("drop deleted_at to simulate pre-migration DB: %v", err) + } + // Sanity: column is gone. + var cnt int + if err := s.DB().QueryRow( + `SELECT COUNT(*) FROM pragma_table_info('messages') WHERE name = 'deleted_at'`, + ).Scan(&cnt); err != nil { + t.Fatalf("pragma_table_info pre-dispatch: %v", err) + } + if cnt != 0 { + t.Fatalf("setup: deleted_at still present (cnt=%d)", cnt) + } + + done := captureStdout(t) + root := newTestRootCmd() + root.AddCommand(searchCmd) + root.SetArgs([]string{ + "search", "--mode", "vector", + "hello", + }) + // Error from the engine itself is fine for this test — what we're + // guarding is that the dispatch path runs the schema migration + // before runHybridSearch opens its raw sql.DB. Other engine-level + // errors (no vectors, missing fts in the test build) don't + // invalidate the migration check. + _ = root.Execute() + _ = done() + + if err := s.DB().QueryRow( + `SELECT COUNT(*) FROM pragma_table_info('messages') WHERE name = 'deleted_at'`, + ).Scan(&cnt); err != nil { + t.Fatalf("pragma_table_info post-dispatch: %v", err) + } + if cnt != 1 { + t.Fatalf("dispatch did not re-add deleted_at column (cnt=%d) — runHybridSearch would query a missing column on an upgraded DB", cnt) } } diff --git a/cmd/msgvault/cmd/serve.go b/cmd/msgvault/cmd/serve.go index d7ec5795..a2935b3e 100644 --- a/cmd/msgvault/cmd/serve.go +++ b/cmd/msgvault/cmd/serve.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "net" "net/http" "os" @@ -87,6 +88,12 @@ func runServe(cmd *cobra.Command, args []string) error { if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + // Legacy [identity] migration is deferred to the first scheduled sync's + // runPostSourceCreateMigrations call, which fires AFTER that sync's + // confirmDefaultIdentity. Calling the migration here would race + // confirmDefaultIdentity for upgraded DBs with sources + a legacy + // [identity] block — same ordering hole the ingest commands already + // close by routing the legacy migration exclusively post-source-create. // Set up cancellable context early so vector-backend initialization // (which may open files and run migrations) respects Ctrl+C. @@ -396,6 +403,14 @@ func runScheduledSync(ctx context.Context, email string, s *store.Store, getOAut if err != nil { return fmt.Errorf("get source: %w", err) } + // Auto-default-identity must run BEFORE the legacy migration retry + // — see comment in account_identity.go. serve is a daemon, so the + // confirmation message has no terminal; discard it. Helper logs any + // failure path through its own logger.Warn. + confirmDefaultIdentity(io.Discard, s, source.ID, email, email, "account-identifier") + if err := runPostSourceCreateMigrations(s); err != nil { + return fmt.Errorf("post-source-create migrations: %w", err) + } // Run incremental sync summary, err := syncer.Incremental(ctx, source) diff --git a/cmd/msgvault/cmd/show_message.go b/cmd/msgvault/cmd/show_message.go index 8b325f17..9ef266ab 100644 --- a/cmd/msgvault/cmd/show_message.go +++ b/cmd/msgvault/cmd/show_message.go @@ -85,6 +85,9 @@ func showLocalMessage(cmd *cobra.Command, idStr string) error { if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrations(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } // Create query engine engine := query.NewSQLiteEngine(s.DB()) diff --git a/cmd/msgvault/cmd/stats.go b/cmd/msgvault/cmd/stats.go index a66b1181..1bdb2ab7 100644 --- a/cmd/msgvault/cmd/stats.go +++ b/cmd/msgvault/cmd/stats.go @@ -4,6 +4,12 @@ import ( "fmt" "github.com/spf13/cobra" + "github.com/wesm/msgvault/internal/store" +) + +var ( + statsAccount string + statsCollection string ) var statsCmd = &cobra.Command{ @@ -14,44 +20,140 @@ var statsCmd = &cobra.Command{ Uses remote server if [remote].url is configured, otherwise uses local database. Use --local to force local database.`, RunE: func(cmd *cobra.Command, args []string) error { - s, err := OpenStore() + scoped := statsAccount != "" || statsCollection != "" + + if IsRemoteMode() { + if statsAccount != "" { + return fmt.Errorf("--account is not supported in remote mode") + } + if statsCollection != "" { + return fmt.Errorf("--collection is not supported in remote mode") + } + } + + // Scoped stats require a local store for scope resolution and GetStatsForScope. + if scoped { + st, err := openLocalStoreAndInit() + if err != nil { + return fmt.Errorf("open store: %w", err) + } + defer func() { _ = st.Close() }() + + var scope Scope + if statsAccount != "" { + scope, err = ResolveAccountFlag(st, statsAccount) + if err != nil { + return err + } + if scope.IsEmpty() { + return fmt.Errorf("--account %q resolved to zero sources", statsAccount) + } + } else { + scope, err = ResolveCollectionFlag(st, statsCollection) + if err != nil { + return err + } + if scope.IsEmpty() { + return fmt.Errorf("--collection %q has no member accounts", statsCollection) + } + } + + sourceIDs := scope.SourceIDs() + // A collection with zero member sources resolves to a non-nil + // Scope (Collection set, Source nil) so IsEmpty above is false, + // but SourceIDs is empty. GetStatsForScope treats an empty + // slice as "unscoped" and would silently return archive-wide + // counts. Reject explicitly with the same shape as the + // IsEmpty branch above. + if len(sourceIDs) == 0 { + return fmt.Errorf("--collection %q has no member accounts", statsCollection) + } + dbStats, err := st.GetStatsForScope(sourceIDs) + if err != nil { + logger.Warn("stats failed", "error", err.Error()) + return fmt.Errorf("get stats: %w", err) + } + logger.Info("stats", + "messages", dbStats.MessageCount, + "threads", dbStats.ThreadCount, + "attachments", dbStats.AttachmentCount, + "labels", dbStats.LabelCount, + "accounts", dbStats.SourceCount, + "db_bytes", dbStats.DatabaseSize, + ) + + if statsAccount != "" { + fmt.Printf("Stats for account %q:\n", scope.DisplayName()) + } else { + n := len(sourceIDs) + suffix := "s" + if n == 1 { + suffix = "" + } + fmt.Printf("Stats for collection %q (%d account%s):\n", + scope.DisplayName(), n, suffix) + } + + printStats(dbStats) + fmt.Printf("\nNote: Size is global (not scoped).\n") + return nil + } + + // Unscoped: route remote to OpenStore (HTTP path), local to + // openLocalStoreAndInit so InitSchema and runStartupMigrations + // run consistently with every other command. + var ( + s MessageStore + err error + ) + if IsRemoteMode() { + s, err = OpenStore() + } else { + s, err = openLocalStoreAndInit() + } if err != nil { return fmt.Errorf("open store: %w", err) } defer func() { _ = s.Close() }() - stats, err := s.GetStats() + dbStats, err := s.GetStats() if err != nil { logger.Warn("stats failed", "error", err.Error()) return fmt.Errorf("get stats: %w", err) } logger.Info("stats", - "messages", stats.MessageCount, - "threads", stats.ThreadCount, - "attachments", stats.AttachmentCount, - "labels", stats.LabelCount, - "accounts", stats.SourceCount, - "db_bytes", stats.DatabaseSize, + "messages", dbStats.MessageCount, + "threads", dbStats.ThreadCount, + "attachments", dbStats.AttachmentCount, + "labels", dbStats.LabelCount, + "accounts", dbStats.SourceCount, + "db_bytes", dbStats.DatabaseSize, ) - // Show source indicator if IsRemoteMode() { fmt.Printf("Remote: %s\n", cfg.Remote.URL) } else { fmt.Printf("Database: %s\n", cfg.DatabaseDSN()) } - fmt.Printf(" Messages: %d\n", stats.MessageCount) - fmt.Printf(" Threads: %d\n", stats.ThreadCount) - fmt.Printf(" Attachments: %d\n", stats.AttachmentCount) - fmt.Printf(" Labels: %d\n", stats.LabelCount) - fmt.Printf(" Accounts: %d\n", stats.SourceCount) - fmt.Printf(" Size: %.2f MB\n", float64(stats.DatabaseSize)/(1024*1024)) - + printStats(dbStats) return nil }, } +func printStats(s *store.Stats) { + fmt.Printf(" Messages: %d\n", s.MessageCount) + fmt.Printf(" Threads: %d\n", s.ThreadCount) + fmt.Printf(" Attachments: %d\n", s.AttachmentCount) + fmt.Printf(" Labels: %d\n", s.LabelCount) + fmt.Printf(" Accounts: %d\n", s.SourceCount) + fmt.Printf(" Size: %.2f MB\n", float64(s.DatabaseSize)/(1024*1024)) +} + func init() { rootCmd.AddCommand(statsCmd) + statsCmd.Flags().StringVar(&statsAccount, "account", "", "Show stats for a specific account") + statsCmd.Flags().StringVar(&statsCollection, "collection", "", + "Show stats for all member accounts of one collection") + statsCmd.MarkFlagsMutuallyExclusive("account", "collection") } diff --git a/cmd/msgvault/cmd/stats_test.go b/cmd/msgvault/cmd/stats_test.go new file mode 100644 index 00000000..c176431a --- /dev/null +++ b/cmd/msgvault/cmd/stats_test.go @@ -0,0 +1,103 @@ +package cmd + +import ( + "log/slog" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/spf13/cobra" + "github.com/wesm/msgvault/internal/config" + "github.com/wesm/msgvault/internal/store" +) + +// TestStatsCommand_AccountAndCollectionMutuallyExclusive confirms that passing +// both --account and --collection to the stats command is rejected by cobra. +func TestStatsCommand_AccountAndCollectionMutuallyExclusive(t *testing.T) { + var a, b string + cmd := &cobra.Command{Use: "stats-test", SilenceErrors: true} + sub := &cobra.Command{Use: "stats", RunE: func(cmd *cobra.Command, args []string) error { return nil }} + sub.Flags().StringVar(&a, "account", "", "") + sub.Flags().StringVar(&b, "collection", "", "") + sub.MarkFlagsMutuallyExclusive("account", "collection") + cmd.AddCommand(sub) + cmd.SetArgs([]string{"stats", "--account", "foo@example.com", "--collection", "bar"}) + + err := cmd.Execute() + if err == nil { + t.Fatal("expected error when both --account and --collection are set, got nil") + } + msg := err.Error() + if !strings.Contains(msg, "account") || !strings.Contains(msg, "collection") { + t.Errorf("error should mention both flag names; got: %q", msg) + } + _ = a + _ = b +} + +// TestStatsCommand_EmptyCollectionRejected verifies that +// `stats --collection ` errors out when the named collection +// has zero member sources, instead of silently falling through to +// archive-wide stats. Regression test for iter13 codex Medium: +// previously, an empty collection produced a non-IsEmpty Scope but +// SourceIDs() returned an empty slice, and GetStatsForScope treats +// an empty slice as unscoped/global. +func TestStatsCommand_EmptyCollectionRejected(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "msgvault.db") + + // Pre-create the store and an empty collection. CreateCollection + // requires at least one source, so create a source, attach, and + // then remove the source from the collection to leave it empty. + st, err := store.Open(dbPath) + if err != nil { + t.Fatalf("open store: %v", err) + } + if err := st.InitSchema(); err != nil { + t.Fatalf("init schema: %v", err) + } + src, err := st.GetOrCreateSource("gmail", "alice@example.com") + if err != nil { + t.Fatalf("create source: %v", err) + } + if _, err := st.CreateCollection("empty", "test", []int64{src.ID}); err != nil { + t.Fatalf("create collection: %v", err) + } + if err := st.RemoveSourcesFromCollection("empty", []int64{src.ID}); err != nil { + t.Fatalf("remove source from collection: %v", err) + } + _ = st.Close() + + savedCfg := cfg + savedLogger := logger + savedStatsCollection := statsCollection + defer func() { + cfg = savedCfg + logger = savedLogger + statsCollection = savedStatsCollection + }() + + cfg = &config.Config{ + HomeDir: tmpDir, + Data: config.DataConfig{DataDir: tmpDir}, + } + logger = slog.New(slog.NewTextHandler(os.Stderr, nil)) + statsCollection = "empty" + + testCmd := &cobra.Command{Use: "stats", RunE: statsCmd.RunE} + testCmd.Flags().StringVar(&statsAccount, "account", "", "") + testCmd.Flags().StringVar(&statsCollection, "collection", "empty", "") + + root := newTestRootCmd() + root.AddCommand(testCmd) + root.SetArgs([]string{"stats", "--collection", "empty"}) + + err = root.Execute() + if err == nil { + t.Fatal("expected error for empty collection, got nil") + } + if !strings.Contains(err.Error(), "no member accounts") { + t.Errorf("error message = %q, want substring \"no member accounts\"", err.Error()) + } +} diff --git a/cmd/msgvault/cmd/store_resolver.go b/cmd/msgvault/cmd/store_resolver.go index 69753ec4..f080f4a5 100644 --- a/cmd/msgvault/cmd/store_resolver.go +++ b/cmd/msgvault/cmd/store_resolver.go @@ -2,12 +2,79 @@ package cmd import ( "fmt" + "os" "time" "github.com/wesm/msgvault/internal/remote" "github.com/wesm/msgvault/internal/store" ) +// runStartupMigrations pulls legacy identity addresses from the global config +// and runs the one-time migration. If migration was performed, the notice is +// logged and printed to stderr. If the migration is deferred because no source +// exists yet, it will be retried on a later command after a source has been +// created — and ingest commands that create the first source should call +// runPostSourceCreateMigrations after GetOrCreateSource so the deferred +// migration applies on the same invocation. +// +// Always returns nil unless the migration itself errors. +func runStartupMigrations(s *store.Store) error { + addrs := cfg.Identity.Addresses + res, err := s.RunStartupMigrations(addrs) + if err != nil { + logger.Warn("startup migration failed", "error", err) + return err + } + // Success cases log at Info (the operation succeeded; res.Notice is + // the user-facing surface on stderr). Reserved Warn for the actual + // error path above. + switch { + case res.Deferred: + logger.Info("legacy [identity] block in config detected (migration deferred until a source exists)", + "address_count", res.AddressCount, + "hint", "run 'msgvault add-account ...' to create a source; the migration will retry on the next command") + case res.Applied: + logger.Info("legacy identity migrated", + "addresses", res.AddressCount, + "sources", res.SourceCount) + } + if res.Notice != "" { + fmt.Fprintln(os.Stderr, res.Notice) + } + return nil +} + +// runStartupMigrationsForIngest is the pre-source-create hook for ingest +// commands. The only startup migration today is MigrateLegacyIdentityConfig, +// which writes to account_identities — and any pre-source-create write +// races confirmDefaultIdentity by populating identity rows before the +// source's own identifier is confirmed, causing confirmDefaultIdentity's +// `len(existing) > 0` guard to skip the source's own address (regression +// caught upstream at iter20). +// +// All ingest paths already invoke runPostSourceCreateMigrations after +// confirmDefaultIdentity, which handles the legacy migration correctly +// in the deferred (no-source) case and is a no-op once the migration +// sentinel is set. So this pre-source call is intentionally a no-op +// to avoid the race. Kept as a named hook so future startup work that +// genuinely belongs *before* source creation has an obvious place to +// land without re-introducing the legacy-identity race. +func runStartupMigrationsForIngest(s *store.Store) error { + _ = s + return nil +} + +// runPostSourceCreateMigrations re-runs startup migrations after the caller +// has just created a source. The legacy identity migration defers when no +// source exists at startup, so on a fresh install the very first +// add-account / add-imap / add-o365 / import-* invocation needs a second +// pass to actually apply the migration on the same invocation that created +// the first source. Subsequent calls are O(1) — once the migration sentinel +// is set, MigrateLegacyIdentityConfig short-circuits. +func runPostSourceCreateMigrations(s *store.Store) error { + return runStartupMigrations(s) +} + // MessageStore is the interface for commands that need basic message operations. // Both store.Store and remote.Store implement this interface. type MessageStore interface { @@ -66,6 +133,25 @@ func openLocalStore() (*store.Store, error) { return store.Open(dbPath) } +// openLocalStoreAndInit opens the local SQLite database, initializes the +// schema, and runs startup migrations. Callers that previously called +// store.Open + s.InitSchema separately should migrate to this helper. +func openLocalStoreAndInit() (*store.Store, error) { + s, err := openLocalStore() + if err != nil { + return nil, err + } + if err := s.InitSchema(); err != nil { + _ = s.Close() + return nil, fmt.Errorf("init schema: %w", err) + } + if err := runStartupMigrations(s); err != nil { + _ = s.Close() + return nil, fmt.Errorf("startup migrations: %w", err) + } + return s, nil +} + // openRemoteStore creates a remote store client. func openRemoteStore() (*remote.Store, error) { return remote.New(remote.Config{ diff --git a/cmd/msgvault/cmd/sync.go b/cmd/msgvault/cmd/sync.go index 51bd7374..fec6bff4 100644 --- a/cmd/msgvault/cmd/sync.go +++ b/cmd/msgvault/cmd/sync.go @@ -52,6 +52,9 @@ Examples: if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrations(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } // Set up context with cancellation before any sync calls // so Ctrl+C always saves checkpoints. diff --git a/cmd/msgvault/cmd/syncfull.go b/cmd/msgvault/cmd/syncfull.go index 10b00d08..8e8730a5 100644 --- a/cmd/msgvault/cmd/syncfull.go +++ b/cmd/msgvault/cmd/syncfull.go @@ -75,6 +75,9 @@ Examples: if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrations(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } getOAuthMgr := oauthManagerCache() diff --git a/cmd/msgvault/cmd/tui.go b/cmd/msgvault/cmd/tui.go index fb716d44..d3928b48 100644 --- a/cmd/msgvault/cmd/tui.go +++ b/cmd/msgvault/cmd/tui.go @@ -91,6 +91,9 @@ Remote Mode: if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrations(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } // Build FTS index in background — TUI uses DuckDB/Parquet for // aggregates and only needs FTS for deep search (Tab to switch). @@ -284,6 +287,34 @@ func cacheNeedsBuild(dbPath, analyticsDir string) cacheStaleness { fmt.Sprintf("%d deletions", deletedSinceBuild)) } + // Dedup-hidden rows (deleted_at) are excluded from the messages + // Parquet export, so a dedup run after the last cache build leaves + // stale duplicate rows in the cache. Detect that by counting hides + // since LastSyncAt and force a full rebuild if any are present. + // The deleted_from_source_at IS NULL clause keeps the count + // disjoint from the deletedSinceBuild count above so a row that is + // both source-deleted and dedup-hidden after LastSyncAt is reported + // once (as a deletion), not double-counted in the reason string. + var hiddenSinceBuild int64 + err = db.DB().QueryRow(` + SELECT COUNT(*) FROM messages + WHERE deleted_at IS NOT NULL + AND deleted_at >= ? + AND deleted_from_source_at IS NULL + `, syncAtStr).Scan(&hiddenSinceBuild) + if err != nil { + return cacheStaleness{ + NeedsBuild: true, FullRebuild: true, + Reason: "cannot verify dedup state", + } + } + if hiddenSinceBuild > 0 { + result.HasDeleted = true + result.FullRebuild = true + reasons = append(reasons, + fmt.Sprintf("%d dedup-hidden", hiddenSinceBuild)) + } + var hasSyncRunsTable int err = db.DB().QueryRow(` SELECT COUNT(*) FROM sqlite_master diff --git a/cmd/msgvault/cmd/update_account.go b/cmd/msgvault/cmd/update_account.go index 08f2e645..af01893a 100644 --- a/cmd/msgvault/cmd/update_account.go +++ b/cmd/msgvault/cmd/update_account.go @@ -37,6 +37,9 @@ Examples: if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrations(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } source, err := s.GetSourceByIdentifier(email) if err != nil { diff --git a/cmd/msgvault/cmd/verify.go b/cmd/msgvault/cmd/verify.go index 509862b7..45569a0a 100644 --- a/cmd/msgvault/cmd/verify.go +++ b/cmd/msgvault/cmd/verify.go @@ -54,6 +54,9 @@ Examples: if err := s.InitSchema(); err != nil { return fmt.Errorf("init schema: %w", err) } + if err := runStartupMigrations(s); err != nil { + return fmt.Errorf("startup migrations: %w", err) + } // Run SQLite integrity check before any Gmail work. Users with a // corrupt database should see the repair hint even if their OAuth diff --git a/docs/accounts-identities-collections-dedup/README.md b/docs/accounts-identities-collections-dedup/README.md new file mode 100644 index 00000000..f94948a6 --- /dev/null +++ b/docs/accounts-identities-collections-dedup/README.md @@ -0,0 +1,339 @@ +# Accounts, Identities, Collections, and Deduplication + +**Shipped May 3, 2026.** Features for managing the growing size and +complexity of accounts in msgvault. The first deduplication command +landed in [commit `9ace189`](https://github.com/jesserobbins/msgvault/commit/9ace189d86949ccd595a012bf29864abcffa4dda) +on April 8, alongside the [implementation plan](https://github.com/jesserobbins/msgvault/commit/8d8cedeba20920a74dc5b1c0acb97f7548b64ff5) +and [test plan](https://github.com/jesserobbins/msgvault/commit/df61cee) +committed earlier the same day. The unified model that emerged over +the next four weeks is proposed in upstream issue +[wesm/msgvault#278](https://github.com/wesm/msgvault/issues/278) and +ships via [PR #304](https://github.com/wesm/msgvault/pull/304). + +## What you need to know + +A long-running msgvault archive accumulates overlapping sources: a +current Gmail sync, an old mbox export, Apple Mail from a retired +laptop, IMAP backups, chat exports, SMS history. Each source is +valuable. Together they create three problems. The same message +appears multiple times. Old imports no longer remember which +addresses are "you." Duplicates dominate search results. + +These features fix all three and keep every source's provenance +intact. + +The whole release rests on two halves: a data model (Account, +Identity, Collection) and a safety story (hide is the default; delete +is always opt-in). Read these if you read nothing else. + +### The data model: Account, Identity, Collection (AIC) + +1. **An *account* is one ingest source.** A Gmail sync is one + account. An mbox import is another. Two imports of the same + real-world mailbox produce two accounts. msgvault never silently + merges them. +2. **An *identity* is the addresses, phone numbers, and identifiers + that mean "me" inside one account.** Identity is per-account + because the same address can mean different things in different + imports. +3. **A *collection* is a named group of accounts.** `All` exists by + default and contains every account. Create `work`, `personal`, or + any other named group. Collections are the boundary for + cross-account features: search, stats, dedup. A collection's + identity is the union of its members'. + +![Accounts and collections — left side shows six per-import accounts (Personal Gmail, Old mbox, Apple Mail archive, iMessage, Old work account, College email), each with the addresses or phone numbers that identify the owner inside that source. Right side shows three collections — All (every account), Personal (a deliberate subset), Work (another named view) — each composed of accounts.](./assets/account-collection-concept.png) + +Each account on the left is one ingest source with its own owner +identity. Collections on the right are user-named groups of those +accounts. `All` is built automatically and contains everything. +`Personal` and `Work` are named subsets. One account can belong to +multiple collections. Collections contain accounts only, not other +collections. + +### The safety story: hide, don't delete + +1. **`deduplicate` hides redundant copies. It does not delete them.** + One survivor stays visible. The other copies stay on disk and drop + out of normal reads. `--undo ` restores them. +2. **Deletion is never required.** msgvault never escalates from + "hide" to "delete locally" to "delete from the source server" on + its own. Each step is a separate, opt-in command. Stay on "hide" + forever if you want. We call this design the **dedup safety + ladder**: four explicit rungs the user climbs deliberately, with + no automatic escalation between them. + +![Data safety ladder — five rungs. Rung 00 (Backup, default-on) writes a point-in-time database copy before any rung that modifies data. Rung 01 (Scan, deduplicate --dry-run) reports what would change with no data touched. Rung 02 (Hide, deduplicate) soft-deletes redundant copies; reversible via --undo. Rung 03 (Local hard delete, delete-deduped --batch) permanently removes hidden rows from the local archive. Rung 04 (Remote delete, delete-staged) deletes from the source server, source-scoped, moves to source trash by default. A banner reads: deletion is never required — you can run deduplicate as many times as you want and stay on rung 02 forever.](./assets/safety-ladder-concept.png) + +The diagram is the mental model for every dedup-related command. +**Each rung is a separate, explicit user action.** Rung 00 (backup) +and rung 02 (hide) happen by default when you run `deduplicate`. +Rungs 03 and 04 happen only when you invoke a different command: +`delete-deduped` or `delete-staged`. msgvault never escalates between +rungs on its own. "Apply dedup" never implies "hard-delete locally." +"Hard-delete locally" never implies "delete from the source server." +"Delete from the source server" never implies "permanently delete +from the source server." + +The rest of this document is a HOWTO with worked examples. + +## HOWTO + +The examples below use a worked scenario with real command shapes. +Substitute your own account and collection names. + +### List your accounts + +Start by seeing what msgvault knows about: + +```sh +msgvault list-accounts +``` + +Each row is one account. Note the identifier (typically the email +address or source name); every command in this guide takes it. + +### Confirm your identities + +For sent vs. received to mean anything in old imports, msgvault needs +to know which addresses are "you" inside each account. + +List what's confirmed across all accounts: + +```sh +msgvault identity list +``` + +Show one account's identity in detail: + +```sh +msgvault identity show me@example.com +``` + +Add a confirmed identifier to an account's identity: + +```sh +msgvault identity add me@example.com me@oldco.example +``` + +Remove one: + +```sh +msgvault identity remove me@example.com me@oldco.example +``` + +Identity is per-account because an address safe to treat as "you" in +one source can be misleading in another. A collection's identity is +the union of its member accounts' identities, computed at read time. +You don't manage it directly. + +### Group accounts into a collection + +For a unified view across several accounts (for search, stats, or +dedup), group them into a collection. `All` already exists. Custom +collections are explicit. + +Create a `work` collection from two accounts: + +```sh +msgvault collection create work --accounts me@oldco.example,me@newco.example +``` + +List all collections: + +```sh +msgvault collection list +``` + +Show one in detail: + +```sh +msgvault collection show work +``` + +Add or remove members later: + +```sh +msgvault collection add work --accounts contractor@example.org +msgvault collection remove work --accounts contractor@example.org +``` + +Delete a collection when you're done with it (the underlying accounts +and their messages are untouched): + +```sh +msgvault collection delete work +``` + +`All` is auto-managed and immutable — msgvault rejects +`collection delete All` and explicit membership edits on `All`. + +### Search or count across a collection + +Wherever `--account` works, `--collection` works too. Cross-account +operations always go through a collection boundary. + +Search inside one account: + +```sh +msgvault search --account me@example.com "dinner friday" +``` + +Search across the `Personal` collection: + +```sh +msgvault search --collection Personal "dinner friday" +``` + +Get stats for one collection: + +```sh +msgvault stats --collection work +``` + +`--account` and `--collection` are mutually exclusive. msgvault +rejects a collection name passed to `--account` (or an account +identifier passed to `--collection`) with a hint to use the right +flag. + +### Run dedup safely + +Three flavors of dedup, ordered by risk: + +```sh +msgvault deduplicate # per-account, each in isolation +msgvault deduplicate --account # one account +msgvault deduplicate --collection # cross-account, inside one collection +``` + +The unscoped form is the safest default. It processes each account +independently and never crosses source boundaries. Cross-account +dedup is higher-risk: it can collapse duplicates between independent +archives whose provenance you may want to preserve. So it requires +an explicit `--collection`. To dedup across every account, write +`--collection All`. + +### Walking the safety ladder + +The recommended sequence: + +**1. Scan first (rung 01).** See what dedup would do before it does +it: + +```sh +msgvault deduplicate --collection Personal --dry-run +``` + +The output shows duplicate groups, which copy is the proposed +survivor, and why it was chosen. Nothing is modified. + +**2. Apply dedup (rung 02).** When the dry-run output looks right: + +```sh +msgvault deduplicate --collection Personal +``` + +Before modifying any row, msgvault writes a point-in-time database +backup alongside your DB file (e.g. +`msgvault.db.dedup-backup-20260503-091500`). Opt out only with +`--no-backup`. The command then hides redundant copies and prints a +batch ID like `dedup-2026-05-03-1`. + +The hidden copies are excluded from search, the TUI, vector and +hybrid retrieval, the API, MCP responses, exports, and stats — but +they remain on disk. + +**3. Undo if you change your mind (still rung 02).** Use the batch ID +from step 2: + +```sh +msgvault deduplicate --undo dedup-2026-05-03-1 +``` + +Undo restores the rows hidden by that batch and cancels any pending +remote-deletion manifest the batch staged. It does not reverse every +side effect — for the precise guarantees, see the spec. + +**4. Hard-delete locally (rung 03 — opt-in, irreversible).** Only run +this when you're confident you'll never want the hidden rows back: + +```sh +msgvault delete-deduped --batch dedup-2026-05-03-1 +``` + +This permanently removes rows the named batch hid. It refuses to +operate on rows it didn't hide. The `--all-hidden` form purges every +hidden row from every batch and requires interactive confirmation. +Backup runs again before the purge unless you pass `--no-backup`. +Undo cannot recover purged rows. + +**5. Delete from the source server (rung 04 — opt-in, does not touch +the local archive).** Cross-source duplicate groups produce zero +remote-deletion entries — only same-source pairs stage. List, inspect, +and execute pending deletion manifests: + +```sh +msgvault list-deletions +msgvault show-deletion +msgvault delete-staged +``` + +By default, `delete-staged` moves messages to the source's trash +(e.g. Gmail's ~30-day Gmail/Trash). Permanent removal requires an +explicit `--permanent` flag and interactive confirmation. As a v1 +guardrail, the whole remote-delete command is gated behind an +environment variable (`MSGVAULT_ENABLE_REMOTE_DELETE=1`); read-only +inspection (`list-deletions`, `show-deletion`, `--list`, `--dry-run`) +is always permitted. + +To cancel a staged batch before it executes: + +```sh +msgvault cancel-deletion +``` + +### Common scenarios + +#### "I just want to clean up duplicates inside my Gmail account." + +```sh +msgvault deduplicate --account me@example.com --dry-run +msgvault deduplicate --account me@example.com +``` + +Stay on rung 02. You're done. + +#### "I imported the same mailbox twice from two sources and want one clean view." + +Create a collection containing both, scan it, apply dedup. msgvault +leaves the originals on each source server untouched. + +```sh +msgvault collection create gmail-plus-mbox \ + --accounts me@example.com,me@example.org +msgvault deduplicate --collection gmail-plus-mbox --dry-run +msgvault deduplicate --collection gmail-plus-mbox +``` + +#### "I want to hard-delete the duplicates I hid last week to reclaim disk." + +```sh +msgvault list-deletions # find the batch ID +msgvault delete-deduped --batch dedup-2026-04-26-1 +``` + +#### "I want to actually remove the duplicates from Gmail." + +This is rung 04. Same-source only — only duplicate pairs that lived +on the same Gmail account stage remote-delete entries. Cross-source +groups (e.g. one copy on Gmail, another in mbox) stage nothing. + +```sh +msgvault list-deletions +msgvault show-deletion +MSGVAULT_ENABLE_REMOTE_DELETE=1 msgvault delete-staged +``` + +The default moves messages to Gmail/Trash, where they're recoverable +for ~30 days. If you want them gone for good, add `--permanent` and +confirm interactively. diff --git a/docs/accounts-identities-collections-dedup/assets/README.md b/docs/accounts-identities-collections-dedup/assets/README.md new file mode 100644 index 00000000..555053b6 --- /dev/null +++ b/docs/accounts-identities-collections-dedup/assets/README.md @@ -0,0 +1,109 @@ +# Diagram assets + +The PNGs in this directory are rendered from the matching HTML files. Each +HTML file is self-contained — no external CSS, no JavaScript, no fonts to +load. Re-rendering is a single headless-Chrome screenshot. + +## Files + +| File | What it shows | +| ----------------------------------- | ---------------------------------------------------------------------------------------------- | +| `account-collection-concept.html` | The accounts vs. collections distinction: per-import sources and the named groups across them. | +| `deduplication-concept.html` | Dedup before/after, with the safety-ladder strip showing this is rung 02 of 4. | +| `safety-ladder-concept.html` | The full data safety ladder: rung 00 backup, then 4 explicit rungs. | +| `survivor-selection-concept.html` | Survivor selection: the sent-message eligibility filter, then the priority list. | + +Each PNG is rendered at 1600 px wide. Heights vary because the panel +content drives layout — render with a tall viewport (1700 is enough for +all four) so nothing is clipped, then trim the bottom whitespace so the +PNG ends with the same padding it starts with. + +## Re-rendering + +Headless Chrome plus ImageMagick. The commands: + +```bash +"$CHROME" --headless=new \ + --disable-gpu \ + --hide-scrollbars \ + --window-size=1600,1700 \ + --default-background-color=0a0a0aff \ + --screenshot=safety-ladder-concept.png \ + "file://$(pwd)/safety-ladder-concept.html" + +# Trim vertical whitespace, restore the full 1600 px width, and pad +# 82 px of breathing room above and below. +magick safety-ladder-concept.png \ + -background "#0a0a0a" -fuzz 2% -trim +repage \ + -gravity center -extent 1600x \ + -gravity north -splice 0x82 \ + -gravity south -splice 0x82 \ + safety-ladder-concept.png +``` + +`$CHROME` is the path to the Chrome (or Chromium) binary. Common +locations: + +| Platform | Path | +| -------- | --------------------------------------------------------------- | +| macOS | `/Applications/Google Chrome.app/Contents/MacOS/Google Chrome` | +| Linux | `/usr/bin/google-chrome` or `/usr/bin/chromium-browser` | +| Windows | `C:\Program Files\Google\Chrome\Application\chrome.exe` | + +A small shell helper that picks one and re-renders all four: + +```bash +#!/usr/bin/env bash +set -euo pipefail +cd "$(dirname "$0")" + +CHROME="${CHROME:-}" +if [ -z "$CHROME" ]; then + for c in \ + "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome" \ + "/usr/bin/google-chrome" \ + "/usr/bin/chromium-browser" \ + "/usr/bin/chromium"; do + if [ -x "$c" ]; then CHROME="$c"; break; fi + done +fi +[ -n "$CHROME" ] || { echo "set CHROME to your Chrome/Chromium binary"; exit 1; } + +for f in account-collection-concept deduplication-concept safety-ladder-concept survivor-selection-concept; do + "$CHROME" --headless=new --disable-gpu --hide-scrollbars \ + --window-size=1600,1700 --default-background-color=0a0a0aff \ + --screenshot="${f}.png" "file://$(pwd)/${f}.html" + magick "${f}.png" \ + -background "#0a0a0a" -fuzz 2% -trim +repage \ + -gravity center -extent 1600x \ + -gravity north -splice 0x82 \ + -gravity south -splice 0x82 \ + "${f}.png" + echo "rendered ${f}.png" +done +``` + +## Editing the diagrams + +Edit the HTML directly — the styles are inline, the data is hand-written, +and there is no build step. The shared palette lives in the `:root` block +of each file and follows the `msgvault.io` site: + +```css +--bg: #0a0a0a; /* page background */ +--surface-1: #161616; /* panel surface */ +--surface-2: #212121; /* nested surface */ +--hairline: #3a3a3a; /* borders */ +--text: #e8e8e8; /* body text */ +--text-2: #c0c0c0; /* secondary text */ +--muted: #a0a0a0; /* hints, eyebrows */ +--accent: #ffffff; /* headings, key emphasis */ +``` + +If you change palette tokens, change them across all four files so the +set stays visually coherent. + +The design here should be refreshed periodically against `msgvault.io` +— the site is the source of truth for type, palette, and surface +treatments. If the site evolves, pull the updated tokens back into these +diagrams so they keep reading as part of the same product. diff --git a/docs/accounts-identities-collections-dedup/assets/account-collection-concept.html b/docs/accounts-identities-collections-dedup/assets/account-collection-concept.html new file mode 100644 index 00000000..bfd69633 --- /dev/null +++ b/docs/accounts-identities-collections-dedup/assets/account-collection-concept.html @@ -0,0 +1,403 @@ + + + + + + Accounts and Collections + + + +
+
msgvault · core model
+

Accounts & Collections: Imports vs. Named Groups

+

Each account keeps one import or sync, plus the addresses and numbers that show who the account belongs to. Collections are the user's explicit grouping of accounts.

+ +
+
+
+

Accounts

+
Messages from one import, plus owner identities for sent vs. received.
+
+ + +
+ +
+
+

Collections

+
Collections are groups of accounts.
+
+ +
+
+
+ All + everything in msgvault +
+
+ Gmail + old mbox + Apple Mail + iMessage + old work + college +
+
+ +
+
+ Personal + a deliberate subset +
+
+ Personal Gmail + iMessage + College email +
+
+ +
+
+ Work + another named view +
+
+ Old work account + Apple Mail archive +
+
+
+ + +
A collection's identity is the union of the identities on its member accounts.
+
+
+
+ + diff --git a/docs/accounts-identities-collections-dedup/assets/account-collection-concept.png b/docs/accounts-identities-collections-dedup/assets/account-collection-concept.png new file mode 100644 index 00000000..579a8d6d Binary files /dev/null and b/docs/accounts-identities-collections-dedup/assets/account-collection-concept.png differ diff --git a/docs/accounts-identities-collections-dedup/assets/deduplication-concept.html b/docs/accounts-identities-collections-dedup/assets/deduplication-concept.html new file mode 100644 index 00000000..5a6155f4 --- /dev/null +++ b/docs/accounts-identities-collections-dedup/assets/deduplication-concept.html @@ -0,0 +1,589 @@ + + + + + + Deduplication + + + +
+
msgvault · safety ladder rung 02
+

Dedup Mode: Hide First, Optionally Delete Later

+

Dedup hides redundant copies. Hard delete and remote delete are separate, opt-in steps. One survivor per duplicate group; the other copies are hidden from normal reads but kept on disk and undo restores them. Permanently removing them from local storage or from the source server are distinct user actions on top of dedup.

+ +
+
+
+

Before

+
The same message arrived through three independent sources.
+
+ +
+ $ deduplicate --collection Personal + + Gmail + mbox + Mail + +
+ +
duplicate group · 3 copies
+ +
+
+
+
+
Personal Gmail
+
current mailbox sync
+
+
+
+
Re: dinner Friday
+
Apr 12 · sam@example.com · raw MIME, 4 labels
+
+
DUP
+
+ +
+
+
+
+
Old mbox
+
historical email backup
+
+
+
+
Re: dinner Friday
+
Apr 12 · sam@example.com · plain body, no labels
+
+
DUP
+
+ +
+
+
+
+
Apple Mail archive
+
retired laptop import
+
+
+
+
Re: dinner Friday
+
Apr 12 · sam@example.com · raw MIME, 1 folder
+
+
DUP
+
+ + +
Survivor selection is deterministic and visible in dry-run output, so the choice is auditable before anything is hidden.
+
+ +
+
+

After dedup

+
One survivor stays visible. The others are hidden on this rung; nothing has been deleted yet.
+
+ +
batch dedup-2026-04-29-01 · undo available
+ +
+
+
+
+
Personal Gmail
+
richer labels, raw MIME present
+
+
+
+
Re: dinner Friday
+
visible to search, TUI, API, MCP, vector
+
+
SURVIVOR
+
+ +
+
+
+
+
Old mbox
+
still in archive · undo restores
+
+
+
+
Re: dinner Friday
+
excluded from normal read paths
+
+ +
+ +
+
+
+
+
Apple Mail archive
+
still in archive · undo restores
+
+
+
+
Re: dinner Friday
+
excluded from normal read paths
+
+ +
+ + +
Remote deletion is a further separate, source-scoped step — moves to the source's trash by default (e.g. Gmail/Trash); permanent removal only with an explicit flag and interactive confirmation. Never inferred from a collection-scope dedup.
+
+
+ +
+
where this sits on the safety ladder
+
+
+
01Scan
+
dry-run · no data touched
+
+
+
02Hide (this step)
+
soft-delete · undo restores
+
+
+
03Local hard deleteopt-in
+
delete-deduped · irreversible
+
+
+
04Remote deleteopt-in
+
source-scoped · moves to source trash
+
+
+
Backup first, by default. Before this step modifies any row, msgvault writes a point-in-time database copy (*.dedup-backup-<ts>) — opt out only with --no-backup. And deletion is never required: you can run deduplicate as many times as you want and stay on rung 02 forever. delete-deduped and delete-staged only run if you invoke them.
+
+
+ + diff --git a/docs/accounts-identities-collections-dedup/assets/deduplication-concept.png b/docs/accounts-identities-collections-dedup/assets/deduplication-concept.png new file mode 100644 index 00000000..04903530 Binary files /dev/null and b/docs/accounts-identities-collections-dedup/assets/deduplication-concept.png differ diff --git a/docs/accounts-identities-collections-dedup/assets/safety-ladder-concept.html b/docs/accounts-identities-collections-dedup/assets/safety-ladder-concept.html new file mode 100644 index 00000000..c132e804 --- /dev/null +++ b/docs/accounts-identities-collections-dedup/assets/safety-ladder-concept.html @@ -0,0 +1,522 @@ + + + + + + Data Safety Ladder + + + +
+
msgvault · data safety ladder
+

🪜 Dedup Safety: A Ladder with 4 Rungs

+

Each rung is a separate, explicit user action. msgvault never escalates from one rung to the next on its own. "Apply dedup" never implies hard delete. "Hard delete locally" never implies remote delete. "Remote delete" never implies permanent remote delete.

+ +
+ deletion is never required +
You can run deduplicate as many times as you want and stay on rung 02 forever. Rungs 03 and 04 are opt-indelete-deduped and delete-staged only run if you choose to invoke them.
+
+ +
+ 00 +
+
Backup (default-on, before any rung)
+
Before deduplicate or delete-deduped changes a single row, msgvault writes a point-in-time copy of your database alongside it (e.g. msgvault.db.dedup-backup-20260503-083100). Opt out only with --no-backup.
+
+ --no-backup to skip +
+ +
+
+
+
+ 01 +
Scan
+
+ deduplicate --dry-run +
Detect duplicates and report what would change. No rows are modified.
+
    +
  • Survivor choices visible in dry-run output
  • +
  • Auditable before any data moves
  • +
  • Default for first-time runs
  • +
+
no data touched
+
+ +
+
+
+ 02 +
Hide (apply dedup)
+
+ deduplicate +
Soft-delete redundant copies. Survivor stays visible; losers are excluded from every normal read path but stay on disk.
+
    +
  • Hidden from search, TUI, API, MCP, vector / hybrid, stats
  • +
  • Batch ID written for audit
  • +
  • --undo <batch-id> restores visibility
  • +
+
reversible · undo
+
+ +
+
+
+ 03 +
Local hard delete
+
+ delete-deduped --batch <id> +
Permanently remove rows that a named dedup batch hid. Refuses to operate on rows it did not hide.
+
    +
  • Separate command — never bundled with dedup
  • +
  • --all-hidden requires interactive confirmation
  • +
  • Undo cannot recover purged rows
  • +
+
irreversible locally
+
+ +
+
+
+ 04 +
Remote delete
+
+ delete-staged (per-source) +
Delete from the source server (Gmail, IMAP, …). Stays source-scoped even when dedup ran across a collection.
+
    +
  • Same-source-only — cross-source dup groups stage zero remote deletions
  • +
  • Moves to source trash by default (e.g. Gmail/Trash, ~30-day recovery)
  • +
  • Permanent removal needs explicit flag + interactive confirmation
  • +
+
leaves the local archive
+
+
+ + + +
+
+ Undo scope + --undo <batch-id> restores rows hidden by that dedup batch and cancels the batch's pending remote-deletion manifest where the source has not yet executed it. It does not restore an exact pre-run database state, and it cannot recover rows that have moved past rung 2. +
+
+ Cross-source rule at rung 4 + Even under deduplicate --collection, a remote-deletion entry is only staged when the loser and the survivor share the same source. Cross-source duplicate groups produce no remote-deletion entries — the originals on each source server stay put. +
+
+
+ + diff --git a/docs/accounts-identities-collections-dedup/assets/safety-ladder-concept.png b/docs/accounts-identities-collections-dedup/assets/safety-ladder-concept.png new file mode 100644 index 00000000..602c59a4 Binary files /dev/null and b/docs/accounts-identities-collections-dedup/assets/safety-ladder-concept.png differ diff --git a/docs/accounts-identities-collections-dedup/assets/survivor-selection-concept.html b/docs/accounts-identities-collections-dedup/assets/survivor-selection-concept.html new file mode 100644 index 00000000..ce47a743 --- /dev/null +++ b/docs/accounts-identities-collections-dedup/assets/survivor-selection-concept.html @@ -0,0 +1,435 @@ + + + + + + Survivor Selection + + + +
+
msgvault · how survivors are picked
+

Survivor Selection: A Filter, Then a Priority List

+

When any message in a duplicate group looks like a sent copy, only sent copies are eligible to survive. The priority list runs only on the eligible set — earlier rules win outright; later rules only break ties left from earlier ones.

+ +
+
+
+

Stage 1 — Eligibility filter

+
Runs before any tie-breaking. Removes received copies from the group when a sent copy is present.
+
+ +
if any sent-signal fires on any copy
+ +
+ sent-message safety rule +
A copy is treated as sent when ANY of these fires:
+
+
+ A +
Gmail SENT label on the message
+
+
+ B +
is_from_me ingest metadata flag set
+
+
+ C +
From address matches a confirmed identity for the message's account
+
+
+
A · or · B · or · C
+
+ +
Received-copy candidates are dropped from the group before tie-breaking. Losing the sent signal silently changes how the archive reads — "I sent this" is harder to recover than "I received this."
+ +
If no sent-signal fires anywhere in the group, every copy stays eligible and the priority list runs over the full group.
+
+ +
+
+

Stage 2 — Priority list

+
Earlier rules win outright. Later rules only apply when all earlier ones tie.
+
+ +
+
+
01
+
+
Source preference (when configured)
+
user-set ordering, e.g. prefer current sync over old archive
+
+
policy
+
+
+
+
02
+
+
Has raw MIME / complete original payload
+
prefer the copy that preserves the full message bytes
+
+
payload
+
+
+
+
03
+
+
Source metadata quality
+
provider IDs, threading info, message-id presence
+
+
metadata
+
+
+
+
04
+
+
Richer label or folder metadata
+
Gmail labels, IMAP folders, Apple Mail mailboxes
+
+
organization
+
+
+
+
05
+
+
Earlier archived timestamp (when meaningful)
+
older archive entry, all else equal
+
+
recency
+
+
+
+
06
+
+
Stable row ID
+
final tie-breaker — guarantees deterministic output
+
+
deterministic
+
+
+ + +
+
+
+ + diff --git a/docs/accounts-identities-collections-dedup/assets/survivor-selection-concept.png b/docs/accounts-identities-collections-dedup/assets/survivor-selection-concept.png new file mode 100644 index 00000000..b1a69f01 Binary files /dev/null and b/docs/accounts-identities-collections-dedup/assets/survivor-selection-concept.png differ diff --git a/docs/accounts-identities-collections-dedup/spec.md b/docs/accounts-identities-collections-dedup/spec.md new file mode 100644 index 00000000..3fb19d3e --- /dev/null +++ b/docs/accounts-identities-collections-dedup/spec.md @@ -0,0 +1,926 @@ +# Accounts, Identities, Collections, and Deduplication — Specification + +This is the authoritative reference for how msgvault organizes +ingested communications, identifies which messages belong to whom, +and removes redundant local copies without destroying the underlying +archive. It defines the conceptual model, the schema, the CLI +surface, the read-side contract, the manifest formats, and the +errors that implementations must produce. + +If a code change disagrees with this document, one of them is wrong. +Open a PR against either to reconcile. + +## Order of introduction: Account, Identity, Collection (AIC) + +We always introduce the data model in the order **Account, Identity, +Collection** — AIC. The order is not cosmetic. It tracks the +dependency chain. + +- **Account** is the atomic unit. One ingest source, with no + reference to any other concept. Every other piece of the model + depends on it. +- **Identity** is per-account. The question "who am I in this + source?" only makes sense once an account exists to ask it + about. Identity reads as a property of an account. +- **Collection** is a grouping of accounts. It composes accounts and + inherits the union of their identities. Collections only make + sense once both halves are in place. + +Reverse the order and the definitions wobble. "A collection is a +named group of accounts" with no account defined yet is a forward +reference. "An identity belongs to an account" introduced after +collections invites the wrong first question — whether collections +have their own identity — before the right one: whose addresses +count as "me" here? + +Code, prose, diagrams, and CLI help text introduce the model in AIC +order. Deduplication arrives last, because it operates over the +model the first three concepts define. + +## Reading order + +| Section | What it specifies | +| ----------------------------------------------------------------------- | ---------------------------------------------------------------- | +| [Conceptual model](#conceptual-model) | Account, Identity, Collection (AIC), `All` | +| [Scope semantics](#scope-semantics) | `--account` / `--collection`, name-conflict rules | +| [Identity model](#identity-model) | Per-account confirmed identifiers, signal sets, comparison rules | +| [Deduplication model](#deduplication-model) | Detection, survivor selection, sent-message eligibility filter | +| [Data safety ladder](#data-safety-ladder) | Rungs 00–04, escalation rules, `--no-backup` | +| [Live-message contract](#live-message-contract) | The `LiveMessagesWhere` predicate and where it applies | +| [Remote deletion model](#remote-deletion-model) | Same-source rule, manifest format, env-var release guardrail | +| [Undo model](#undo-model) | What undo restores, what it does not | +| [Schema](#schema) | DDL for `collections`, `account_identities`, `applied_migrations`, message-row dedup columns | +| [CLI surface](#cli-surface) | Every command and flag, with mutual-exclusion rules and defaults | +| [Backup behavior](#backup-behavior) | When backup runs, where it writes, opt-out | +| [Batch identifiers](#batch-identifiers) | Format and uniqueness rules for dedup-batch IDs | +| [Error catalog](#error-catalog) | Verbatim error strings the CLI emits | +| [Migration semantics](#migration-semantics) | Legacy `[identity]` config → per-account records | +| [Cache and index policy](#cache-and-index-policy) | What "filtering is the contract" means | +| [Scope review checklist](#scope-review-checklist) | What to verify before merging changes that touch this area | + +## Conceptual model + +The model has three concepts, introduced in AIC order: Account, +Identity, Collection. `All` is named separately as the default +collection that ships pre-built. Deduplication operates over all +three and is specified in its own section below. + +### Account + +An **account** is one ingested message source/archive. It is the +smallest durable provenance unit in msgvault. One Gmail sync source. +One IMAP source. One mbox import. One Apple Mail import. One iMessage +import. One SMS import. One Facebook Messenger import. One +meeting-transcript import source. + +The same real-world mailbox imported through Gmail sync and later +through an old mbox export creates two accounts. They may represent +the same human mailbox. They are distinct archives with distinct +provenance and distinct source-specific deletion semantics. + +msgvault never infers that two imports belong together because an +email address, display name, or message content overlaps. + +### Identity + +An **identity** belongs to an account. It is the set of addresses, +phone numbers, or other protocol-specific identifiers that mean "me" +in that source. A **confirmed identity** means messages from that +address or identifier can be treated as "from me" within that +account's context. + +Identity is account-scoped because the same address can appear in +multiple imports — one address is "me" in one source and not in +another. A global identity list collapses that distinction. The +[Identity model](#identity-model) section below specifies the +discovery signals, comparison rules, and storage shape. + +### Collection + +A **collection** is a named grouping of accounts. It is the user's +explicit statement that multiple sources should be viewed or operated +on together. Examples: `All`, `work`, `personal`, `old laptop +imports`, `gmail plus exports`, `family messages`. + +Collections are many-to-many. An account can belong to multiple +collections. A collection can contain multiple accounts. A collection +contains account/source IDs, not other collections. + +Collections are the boundary for cross-account features. To search, +count, deduplicate, or export across two independent archives, the +user puts them in a collection. + +A collection's identity is the union of confirmed identities from its +member accounts, computed at read time. It is not a separately stored +object. + +### `All` + +`All` is the default collection containing every account/source. +msgvault creates and maintains it automatically. It is still a +collection. Operations against `All` are collection-scoped operations +and the CLI reports them that way. + +`All` is immutable through the CLI. msgvault rejects +`collection delete All` and explicit membership edits on `All`. New +accounts join `All` automatically when they are created. + +![Accounts and collections](./assets/account-collection-concept.png) + +## Scope semantics + +The user-facing scope vocabulary is small. + +| Scope | Meaning | +| ---------------- | ------------------------------------------------------- | +| Account scope | One source/archive. | +| Collection scope | All member accounts of one collection. | +| All scope | Every source/archive. The `All` collection is the model's named handle for this set; implementations may resolve it through the collection record or by omitting the source-id filter, provided membership equivalence holds. | + +CLI flags expose those boundaries directly. + +| Flag | Resolves to | +| --------------------------- | ------------------------------------------------------- | +| `--account ` | Exactly one account/source. | +| `--collection ` | Exactly one collection. | +| (omitted, where supported) | The command's documented default — typically per-account iteration for dedup, or `All` for search and browse. | + +`--account` and `--collection` are mutually exclusive. Implementations +must enforce this at the cobra layer, not at runtime. + +### Name-conflict rules + +A collection name and an account identifier can collide. The +resolver rejects each form against the wrong flag with a hint: + +- If `work` is a collection, `--account work` returns + `"work" is a collection, not an account; use --collection work`. +- If `alice@example.com` is an account, `--collection alice@example.com` + returns `"alice@example.com" is an account, not a collection; use --account alice@example.com`. +- If `--account` matches no source and no collection, + `no account found for "" (try 'msgvault list-accounts')`. +- If `--collection` matches no collection and no source, + `no collection named "" (try 'msgvault collection list')`. +- If `--account ` matches multiple sources by identifier or + display name, + `ambiguous account "" matches multiple sources: []`. + +The full verbatim catalog is in [Error catalog](#error-catalog). + +## Identity model + +The [Conceptual model](#conceptual-model) introduces identity. This +section specifies the implementation: which signals confirm an +identity, how identifiers compare, and how the collection-level +union resolves at read time. + +A collection's identity surfaces through `GetIdentitiesForScope` (or +equivalent) when running scoped operations. It is the union of +confirmed identities from member accounts, computed at read time, not +a separately stored object. + +### Discovery signals + +Confirmed identifiers carry a `source_signal` field that records +which signals confirmed them. A confirmed identifier may carry one or +more of: + +- `is_from_me` — ingest metadata flagged the message as sent by the + account owner. +- `sent-folder` / `sent-label` — the message was found in a + sent-mail folder or had a Gmail `SENT` label. +- `account-identifier` — the address matches the account's primary + identifier (e.g. the Gmail address itself). +- `oauth` — OAuth or provider account metadata named the address. +- `manual` — the user added the identifier interactively via + `identity add`. +- `config_migration` — the identifier was inserted by the one-time + migration that promotes a legacy `[identity]` config block into + per-account records. Distinct from `manual` so `identity list` can + show provenance accurately. + +Signals are stored as a sorted comma-separated set in +`account_identities.source_signal` (e.g. `account-identifier,manual`). +An identity gains signals over time as new evidence appears; signals +are never removed except by `identity remove`. + +Global identity configuration is not part of the model. Legacy +`[identity]` config is migrated once on upgrade — see +[Migration semantics](#migration-semantics). + +### Identifier comparison + +Email-shaped identifiers compare case-insensitively for the local +part and the domain. Other identifiers (phone numbers, account +strings) compare exactly. Phone numbers are stored in normalized +E.164 form (`+1...`), which the CLI applies on `identity add`. + +## Collection behavior + +Required: + +- `All` is created and maintained automatically. +- Users can create named collections from accounts. +- Users can add and remove accounts from collections. +- Collection membership accepts only accounts/sources. +- Collection views preserve account provenance. + +Out of scope: + +- Nested collections. +- Implicit collection creation from matching email addresses. +- Treating a collection as an account. + +## Deduplication model + +Deduplication removes redundant local copies from normal user-facing +results without destroying the underlying archive. + +### Valid scopes + +| Invocation | Boundary | +| --------------------------------------- | ----------------------------------------------------------- | +| `deduplicate --account ` | Compare messages only within that account/source. | +| `deduplicate --collection ` | Compare messages across member accounts in that collection. | +| `deduplicate` | Process each account independently, in iteration. | + +The unscoped form is per-account cleanup. It does not compare all +messages across all accounts as one global set. + +The unscoped default is per-account iteration rather than +`--collection All` because cross-account dedup is the higher-risk +operation. It can collapse duplicates between independent archives +whose provenance the user may want to preserve. Cross-account dedup +requires explicit `--collection`. To dedup across every account, +write `--collection All`. + +### Detection + +Duplicate detection uses multiple signals: + +- RFC822 `Message-ID` header (`messages.rfc822_message_id`). +- Normalized raw MIME or body content hash. +- Provider/source message IDs where appropriate + (`messages.source_message_id`). +- Attachment content hashes where relevant. + +Detection runs in two passes: + +1. **Message-ID pass.** Group messages by RFC822 Message-ID. Each + distinct ID forms one duplicate group. +2. **Content-hash pass.** Among the messages that survive the + Message-ID pass (i.e., excluding identified losers), group by + normalized content hash. Messages without a Message-ID and + Message-ID survivors are both eligible — they may form + content-hash groups together when their normalized payloads + match. + +The two passes are sequential, not transitively unioned. A +content-hash group with two Message-ID survivors keeps both as +winners (one per Message-ID group). A content-hash group containing +both a Message-ID survivor and a sent-copy orphan (a sent copy with +no Message-ID) is skipped to preserve the sent-message eligibility +filter. A future revision may treat every signal as a +transitive-union surface; today's behavior is sequential. + +### Survivor selection + +Survivor selection is deterministic and explainable. Two stages run +in order. + +**Stage 1 — eligibility filter.** When any message in a duplicate +group looks like a sent copy, only sent copies are eligible to +survive. Received-copy candidates drop from the group before +tie-breaking. A message looks like a sent copy when any of the +following fires (OR): + +- a Gmail `SENT` label on the message. +- an `is_from_me` flag on the message from ingest metadata. +- the `From` address matches a confirmed identity for the message's + account (case-insensitive for email-shaped identifiers). + +This is an eligibility filter, not a tie-breaker. Letting a received +copy win on payload richness silently changes how the archive reads. +"I sent this" is harder to recover from data than "I received this." + +**Stage 2 — priority list.** Within the eligible set, survivor +preference runs in this order: + +1. Source preference (when `--prefer` is configured, or the default + order: `gmail,imap,mbox,emlx,hey`). +2. Has raw MIME / complete original payload. +3. Source metadata quality — provider IDs, threading info, presence + of Message-ID. +4. Richer label or folder metadata. +5. Earlier `archived_at` timestamp (when meaningful). +6. Stable row ID, as the final tie-breaker. + +Earlier rules win outright. Later rules apply only when all earlier +ones tie. The exact policy is visible in dry-run output. + +![Survivor selection: filter then priority list](./assets/survivor-selection-concept.png) + +### Effects of applying dedup + +A successful `deduplicate` run: + +- Chooses one survivor per duplicate group. +- Hides redundant local rows by setting `messages.deleted_at` and + `messages.delete_batch_id`. +- Unions labels from non-survivors onto the survivor. +- Backfills the survivor's raw MIME if a non-survivor had it and the + survivor did not. +- Writes the batch ID to a manifest entry for audit and undo. +- Stages remote-deletion manifest entries only when explicitly + requested via `--delete-dups-from-source-server` AND the loser + and survivor share a `source_id`. + +Dedup does not silently escalate from local hiding to local hard +deletion or remote deletion. + +![Dedup hides; delete is separate](./assets/deduplication-concept.png) + +## Data safety ladder + +We call the design the **dedup safety ladder**: four explicit rungs +the user climbs deliberately, plus a rung-zero backup that runs by +default. + +**Deletion is never required.** A user can run `deduplicate` as many +times as they want and stay on rung 02 forever. Rungs 03 and 04 are +opt-in and only run if the user invokes a different command. + +Each rung is a separate, explicit user action. msgvault never +escalates from one rung to the next on its own. "Apply dedup" never +implies hard delete. "Hard-delete locally" never implies remote +delete. "Remote delete" never implies permanent remote delete. + +| Rung | Action | Command | Reversibility | +| ---- | ------------------- | -------------------------------- | ---------------------------------- | +| 00 | Backup (default-on) | (automatic before 02 and 03) | n/a — produces a recoverable file | +| 01 | Scan | `deduplicate --dry-run` | no data touched | +| 02 | Hide (apply dedup) | `deduplicate` | reversible via `--undo ` | +| 03 | Local hard delete | `delete-deduped --batch ` | irreversible locally | +| 04 | Remote delete | `delete-staged` (per-source) | leaves the local archive | + +**Rung 00 — Backup.** Before `deduplicate` or `delete-deduped` +modifies any row, msgvault writes a point-in-time copy of the +database. See [Backup behavior](#backup-behavior). Scan (rung 01) +modifies no rows and triggers no backup. + +**Rung 01 — Scan.** `deduplicate --dry-run` detects duplicates and +reports what would change. No rows are modified. Survivor choices +are visible in the output. The result is auditable before any data +moves. + +**Rung 02 — Hide.** `deduplicate` applies the scan. Pruned copies +are soft-deleted: hidden from normal reads, kept on disk. +`--undo ` restores visibility. + +**Rung 03 — Local hard delete.** `delete-deduped` permanently +removes hidden rows from the local archive. By default it acts on +named batches via `--batch ` and refuses to operate on rows it +didn't hide. The `--all-hidden` form purges every hidden row and +requires interactive confirmation. Backup runs again before the +purge unless `--no-backup` is set. Undo cannot recover purged rows. + +**Rung 04 — Remote delete.** `delete-staged` deletes from the source +server. Same-source-only — see [Remote deletion model](#remote-deletion-model). +The default action moves messages to the source's trash. Permanent +removal requires `--permanent` and an interactive confirmation that +checks for the literal word `delete`. + +Attachment dedup is independent of message dedup. Attachments live +in a content-addressed pool, so identical files are stored once +regardless of how many messages reference them. Hiding or +hard-deleting a duplicate message does not delete the underlying +attachment blob unless no remaining message references it. + +![Data safety ladder](./assets/safety-ladder-concept.png) + +## Live-message contract + +A **live message** is a row that has not been locally hidden by +dedup and has not been recorded as deleted from the source server. +The term is internal vocabulary for this contract. It appears in +implementation slices and code as the `LiveMessagesWhere` predicate. + +Normal user-facing reads return live messages only. + +The contract applies to: + +- Message search (FTS5, deep search). +- Vector and hybrid search. +- TUI browsing and drill-downs. +- Stats and aggregates. +- API responses. +- MCP responses. +- Exports that claim to represent the visible archive. + +### The `LiveMessagesWhere` predicate + +`LiveMessagesWhere(alias, hideDeletedFromSource)` returns a SQL +fragment that filters out non-live rows. The two flag values are: + +| `hideDeletedFromSource` | Filters out rows where | +| ----------------------- | -------------------------------------------------------------- | +| `false` | `deleted_at IS NOT NULL` (dedup-hidden only) | +| `true` | `deleted_at IS NOT NULL` OR `deleted_from_source_at IS NOT NULL` | + +The `alias` argument is the table alias the query uses for +`messages` (most commonly `m` or `msg`; pass `""` for unaliased +references). + +Every read path constructs its WHERE clause through this predicate. +Implementations must not inline the comparison. + +Indexes and caches may lag behind canonical SQLite state. Normal +retrieval still filters through `LiveMessagesWhere`. Rebuilding +derived surfaces is operational hygiene; it is not the only thing +keeping hidden duplicates out of results. + +### Query scope is consistent across backends + +| Scope | Source-ID set passed into the predicate | +| ---------------- | ---------------------------------------------------- | +| Account scope | One source ID. | +| Collection scope | The set of source IDs in `collection_sources`. | +| All scope | Every source ID (resolved from the `All` collection). | + +Backend differences are acceptable for ranking or performance. They +are not acceptable for scope membership or live-message visibility. + +## Remote deletion model + +Remote deletion is a separate operation from local dedup. Even when +duplicate detection runs across a collection, remote deletion +decisions remain source-specific. + +### Rules + +- **Same-source constraint.** A remote-deletion entry is staged only + when the loser and the survivor share a `source_id`. Cross-source + duplicate groups produce no remote-deletion entries even when the + dedup scope is a collection that spans those sources. +- **Source-scoped manifests.** Manifest filenames and reporting + labels reflect the source, never the collection name, even when + dedup was invoked under `--collection`. +- **Move to source trash by default.** Where the source supports a + recoverable trash state (e.g. Gmail's ~30-day Gmail/Trash), the + default remote-deletion behavior moves messages there. +- **Permanent deletion is opt-in.** Permanent remote deletion + requires `--permanent` and interactive confirmation. It is never + the default, never inferred from dedup, and never applied in batch + without the user acknowledging the source and scope at the moment + of the action. +- **Release guardrail.** The destructive `delete-staged` execute + path is gated behind the environment variable + `MSGVAULT_ENABLE_REMOTE_DELETE=1` for the v1 release. Read-only + modes (`--list`, `--dry-run`, `list-deletions`, `show-deletion`) + are always permitted. The gate is independent of `--permanent`. + Removal of the guardrail is a future release decision. + +### Manifest format + +`Manager` writes one JSON manifest per deletion batch under +`~/.msgvault/deletions//.json`, where `` +is one of `pending`, `in_progress`, `completed`, `cancelled`, +`failed`. Status transitions move the file between subdirectories. + +```json +{ + "version": 1, + "id": "", + "created_at": "2026-05-03T16:45:00Z", + "created_by": "cli", // "tui" | "cli" | "api" + "description": "dedup batch dedup-...", + "filters": { /* source-specific filter restoration */ }, + "summary": { /* counts, per-sender breakdown */ }, + "gmail_ids": ["", "..."], + "status": "pending", // pending | in_progress | completed | cancelled | failed + "execution": null // populated after delete-staged runs +} +``` + +`gmail_ids` is the source-scoped list of provider IDs to delete. +It is named `gmail_ids` for historical reasons; for IMAP sources it +holds the provider's UID. A future revision may rename it. + +`Manifest.Version` is `1`. A later schema bump will increment it and +implementations must reject mismatches with a clear error. + +The on-disk directory location (`/.json`) is the +spec-authoritative state for a manifest. The inline `status` field is +convenience metadata that should match the directory; implementations +should update it before moving the file. A future revision may +declare directory-as-truth and reduce `status` to a derived field — +implementations should not treat the inline value as load-bearing on +disagreement. + +## Undo model + +Undo is not full time travel. + +`--undo ` restores rows hidden by that dedup batch by +clearing `deleted_at` and `delete_batch_id` on every row whose +`delete_batch_id` matches. It also cancels the batch's pending +remote-deletion manifest where the source has not yet executed it. + +`--undo` does not: + +- Reverse survivor label unioning. +- Reverse raw-MIME enrichment from non-survivors onto the survivor. +- Restore derived-index state. Indexes catch up on the next rebuild. +- Recover rows the user has since purged through `delete-deduped`. +- Reverse remote deletions already executed against a source. + +Implementations must surface this gap in the user-facing +`--undo` output text, not bury it in documentation. + +`--undo` is repeatable: passing multiple `--undo ` flags is an +ordered sequence of independent undos. Failures on one batch do not +skip later batches and errors are aggregated. `--undo` is mutually +exclusive with `--account`, `--collection`, and `--dry-run`. + +## Schema + +This section is the canonical schema reference. Identifier types are +SQLite; PostgreSQL definitions in `schema.sql` follow the same +shape with dialect-appropriate types. + +### Sources (account ingest unit) + +Defined in the broader `schema.sql`. Relevant columns for this spec: + +```sql +CREATE TABLE IF NOT EXISTS sources ( + id INTEGER PRIMARY KEY, + source_type TEXT NOT NULL, -- 'gmail', 'imap', 'mbox', 'emlx', ... + identifier TEXT NOT NULL, -- email, phone number, account ID + display_name TEXT, + -- ... sync state, oauth_app, timestamps ... + UNIQUE(source_type, identifier) +); +``` + +### Collections + +```sql +CREATE TABLE IF NOT EXISTS collections ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + description TEXT NOT NULL DEFAULT '', + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS collection_sources ( + collection_id INTEGER NOT NULL REFERENCES collections(id) ON DELETE CASCADE, + source_id INTEGER NOT NULL REFERENCES sources(id) ON DELETE CASCADE, + PRIMARY KEY (collection_id, source_id) +); + +CREATE INDEX IF NOT EXISTS idx_collection_sources_source_id + ON collection_sources(source_id); +``` + +The `All` collection is a row in `collections` with name `All`. It +is auto-managed by store bootstrap. CLI mutations on it return +`ErrCollectionImmutable`. + +### Account identities + +```sql +CREATE TABLE IF NOT EXISTS account_identities ( + source_id INTEGER NOT NULL REFERENCES sources(id) ON DELETE CASCADE, + address TEXT NOT NULL, -- case-preserved + source_signal TEXT NOT NULL DEFAULT '', -- sorted comma-separated signal set + confirmed_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (source_id, address) +); + +CREATE INDEX IF NOT EXISTS idx_account_identities_address + ON account_identities(address); +``` + +`address` preserves case as the user entered it. Comparison uses +`identifierMatch` (case-insensitive for email-shaped identifiers, +exact otherwise). + +### Dedup soft-delete columns on `messages` + +The `messages` table carries three columns that this feature uses: + +```sql +deleted_at DATETIME, -- set when dedup hides the row +deleted_from_source_at DATETIME, -- set when delete-staged executes against a source +delete_batch_id TEXT -- ties hidden rows to their dedup batch +``` + +`deleted_at IS NULL` is the local-hide gate. `deleted_from_source_at +IS NULL` is the remote-deletion gate. The `LiveMessagesWhere` +predicate combines them. + +### Applied migrations + +```sql +CREATE TABLE IF NOT EXISTS applied_migrations ( + name TEXT PRIMARY KEY, + applied_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); +``` + +DDL changes use `IF NOT EXISTS`. This table records *data* +migrations that must run exactly once (e.g. +`legacy_identity_to_per_account`). + +## CLI surface + +This section is the source of truth for command names and flags. +Long-form flag names are stable; short forms are noted where they +exist. Defaults are the literal in-code defaults. + +### `msgvault deduplicate` + +Find duplicate messages and (with `--undo`) reverse a previous run. + +| Flag | Type | Default | Notes | +| ----------------------------------- | --------- | ----------------------------- | ------------------------------------------------------------------ | +| `--dry-run` | bool | `false` | Scan and report only. | +| `--no-backup` | bool | `false` | Skip the pre-execute database backup. | +| `--prefer ` | string | (none) | Source-type preference order for survivor selection. When the flag is empty, implementations fall back to the documented default order: `gmail,imap,mbox,emlx,hey`. The fall-through gives a single source of truth for the default — implementations should not register the literal default string as the cobra-layer default. | +| `--content-hash` | bool | `false` | Run the second-pass content-hash detection. | +| `--undo ` (repeatable) | string... | (none) | Reverse one or more named batches. | +| `--account ` | string | (none) | Per-source scope. | +| `--collection ` | string | (none) | Cross-source scope inside one collection. | +| `--delete-dups-from-source-server` | bool | `false` | DESTRUCTIVE: stage pruned duplicates for remote deletion. Execution still requires `MSGVAULT_ENABLE_REMOTE_DELETE=1`. | +| `--yes` / `-y` | bool | `false` | Skip the confirmation prompt. | + +Mutually exclusive flag pairs (enforced at the cobra layer): + +- `--account` ⊕ `--collection` +- `--dry-run` ⊕ `--undo` +- `--undo` ⊕ `--account` +- `--undo` ⊕ `--collection` +- `--undo` ⊕ `--delete-dups-from-source-server` + +Unscoped invocation iterates over every account in isolation. It +never crosses source boundaries. + +### `msgvault delete-deduped` + +Permanently remove rows that a named dedup batch hid (rung 03). + +| Flag | Type | Default | Notes | +| ----------------------------------- | --------- | ------------- | -------------------------------------------------- | +| `--batch ` (repeatable) | string... | (none) | Named batches to purge. Refuses unrelated rows. | +| `--all-hidden` | bool | `false` | Purge every hidden row across all batches. | +| `--no-backup` | bool | `false` | Skip the pre-execute database backup. | +| `--yes` / `-y` | bool | `false` | Skip the confirmation prompt. | + +`--all-hidden` always requires interactive confirmation (the `-y` +shortcut does not bypass it). Without `--batch` or `--all-hidden`, +the command errors with `must specify --batch or --all-hidden`. + +### `msgvault collection` + +| Subcommand | Notes | +| --------------------------------------------------- | ------------------------------------------------------------------------------------------- | +| `collection create --accounts ` | Create a new collection with the given members. Rejects `name = All` with `ErrCollectionImmutable`. | +| `collection list` | List every collection with member counts. | +| `collection show ` | Show collection details and members. | +| `collection add --accounts ` | Add accounts to a collection. Rejects `name = All`. | +| `collection remove --accounts ` | Remove accounts from a collection. Rejects `name = All`. | +| `collection delete ` | Delete a collection. Rejects `name = All`. Underlying sources and messages stay untouched. | + +### `msgvault identity` + +| Subcommand | Notes | +| --------------------------------------------------- | ------------------------------------------------------------------------------------------- | +| `identity list` | List confirmed identifiers across one or more accounts. Accepts `--account` / `--collection`. | +| `identity show ` | Show one account's identity in detail, including signal sets. | +| `identity add ` | Add a confirmed identifier with `manual` signal. | +| `identity remove ` | Remove a confirmed identifier. Returns the rows-affected count. | + +### `msgvault list-deletions` / `show-deletion` / `cancel-deletion` / `delete-staged` + +| Command | Notes | +| ------------------------------------ | ------------------------------------------------------------------------------------------------ | +| `list-deletions` | List pending and recent deletion batches. Always permitted regardless of the env-var guardrail. | +| `show-deletion ` | Show one batch's manifest. Read-only; permitted regardless of the guardrail. | +| `cancel-deletion [batch-id] [--all]` | Cancel pending or in-progress batches. `--all` cancels every pending or in-progress batch. | +| `delete-staged [batch-id]` | Execute pending remote deletions. Gated behind `MSGVAULT_ENABLE_REMOTE_DELETE=1`. | + +`delete-staged` flags: + +| Flag | Type | Default | Notes | +| ------------------- | ------ | ------- | ------------------------------------------------------------------------------------------- | +| `--permanent` | bool | `false` | DESTRUCTIVE: permanent deletion via batch API. Cannot combine with `--yes`. | +| `--yes` / `-y` | bool | `false` | Skip non-permanent confirmation. Has no effect when `--permanent` is set. | +| `--dry-run` | bool | `false` | Show what would be deleted; never call the source API. | +| `--list` / `-l` | bool | `false` | List staged batches without executing. | +| `--account ` | string | (none) | Required when multiple accounts have pending batches and the manifest does not name one. | + +Permanent deletion requires the literal word `delete` typed at the +confirmation prompt. `y` is not enough. + +### `msgvault search`, `msgvault stats`, `msgvault tui` + +These commands accept the scope flags symmetrically: + +- `--account ` — one account. +- `--collection ` — one collection. +- (omitted) — defaults to `All` for read commands. + +`--account` and `--collection` are mutually exclusive. The same +name-conflict resolver runs as for `deduplicate`. + +## Backup behavior + +Before any rung that modifies data (rungs 02 and 03), msgvault +writes a point-in-time copy of the database alongside the live DB +file. + +### Naming + +| Command | Filename pattern | +| ----------------- | ---------------------------------------------------------------------- | +| `deduplicate` | `.dedup-backup-` | +| `delete-deduped` | `.delete-deduped-backup-` | + +`` is the active database path. The timestamp uses the local +time zone in the format `20060102-150405`. + +### Mechanism + +Backup uses SQLite's `VACUUM INTO` to produce a point-in-time +consistent copy. Implementations must reject non-file DSNs (e.g. +`postgres://`) up front rather than at the first backup attempt. + +### Opt-out + +`--no-backup` on either command suppresses the backup. The flag does +not affect any other behavior. + +### Lifecycle + +msgvault never deletes a backup file. Disk-space management is the +user's responsibility. A future release may add a `--prune-backups` +helper. + +## Batch identifiers + +Dedup batches are identified by a string of the form: + +``` +dedup---- +``` + +Components: + +- `` — local time, format `20060102-150405`. +- `` — integer primary key from `sources`. +- `` — the source's `identifier`, sanitized + for filename safety (alphanumeric, `-`, `_`). +- `` — 8 hex characters from a CSPRNG, to disambiguate + same-second runs against the same source. + +Example: `dedup-20260503-091500-7-me_at_example.com-0d4cb6f1`. + +Implementations must guarantee uniqueness. The random token exists +to prevent collisions when two runs land in the same second against +the same source. + +## Error catalog + +These are the verbatim error strings the CLI emits for the +identities/collections/dedup feature. Translations are not provided; +implementations match these strings exactly so users and scripts +can recognize them. + +### Scope flag resolution + +| Condition | Message | +| ------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------- | +| `--account X` where `X` is a collection | `"X" is a collection, not an account; use --collection X` | +| `--collection X` where `X` is an account | `"X" is an account, not a collection; use --account X` | +| `--account X` matches multiple sources | `ambiguous account "X" matches multiple sources: [a (gmail, id=1), b (mbox, id=2)]` | +| `--account X` matches nothing | `no account found for "X" (try 'msgvault list-accounts')` | +| `--collection X` matches nothing | `no collection named "X" (try 'msgvault collection list')` | +| Both `--account` and `--collection` set | (cobra layer) `if any flags in the group [account collection] are set none of the others can be; [account collection] were all set` | + +### Collection mutations + +| Condition | Sentinel error | +| ------------------------------------------------------------------ | ----------------------------------------------------------- | +| `collection create All` / `collection delete All` / membership edit on `All` | `ErrCollectionImmutable: cannot modify the auto-managed "All" collection` | +| Lookup of a collection that does not exist | `ErrCollectionNotFound: collection not found` | + +### Deletion-staged + +| Condition | Message | +| ------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------- | +| `delete-staged` invoked without `MSGVAULT_ENABLE_REMOTE_DELETE=1` | (release-guardrail message naming the env var; refused before any API call) | +| `delete-staged` finds no account in manifest, `--account` not set | `no account in deletion manifest - use --account flag` | +| `delete-staged` finds multiple accounts pending, `--account` not set | `multiple accounts in pending batches () - use --account flag to specify which account` | +| Permanent confirmation prompt receives anything other than `delete` | `Cancelled. Drop --permanent to use trash deletion without elevated permissions.` | + +### Deduplicate undo + +| Condition | Message | +| ------------------------------------------------------------------ | ---------------------------------------------------------------------------------------- | +| `--undo X` where `X` matches no batch | `undo dedup "X": batch not found` | +| `--undo X` partially restores (some rows already purged) | summary line plus ` already purged from rung 03 cannot be restored` | + +## Migration semantics + +### Legacy `[identity]` config → per-account records + +On first startup after upgrade, if `config.toml` contains a +top-level `[identity]` block with `addresses = [...]`, msgvault +runs the one-time data migration `legacy_identity_to_per_account`: + +1. For each address in the legacy block: + - For each existing account whose source supports email-shaped + identifiers, insert a confirmed identity record with + `source_signal = manual` if the address is not already + confirmed. +2. Insert a row into `applied_migrations` with + `name = 'legacy_identity_to_per_account'`. +3. Log a warning naming the migration and the number of records + inserted. +4. Print a one-time CLI notice asking the user to review per-account + identities via `msgvault identity list`. + +After migration, the `[identity]` block is no longer read. The +migration runs exactly once: subsequent startups see the +`applied_migrations` row and skip. + +If startup happens before any source exists, the migration defers +until the first source is created, then runs against that source +only. Subsequent source creations apply the legacy block to the new +source until the migration row is finally inserted. + +Removing the legacy `[identity]` block after migration is safe but +not required. + +### Schema migrations + +Schema DDL is idempotent (`CREATE TABLE IF NOT EXISTS`, +`CREATE INDEX IF NOT EXISTS`, dialect-aware `ALTER TABLE … ADD +COLUMN` guards). The `applied_migrations` table is reserved for +data migrations only. + +## Cache and index policy + +The product contract: + +- Dedup changes the canonical archive state. +- Normal reads filter rows that are no longer live, via + `LiveMessagesWhere`. +- Derived indexes (FTS5 shadow tables, Parquet snapshots, vector + indexes) may be rebuilt, updated, or marked stale as an + operational concern. + +Filtering through `LiveMessagesWhere` is mandatory for correctness. +Best-effort derived index cleanup is allowed. Manual rebuild +commands remain available. Any known stale derived surface is +visible in command output or logs. + +## Scope review checklist + +Use this checklist when reviewing changes that touch this area. +Every question should answer cleanly without qualifications. + +- Does "account" always mean one ingest source/archive? +- Is every cross-account operation expressed through a collection? +- Can users tell from the command or UI when they are crossing + account/source boundaries? +- Are identities account-scoped rather than global, with a defined + migration from any legacy global config? +- Is `All` modeled as a collection, and is it immutable through the + CLI? +- Are collections first-class query scopes across every read + surface? +- Are hidden duplicates excluded from every normal read path through + the `LiveMessagesWhere` predicate, not by inline filtering? +- Does dedup honor sent-message eligibility before falling back to + the survivor priority list? +- Does the safety ladder keep scan / hide / local hard delete / + remote delete as four separate user actions, with no automatic + escalation between them? +- Does remote deletion stay same-source-only, default to moving to + source trash, and require explicit confirmation for permanent + removal? +- Is the v1 release guardrail (`MSGVAULT_ENABLE_REMOTE_DELETE=1`) + still enforced for the destructive `delete-staged` execute path? +- Does undo avoid promising exact rollback, both in code and in + user-facing text? +- Do error messages match the verbatim strings in the + [Error catalog](#error-catalog)? + +--- + +*Authored by [@jesserobbins](https://github.com/jesserobbins), with +support from [@wesm](https://github.com/wesm). Tools used: +[Primeradiant Superpowers](https://github.com/obra/superpowers) and +[roborev](https://github.com/wesm/roborev), backed by Claude and Codex.* diff --git a/internal/config/config.go b/internal/config/config.go index 9006eeae..d39e7d60 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "net" + "net/url" "os" "path/filepath" "runtime" @@ -70,6 +71,11 @@ type RemoteConfig struct { } // Config represents the msgvault configuration. +// IdentityConfig holds the user's curated identity addresses. +type IdentityConfig struct { + Addresses []string `toml:"addresses"` +} + type Config struct { Data DataConfig `toml:"data"` Log LogConfig `toml:"log"` @@ -80,6 +86,7 @@ type Config struct { Server ServerConfig `toml:"server"` Remote RemoteConfig `toml:"remote"` Vector vector.Config `toml:"vector"` + Identity IdentityConfig `toml:"identity"` Accounts []AccountSchedule `toml:"accounts"` // Computed paths (not from config file) @@ -347,6 +354,51 @@ func (c *Config) DatabaseDSN() string { return filepath.Join(c.Data.DataDir, "msgvault.db") } +// DatabasePath returns the on-disk SQLite filesystem path for backup +// operations (VACUUM INTO, copies). It accepts the plain filesystem +// path and the SQLite "file:" URI form, decoding any percent-encoded +// bytes (e.g. "file:/var/lib/my%20vault.db" -> "/var/lib/my vault.db") +// and dropping the URI query string. Returns an error for non-file +// DSNs (e.g. "postgres://..."), which the SQLite-only backup helpers +// cannot operate on. +func (c *Config) DatabasePath() (string, error) { + dsn := c.DatabaseDSN() + if strings.HasPrefix(dsn, "file:") { + u, err := url.Parse(dsn) + if err != nil { + return "", fmt.Errorf("parse file: URI %q: %w", dsn, err) + } + // SQLite accepts both file:/abs/path (Path) and file:rel/path + // (Opaque) shapes. url.Parse decodes percent-encoding for Path + // but NOT for Opaque, so a relative file: URI like + // "file:my%20vault.db" leaves the encoding intact in u.Opaque + // and the on-disk filename never matches. PathUnescape handles + // the relative-form case explicitly. + path := u.Path + if path == "" { + decoded, err := url.PathUnescape(u.Opaque) + if err != nil { + return "", fmt.Errorf("decode file: URI opaque part %q: %w", u.Opaque, err) + } + path = decoded + } + if path == "" { + return "", fmt.Errorf("empty file: URI in database DSN: %q", dsn) + } + return path, nil + } + if strings.Contains(dsn, "://") { + // postgres://, mysql://, etc. — non-file DSN; backup is + // SQLite-specific and the caller can't operate on these. + return "", fmt.Errorf( + "backup operations require a SQLite filesystem DSN; "+ + "got non-file DSN %q (set [data].database_url to a "+ + "plain filesystem path or file: URI)", dsn, + ) + } + return dsn, nil +} + // AttachmentsDir returns the path to the attachments directory. func (c *Config) AttachmentsDir() string { return filepath.Join(c.Data.DataDir, "attachments") diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 8979e3a2..16aad1de 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1688,3 +1688,101 @@ func TestMicrosoftConfig_DefaultTenant(t *testing.T) { t.Errorf("EffectiveTenantID() = %q, want %q", cfg.Microsoft.EffectiveTenantID(), "common") } } + +func TestDatabasePath(t *testing.T) { + t.Run("plain filesystem path passes through", func(t *testing.T) { + cfg := &Config{} + cfg.Data.DataDir = "/tmp/data" + got, err := cfg.DatabasePath() + if err != nil { + t.Fatalf("DatabasePath: %v", err) + } + want := filepath.Join("/tmp/data", "msgvault.db") + if got != want { + t.Errorf("DatabasePath() = %q, want %q", got, want) + } + }) + + t.Run("file: URI is stripped", func(t *testing.T) { + cfg := &Config{} + cfg.Data.DatabaseURL = "file:/var/lib/msgvault.db" + got, err := cfg.DatabasePath() + if err != nil { + t.Fatalf("DatabasePath: %v", err) + } + if got != "/var/lib/msgvault.db" { + t.Errorf("DatabasePath() = %q, want '/var/lib/msgvault.db'", got) + } + }) + + t.Run("file: URI with query string drops query", func(t *testing.T) { + cfg := &Config{} + cfg.Data.DatabaseURL = "file:/var/lib/msgvault.db?_journal_mode=WAL&_busy_timeout=5000" + got, err := cfg.DatabasePath() + if err != nil { + t.Fatalf("DatabasePath: %v", err) + } + if got != "/var/lib/msgvault.db" { + t.Errorf("DatabasePath() = %q, want '/var/lib/msgvault.db'", got) + } + }) + + t.Run("file: URI decodes percent-encoded path", func(t *testing.T) { + cfg := &Config{} + cfg.Data.DatabaseURL = "file:/var/lib/my%20vault.db" + got, err := cfg.DatabasePath() + if err != nil { + t.Fatalf("DatabasePath: %v", err) + } + if got != "/var/lib/my vault.db" { + t.Errorf("DatabasePath() = %q, want '/var/lib/my vault.db'", got) + } + }) + + t.Run("file: URI relative path (Opaque)", func(t *testing.T) { + // SQLite accepts file:rel/path; url.Parse routes that into u.Opaque. + cfg := &Config{} + cfg.Data.DatabaseURL = "file:msgvault.db" + got, err := cfg.DatabasePath() + if err != nil { + t.Fatalf("DatabasePath: %v", err) + } + if got != "msgvault.db" { + t.Errorf("DatabasePath() = %q, want 'msgvault.db'", got) + } + }) + + t.Run("file: URI relative path with percent-encoding (Opaque)", func(t *testing.T) { + // url.Parse decodes percent-encoding for u.Path but not u.Opaque, + // so DatabasePath has to PathUnescape the relative-form bytes + // itself. Without that, "file:my%20vault.db" never matches the + // on-disk filename "my vault.db" and backups break. + cfg := &Config{} + cfg.Data.DatabaseURL = "file:my%20vault.db" + got, err := cfg.DatabasePath() + if err != nil { + t.Fatalf("DatabasePath: %v", err) + } + if got != "my vault.db" { + t.Errorf("DatabasePath() = %q, want 'my vault.db'", got) + } + }) + + t.Run("postgres:// is rejected", func(t *testing.T) { + cfg := &Config{} + cfg.Data.DatabaseURL = "postgres://user@host:5432/db" + _, err := cfg.DatabasePath() + if err == nil { + t.Fatal("DatabasePath: expected error for non-file DSN, got nil") + } + }) + + t.Run("empty file: URI is rejected", func(t *testing.T) { + cfg := &Config{} + cfg.Data.DatabaseURL = "file:" + _, err := cfg.DatabasePath() + if err == nil { + t.Fatal("DatabasePath: expected error for empty file: URI, got nil") + } + }) +} diff --git a/internal/dedup/dedup.go b/internal/dedup/dedup.go new file mode 100644 index 00000000..42386d09 --- /dev/null +++ b/internal/dedup/dedup.go @@ -0,0 +1,1341 @@ +// Package dedup provides duplicate detection and merging for msgvault. +// +// # Terminology +// +// "Account" means one ingest source/archive (a single Gmail OAuth +// connection, one mbox import, one IMAP source, etc.). "Collection" +// means a named, user-defined grouping of accounts. Cross-source dedup +// is only available via --collection; --account always operates on a +// single source. +// +// # Scoping rules +// +// Without explicit scope, dedup operates on one account at a time and +// duplicate groups can only contain messages ingested twice into the +// same account (for example, re-importing the same mbox twice). +// +// With --account, dedup is restricted to the named account and behaves +// the same way — source boundaries are never crossed. +// +// With --collection, dedup compares messages across every account in +// the collection. This is the only way to merge duplicates that span +// sources, and it is an explicit user opt-in. Pruned losers are hidden +// locally and reversible via --undo. Remote-deletion staging stays +// same-source-only even under collection scope, so the user's +// authoritative remote mailbox can never be touched because of a +// duplicate found in a different account. +// +// Outside collection scope, dedup never merges messages across +// different accounts. This is critical for sent messages: a message +// alice sends to bob is one logical message but it has a legitimate +// copy in alice's Sent folder and another in bob's Inbox. Both copies +// share the same RFC822 Message-ID. If both accounts are archived in +// msgvault, they must be preserved independently because deleting one +// would change the other user's view of history. Sent-message handling +// is covered in more detail by FormatMethodology. +package dedup + +import ( + "bufio" + "bytes" + "compress/zlib" + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "log/slog" + "net/textproto" + "path/filepath" + "runtime" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/wesm/msgvault/internal/deletion" + "github.com/wesm/msgvault/internal/store" +) + +// Config controls the dedup engine behaviour. +type Config struct { + // SourcePreference orders source types when picking a survivor + // inside a duplicate group. Earlier entries win. + SourcePreference []string + + // DryRun reports what would happen without mutating the database + // or writing deletion manifests. + DryRun bool + + // ContentHashFallback also groups messages by normalized raw MIME + // content after the RFC822 Message-ID pass. This is slower, but can + // catch duplicates where Message-ID is missing or transport headers + // are the only difference between copies. + ContentHashFallback bool + + // AccountSourceIDs restricts dedup to the listed sources and + // allows cross-source grouping between them. Callers that want + // strict per-source dedup should leave this empty. + AccountSourceIDs []int64 + + // Account is the canonical identifier for the scoped account + // (for example, "alice@gmail.com"). It is used when building + // deletion manifests and in the methodology output. + Account string + + // ScopeIsCollection is true when AccountSourceIDs spans multiple + // distinct accounts via --collection. The methodology output + // branches on this: collection mode intentionally crosses account + // boundaries, while account/per-source modes do not. + ScopeIsCollection bool + + // DeleteDupsFromSourceServer, when true, writes pending + // deletion manifests for pruned duplicates that meet ALL of: + // 1. the pruned copy lives in a remote source whose type + // appears in remoteSourceTypes (gmail today; imap is gated + // until staged manifests can be routed to an IMAP executor), + // 2. the surviving copy is in the SAME source_id (i.e. the + // very same remote mailbox holds the winner). + // + // This second rule is load-bearing: it guarantees that a + // merged-pile dedup run can never cause deletions from the + // user's authoritative Gmail/IMAP account just because a + // duplicate was found in a local archive. Only true + // intra-mailbox duplicates are ever proposed for remote + // deletion. + // + // Even with this rule, the field defaults to false so that + // destructive side effects never happen without an explicit + // --delete-dups-from-source-server opt-in at the CLI layer. + DeleteDupsFromSourceServer bool + + // DeletionsDir is the directory where staged deletion manifests + // are written. Required when DeleteDupsFromSourceServer is true. + DeletionsDir string + + // IdentityAddressesBySource maps each source ID to the set of + // confirmed "me" addresses for that source. When a pruned + // candidate's From: matches the address set for its source, + // the survivor-selection rule treats the message as a sent + // copy — in addition to the existing Gmail SENT label and + // messages.is_from_me signals. Per-source keying ensures that + // an address confirmed for one account is not treated as "me" + // for a different account. + IdentityAddressesBySource map[int64]map[string]struct{} +} + +// DefaultSourcePreference is the default source-type authority order. +var DefaultSourcePreference = []string{ + "gmail", "imap", "mbox", "emlx", "hey", +} + +// remoteSourceTypes lists source types whose messages can be deleted +// via the deletion-staging machinery. +// +// Only gmail is listed today: the staged-deletion manifest format and +// executor are Gmail-specific (manifest.GmailIDs, gmail.API client). Adding +// "imap" here would let an IMAP dedup run with --delete-dups-from-source-server +// stage manifests that delete-staged would then try to execute through Gmail. +// Re-add IMAP only after manifests record source type and delete-staged can +// route to an IMAP executor. +var remoteSourceTypes = map[string]bool{ + "gmail": true, +} + +// Engine orchestrates duplicate detection and merging. +type Engine struct { + store *store.Store + config Config + logger *slog.Logger +} + +// NewEngine creates a new dedup engine. +func NewEngine(st *store.Store, cfg Config, logger *slog.Logger) *Engine { + if len(cfg.SourcePreference) == 0 { + cfg.SourcePreference = DefaultSourcePreference + } + if logger == nil { + logger = slog.Default() + } + return &Engine{store: st, config: cfg, logger: logger} +} + +// DuplicateGroup represents a set of messages that are duplicates of +// each other (share the same RFC822 Message-ID in the scoped sources). +type DuplicateGroup struct { + Key string // RFC822 Message-ID or normalized hash + KeyType string // "message-id" or "normalized-hash" + Messages []DuplicateMessage // all messages in the group + Survivor int // index into Messages of the chosen survivor +} + +// DuplicateMessage holds metadata for a single message in a duplicate +// group, including sent-message signals for safety checks. +type DuplicateMessage struct { + ID int64 + SourceID int64 + SourceType string + SourceIdentifier string + SourceMessageID string + Subject string + SentAt time.Time + HasRawMIME bool + LabelCount int + ArchivedAt time.Time + IsFromMe bool + HasSentLabel bool + FromEmail string + MatchedIdentity bool +} + +// IsSentCopy reports whether this message appears to be the sender-side +// copy of an outbound email. Three independent signals (OR-combined): +// - Gmail SENT system label on the message +// - messages.is_from_me set at ingest time +// - From: address matches a configured identity address +func (m DuplicateMessage) IsSentCopy() bool { + return m.HasSentLabel || m.IsFromMe || m.MatchedIdentity +} + +// Report summarises the results of a dedup scan. +type Report struct { + TotalMessages int64 + DuplicateGroups int + DuplicateMessages int // messages that would be pruned + BySourcePair map[string]int // "gmail+mbox" -> groups + SampleGroups []DuplicateGroup + Groups []DuplicateGroup + BackfilledCount int64 + ContentHashGroups int + SkippedDecompressionErrors int +} + +// ExecutionSummary summarises the results of dedup execution. +type ExecutionSummary struct { + GroupsMerged int + MessagesRemoved int + LabelsTransferred int + RawMIMEBackfilled int + BatchID string + StagedManifests []StagedManifest +} + +// StagedManifest records a single deletion manifest created by dedup. +type StagedManifest struct { + Account string + SourceType string + ManifestID string + MessageCount int +} + +// remoteKey groups remote source IDs by the (account, source_type) pair so +// that a user with multiple remote sources sharing the same account +// identifier (e.g. gmail + imap for the same address) gets one manifest per +// source type rather than a single manifest whose SourceType label reflects +// only the first contributor. +type remoteKey struct { + Account string + SourceType string +} + +// Scan finds all duplicate groups that dedup would prune. +// AccountSourceIDs must be non-empty to prevent accidental cross-account +// grouping; the CLI ensures this by iterating sources one at a time when +// no explicit --account is given. +// +// Side effect (non-dry-run only): if the scoped sources contain messages +// with no rfc822_message_id but with stored MIME, Scan calls +// store.BackfillRFC822IDs to derive the column from the stored headers +// before grouping. The backfill is idempotent metadata derivation — it +// fills a previously-NULL column from data already on the row, never +// overwrites a non-NULL value, and changes no message content. It happens +// during scan (rather than during merge) because the duplicate groups it +// surfaces are needed for the report the user is shown before the merge +// confirmation. The dedup-batch backup-before-merge contract still holds: +// the backup is sized for the merge (which soft-deletes losers and +// transfers labels), not for this metadata catch-up. Dry-run mode skips +// the backfill and reports the count as a "would-backfill" preview. +func (e *Engine) Scan(ctx context.Context) (*Report, error) { + if len(e.config.AccountSourceIDs) == 0 { + return nil, fmt.Errorf("AccountSourceIDs must be non-empty; use per-source iteration for unscoped dedup") + } + + started := time.Now() + e.logger.Info("dedup scan start", + "account", e.config.Account, + "sources", len(e.config.AccountSourceIDs), + "is_collection", e.config.ScopeIsCollection, + "content_hash_fallback", e.config.ContentHashFallback, + "dry_run", e.config.DryRun, + ) + + count, err := e.store.CountMessagesWithoutRFC822ID( + e.config.AccountSourceIDs..., + ) + if err != nil { + return nil, fmt.Errorf("count messages without rfc822 id: %w", err) + } + + var backfilledCount int64 + if count > 0 && e.config.DryRun { + e.logger.Info( + "dry-run: backfill needed before dedup can run — "+ + "messages missing rfc822_message_id will be skipped", + "count", count) + backfilledCount = -count // negative signals "needed but skipped" + } else if count > 0 { + e.logger.Info("backfilling rfc822_message_id from stored MIME", + "count", count) + var backfillFailed int64 + backfilledCount, backfillFailed, err = e.store.BackfillRFC822IDs( + e.config.AccountSourceIDs, + func(done, total int64) { + e.logger.Info("backfill progress", + "done", done, "total", total) + }, + ) + if err != nil { + return nil, fmt.Errorf("backfill rfc822 ids: %w", err) + } + if backfilledCount > 0 { + e.logger.Info("backfilled rfc822_message_id", + "count", backfilledCount) + } + if backfillFailed > 0 { + e.logger.Warn("backfill: some messages could not be parsed", + "failed", backfillFailed) + } + } + + totalMessages, err := e.store.CountActiveMessages( + e.config.AccountSourceIDs..., + ) + if err != nil { + return nil, fmt.Errorf("count active messages: %w", err) + } + + storeGroups, err := e.store.FindDuplicatesByRFC822ID( + e.config.AccountSourceIDs..., + ) + if err != nil { + return nil, fmt.Errorf("find duplicates: %w", err) + } + + report := &Report{ + TotalMessages: totalMessages, + BySourcePair: make(map[string]int), + BackfilledCount: backfilledCount, + } + + for _, sg := range storeGroups { + if ctx.Err() != nil { + return nil, ctx.Err() + } + msgs, err := e.store.GetDuplicateGroupMessages( + sg.RFC822MessageID, e.config.AccountSourceIDs..., + ) + if err != nil { + return nil, fmt.Errorf( + "get group messages for %s: %w", + sg.RFC822MessageID, err, + ) + } + if len(msgs) < 2 { + continue + } + + group := DuplicateGroup{ + Key: sg.RFC822MessageID, + KeyType: "message-id", + } + for _, m := range msgs { + matched := false + if m.FromEmail != "" { + if addrs := e.config.IdentityAddressesBySource[m.SourceID]; addrs != nil { + _, matched = addrs[store.NormalizeIdentifierForCompare(m.FromEmail)] + } + } + group.Messages = append(group.Messages, DuplicateMessage{ + ID: m.ID, + SourceID: m.SourceID, + SourceType: m.SourceType, + SourceIdentifier: m.SourceIdentifier, + SourceMessageID: m.SourceMessageID, + Subject: m.Subject, + SentAt: m.SentAt, + HasRawMIME: m.HasRawMIME, + LabelCount: m.LabelCount, + ArchivedAt: m.ArchivedAt, + IsFromMe: m.IsFromMe, + HasSentLabel: m.HasSentLabel, + FromEmail: m.FromEmail, + MatchedIdentity: matched, + }) + } + + e.selectSurvivor(&group) + report.Groups = append(report.Groups, group) + report.BySourcePair[sourcePairKey(group.Messages)]++ + } + + if e.config.ContentHashFallback { + // Exclude only losers (messages already selected for pruning) from + // the content-hash pass, not survivors. A message missing + // Message-ID can legitimately match the content of a survivor that + // anchored a Message-ID group; survivors stay eligible so the + // second pass can link orphan rows back to that anchor. + // + // Survivors are tracked separately so we can guarantee a survivor + // of a Message-ID group cannot be demoted to a loser by the + // content-hash pass (which would silently prune it after labels + // from the Message-ID group's losers were already merged in). + excludeIDs := make(map[int64]bool, len(report.Groups)*2) + messageIDSurvivors := make(map[int64]bool, len(report.Groups)) + for _, g := range report.Groups { + for j, m := range g.Messages { + if j == g.Survivor { + messageIDSurvivors[m.ID] = true + continue + } + excludeIDs[m.ID] = true + } + } + + contentHashGroups, skipped, err := e.scanNormalizedHashGroups(excludeIDs) + if err != nil { + return nil, fmt.Errorf( + "scan normalized content hashes: %w", err, + ) + } + report.SkippedDecompressionErrors = skipped + for _, g := range contentHashGroups { + // Spec § Detection: "A content-hash group with two Message-ID + // survivors keeps both as winners (one per Message-ID group)." + // Count how many Message-ID-pass survivors landed in this group; + // if more than one, neither should be demoted — skip entirely. + // + // Spec § Survivor selection: "When any message in a duplicate + // group looks like a sent copy, only sent copies are eligible to + // survive." A separate skip handles the case where an MID + // survivor and a sent-copy orphan share a content hash — forcing + // the (non-sent) MID survivor in over the sent orphan would + // silently violate the eligibility filter. Sent-copy detection + // excludes the MID survivor itself; if the MID survivor is + // already a sent copy, the override only confirms selectSurvivor's + // choice, which is harmless. + midSurvivorCount := 0 + hasSentOrphan := false + for _, m := range g.Messages { + if messageIDSurvivors[m.ID] { + midSurvivorCount++ + continue + } + if m.IsSentCopy() { + hasSentOrphan = true + } + } + if midSurvivorCount > 1 { + continue + } + if midSurvivorCount >= 1 && hasSentOrphan { + continue + } + + // If this content-hash group contains exactly one Message-ID + // survivor that did not win the content-hash survivor selection, + // force that survivor to win. Demoting a survivor that has already + // absorbed labels from its Message-ID losers would silently destroy + // that union when MergeDuplicates soft-deletes the demoted survivor. + for j, m := range g.Messages { + if j == g.Survivor { + continue + } + if messageIDSurvivors[m.ID] { + g.Survivor = j + break + } + } + report.Groups = append(report.Groups, g) + report.ContentHashGroups++ + report.BySourcePair[sourcePairKey(g.Messages)]++ + } + } + + report.DuplicateGroups = len(report.Groups) + for _, g := range report.Groups { + report.DuplicateMessages += len(g.Messages) - 1 + } + + maxSamples := min(10, len(report.Groups)) + report.SampleGroups = append( + []DuplicateGroup(nil), report.Groups[:maxSamples]..., + ) + + e.logger.Info("dedup scan done", + "groups", report.DuplicateGroups, + "messages_to_prune", report.DuplicateMessages, + "content_hash_groups", report.ContentHashGroups, + "backfilled", report.BackfilledCount, + "skipped_decompression_errors", report.SkippedDecompressionErrors, + "duration_ms", time.Since(started).Milliseconds(), + ) + return report, nil +} + +// rawWorkItem carries one compressed raw-MIME blob to a worker. +type rawWorkItem struct { + candidate store.ContentHashCandidate + rawData []byte + compress string +} + +// hashResult carries the normalized hash plus message metadata. +type hashResult struct { + hash string + msg DuplicateMessage + skipped bool +} + +// scanNormalizedHashGroups hashes raw MIME after stripping transport-specific +// headers. It skips messages already matched by the primary Message-ID pass. +// Returns the duplicate groups plus a count of candidates skipped due to +// zlib decompression failure. +func (e *Engine) scanNormalizedHashGroups( + excludeIDs map[int64]bool, +) ([]DuplicateGroup, int, error) { + candidates, err := e.store.GetAllRawMIMECandidates( + e.config.AccountSourceIDs..., + ) + if err != nil { + return nil, 0, err + } + + candidateMap := make(map[int64]store.ContentHashCandidate, len(candidates)) + for _, c := range candidates { + if !excludeIDs[c.ID] { + candidateMap[c.ID] = c + } + } + if len(candidateMap) == 0 { + return nil, 0, nil + } + + ids := make([]int64, 0, len(candidateMap)) + for id := range candidateMap { + ids = append(ids, id) + } + sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) + + numWorkers := runtime.NumCPU() + if numWorkers > 16 { + numWorkers = 16 + } + if numWorkers > len(ids) { + numWorkers = len(ids) + } + if numWorkers < 1 { + numWorkers = 1 + } + + work := make(chan rawWorkItem, numWorkers*4) + results := make(chan hashResult, numWorkers*4) + const maxDecompressionWarns = 5 + var decompressionFailures atomic.Int32 + + var wg sync.WaitGroup + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for item := range work { + raw := item.rawData + if item.compress == "zlib" { + r, err := zlib.NewReader(bytes.NewReader(raw)) + if err != nil { + if decompressionFailures.Add(1) <= maxDecompressionWarns { + e.logger.Warn("content-hash: zlib open failed", + "message_id", item.candidate.ID, "err", err) + } + results <- hashResult{skipped: true} + continue + } + decompressed, err := io.ReadAll(r) + _ = r.Close() + if err != nil { + if decompressionFailures.Add(1) <= maxDecompressionWarns { + e.logger.Warn("content-hash: zlib read failed", + "message_id", item.candidate.ID, "err", err) + } + results <- hashResult{skipped: true} + continue + } + raw = decompressed + } + + matched := false + if item.candidate.FromEmail != "" { + if addrs := e.config.IdentityAddressesBySource[item.candidate.SourceID]; addrs != nil { + _, matched = addrs[store.NormalizeIdentifierForCompare(item.candidate.FromEmail)] + } + } + + results <- hashResult{ + hash: sha256Hex(normalizeRawMIME(raw)), + msg: DuplicateMessage{ + ID: item.candidate.ID, + SourceID: item.candidate.SourceID, + SourceType: item.candidate.SourceType, + SourceIdentifier: item.candidate.SourceIdentifier, + SourceMessageID: item.candidate.SourceMessageID, + Subject: item.candidate.Subject, + SentAt: item.candidate.SentAt, + HasRawMIME: true, + LabelCount: item.candidate.LabelCount, + ArchivedAt: item.candidate.ArchivedAt, + IsFromMe: item.candidate.IsFromMe, + HasSentLabel: item.candidate.HasSentLabel, + FromEmail: item.candidate.FromEmail, + MatchedIdentity: matched, + }, + } + } + }() + } + + type hashEntry struct { + msgs []DuplicateMessage + } + hashMap := make(map[string]*hashEntry) + skipped := 0 + collectDone := make(chan struct{}) + go func() { + for r := range results { + if r.skipped { + skipped++ + continue + } + if entry, ok := hashMap[r.hash]; ok { + entry.msgs = append(entry.msgs, r.msg) + } else { + hashMap[r.hash] = &hashEntry{msgs: []DuplicateMessage{r.msg}} + } + } + close(collectDone) + }() + + readErr := e.store.StreamMessageRaw( + ids, + func(messageID int64, rawData []byte, compression string) { + c, ok := candidateMap[messageID] + if !ok { + return + } + dataCopy := make([]byte, len(rawData)) + copy(dataCopy, rawData) + work <- rawWorkItem{ + candidate: c, + rawData: dataCopy, + compress: compression, + } + }, + ) + close(work) + wg.Wait() + close(results) + <-collectDone + + if readErr != nil { + return nil, skipped, fmt.Errorf("stream message raw: %w", readErr) + } + + var groups []DuplicateGroup + for hash, entry := range hashMap { + if len(entry.msgs) < 2 { + continue + } + g := DuplicateGroup{ + Key: hash, + KeyType: "normalized-hash", + Messages: entry.msgs, + } + e.selectSurvivor(&g) + groups = append(groups, g) + } + if skipped > maxDecompressionWarns { + e.logger.Warn("content-hash: additional zlib failures suppressed", + "suppressed", skipped-maxDecompressionWarns) + } + return groups, skipped, nil +} + +// transportHeaders vary across otherwise-identical copies of the same email. +var transportHeaders = map[string]bool{ + "Received": true, + "Delivered-To": true, + "Return-Path": true, + "X-Received": true, + "X-Gmail-Labels": true, + "X-Gmail-Received": true, + "X-Google-Smtp-Source": true, + "X-Gm-Message-State": true, + "Authentication-Results": true, + "Dkim-Signature": true, + "Arc-Seal": true, + "Arc-Message-Signature": true, + "Arc-Authentication-Results": true, + "X-Google-Dkim-Signature": true, + "X-Forwarded-To": true, + "X-Forwarded-For": true, + "X-Original-To": true, + "X-Apple-Mail-Labels": true, +} + +// normalizeRawMIME strips transport/export-specific headers before hashing. +func normalizeRawMIME(raw []byte) []byte { + crlfEnd := bytes.Index(raw, []byte("\r\n\r\n")) + lfEnd := bytes.Index(raw, []byte("\n\n")) + headerEnd := -1 + switch { + case crlfEnd >= 0 && lfEnd >= 0: + headerEnd = min(crlfEnd, lfEnd) + case crlfEnd >= 0: + headerEnd = crlfEnd + case lfEnd >= 0: + headerEnd = lfEnd + } + if headerEnd == -1 { + return raw + } + + headerSection := raw[:headerEnd] + // Find the start of the actual body after the blank line. + var bodyStart int + switch { + case bytes.HasPrefix(raw[headerEnd:], []byte("\r\n\r\n")): + bodyStart = headerEnd + 4 + case bytes.HasPrefix(raw[headerEnd:], []byte("\n\n")): + bodyStart = headerEnd + 2 + default: + return raw + } + body := raw[bodyStart:] + + // Copy headerSection before appending to avoid mutating the + // underlying raw buffer (headerSection is a sub-slice of raw). + hdrBuf := make([]byte, len(headerSection)+4) + copy(hdrBuf, headerSection) + copy(hdrBuf[len(headerSection):], "\r\n\r\n") + reader := textproto.NewReader(bufio.NewReader(bytes.NewReader(hdrBuf))) + mimeHeader, err := reader.ReadMIMEHeader() + if err != nil { + return raw + } + + var kept []string + for key := range mimeHeader { + if !transportHeaders[textproto.CanonicalMIMEHeaderKey(key)] { + kept = append(kept, key) + } + } + sort.Strings(kept) + + var buf bytes.Buffer + for _, key := range kept { + for _, val := range mimeHeader[key] { + fmt.Fprintf(&buf, "%s: %s\n", key, val) + } + } + buf.WriteString("\n") // canonical header/body separator + buf.Write(body) + return buf.Bytes() +} + +func sha256Hex(data []byte) string { + h := sha256.Sum256(data) + return hex.EncodeToString(h[:]) +} + +// selectSurvivor picks the best message to keep in a duplicate group. +func (e *Engine) selectSurvivor(group *DuplicateGroup) { + if len(group.Messages) <= 1 { + group.Survivor = 0 + return + } + + priorityMap := make(map[string]int) + for i, st := range e.config.SourcePreference { + priorityMap[st] = i + } + + candidates := allIndexes(len(group.Messages)) + var sentIdxs []int + for _, i := range candidates { + if group.Messages[i].IsSentCopy() { + sentIdxs = append(sentIdxs, i) + } + } + if len(sentIdxs) > 0 { + candidates = sentIdxs + } + + best := candidates[0] + for _, i := range candidates[1:] { + if e.isBetter( + group.Messages[i], group.Messages[best], priorityMap, + ) { + best = i + } + } + group.Survivor = best +} + +func allIndexes(n int) []int { + out := make([]int, n) + for i := range out { + out[i] = i + } + return out +} + +// isBetter returns true if candidate is a better survivor than current. +func (e *Engine) isBetter( + candidate, current DuplicateMessage, priorityMap map[string]int, +) bool { + candPri := sourcePriority(candidate.SourceType, priorityMap) + currPri := sourcePriority(current.SourceType, priorityMap) + if candPri != currPri { + return candPri < currPri + } + if candidate.HasRawMIME != current.HasRawMIME { + return candidate.HasRawMIME + } + if candidate.LabelCount != current.LabelCount { + return candidate.LabelCount > current.LabelCount + } + if !candidate.ArchivedAt.IsZero() && !current.ArchivedAt.IsZero() { + return candidate.ArchivedAt.Before(current.ArchivedAt) + } + return candidate.ID < current.ID +} + +func sourcePriority(sourceType string, priorityMap map[string]int) int { + if p, ok := priorityMap[sourceType]; ok { + return p + } + return len(priorityMap) +} + +// Execute merges every duplicate group: unions labels onto the +// survivor, soft-deletes the pruned duplicates, and — when +// DeleteDupsFromSourceServer is enabled AND a pruned copy shares a +// source_id with its survivor — writes a deletion manifest. +func (e *Engine) Execute( + ctx context.Context, report *Report, batchID string, +) (*ExecutionSummary, error) { + summary := &ExecutionSummary{BatchID: batchID} + + started := time.Now() + e.logger.Info("dedup execute start", + "batch", batchID, + "account", e.config.Account, + "groups", report.DuplicateGroups, + "messages_to_prune", report.DuplicateMessages, + "stage_remote_deletion", e.config.DeleteDupsFromSourceServer, + ) + + remoteByKey := make(map[remoteKey][]string) + + for i, group := range report.Groups { + if ctx.Err() != nil { + return summary, ctx.Err() + } + + survivor := group.Messages[group.Survivor] + survivorID := survivor.ID + var dupIDs []int64 + for j, m := range group.Messages { + if j == group.Survivor { + continue + } + dupIDs = append(dupIDs, m.ID) + + if !e.config.DeleteDupsFromSourceServer { + continue + } + if !remoteSourceTypes[m.SourceType] { + continue + } + if m.SourceID != survivor.SourceID { + continue + } + acct := m.SourceIdentifier + if acct == "" { + acct = e.config.Account + } + key := remoteKey{Account: acct, SourceType: m.SourceType} + remoteByKey[key] = append( + remoteByKey[key], m.SourceMessageID, + ) + } + + mergeResult, err := e.store.MergeDuplicates( + survivorID, dupIDs, batchID, + ) + if err != nil { + return summary, fmt.Errorf( + "merge group %d (%s): %w", i, group.Key, err, + ) + } + + summary.GroupsMerged++ + summary.MessagesRemoved += len(dupIDs) + summary.LabelsTransferred += mergeResult.LabelsTransferred + summary.RawMIMEBackfilled += mergeResult.RawMIMEBackfilled + } + + if e.config.DeleteDupsFromSourceServer && len(remoteByKey) > 0 { + staged, err := e.stageDeletionManifests(batchID, remoteByKey) + if err != nil { + return summary, err + } + summary.StagedManifests = staged + } + + e.logger.Info("dedup execute done", + "batch", batchID, + "groups_merged", summary.GroupsMerged, + "messages_removed", summary.MessagesRemoved, + "labels_transferred", summary.LabelsTransferred, + "raw_mime_backfilled", summary.RawMIMEBackfilled, + "staged_manifests", len(summary.StagedManifests), + "duration_ms", time.Since(started).Milliseconds(), + ) + return summary, nil +} + +func (e *Engine) stageDeletionManifests( + batchID string, + byKey map[remoteKey][]string, +) ([]StagedManifest, error) { + if e.config.DeletionsDir == "" { + return nil, fmt.Errorf( + "deletions dir not configured but " + + "DeleteDupsFromSourceServer is true", + ) + } + + mgr, err := deletion.NewManager(e.config.DeletionsDir) + if err != nil { + return nil, fmt.Errorf("open deletion manager: %w", err) + } + + keys := make([]remoteKey, 0, len(byKey)) + for k := range byKey { + keys = append(keys, k) + } + sort.Slice(keys, func(i, j int) bool { + if keys[i].Account != keys[j].Account { + return keys[i].Account < keys[j].Account + } + return keys[i].SourceType < keys[j].SourceType + }) + + // Single-type accounts keep the original manifest ID (no source-type + // suffix) so existing consumers — and test fixtures — don't see a + // rename. Only accounts contributing duplicates from more than one + // source type need disambiguation. + typesPerAccount := make(map[string]int) + for k := range byKey { + typesPerAccount[k.Account]++ + } + + var staged []StagedManifest + for _, k := range keys { + ids := dedupStrings(byKey[k]) + if len(ids) == 0 { + continue + } + + description := fmt.Sprintf("Dedup pruned duplicates (%s)", batchID) + manifest := deletion.NewManifest(description, ids) + if typesPerAccount[k.Account] > 1 { + manifest.ID = manifestIDFor(batchID, k.Account+"-"+k.SourceType) + } else { + manifest.ID = manifestIDFor(batchID, k.Account) + } + manifest.CreatedBy = "dedup" + manifest.Filters.Account = k.Account + + path := filepath.Join( + mgr.PendingDir(), manifest.ID+".json", + ) + if err := manifest.Save(path); err != nil { + return staged, fmt.Errorf( + "save manifest for %s: %w", k.Account, err, + ) + } + staged = append(staged, StagedManifest{ + Account: k.Account, + SourceType: k.SourceType, + ManifestID: manifest.ID, + MessageCount: len(ids), + }) + } + return staged, nil +} + +func manifestIDFor(batchID, account string) string { + return fmt.Sprintf("%s-%s", batchID, SanitizeFilenameComponent(account)) +} + +// SanitizeFilenameComponent strips or replaces characters that are unsafe +// for use in filenames, ensuring the result contains only alphanumeric, +// hyphens, and underscores (with @ and . replaced by hyphens). +func SanitizeFilenameComponent(a string) string { + var b strings.Builder + for _, r := range a { + switch { + case (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') || r == '-' || r == '_': + b.WriteRune(r) + case r == '@' || r == '.': + b.WriteRune('-') + } + } + s := b.String() + if s == "" { + s = "account" + } + if len(s) > 40 { + sum := sha256.Sum256([]byte(a)) + s = s[:31] + "-" + hex.EncodeToString(sum[:4]) + } + return s +} + +func dedupStrings(in []string) []string { + seen := make(map[string]bool, len(in)) + out := make([]string, 0, len(in)) + for _, s := range in { + if seen[s] { + continue + } + seen[s] = true + out = append(out, s) + } + sort.Strings(out) + return out +} + +// Undo restores every message with the given batch ID and cancels any +// pending deletion manifests that dedup created for that batch. +// +// Manifest cancellation is best-effort: if cancelling one manifest +// fails, the remaining manifests are still attempted, and any errors +// are joined into a single returned error alongside the restored row +// count and the list of manifests already in progress. +func (e *Engine) Undo(batchID string) (int64, []string, error) { + started := time.Now() + e.logger.Info("dedup undo start", "batch", batchID) + + restored, err := e.store.UndoDedup(batchID) + if err != nil { + e.logger.Warn("dedup undo failed", + "batch", batchID, + "duration_ms", time.Since(started).Milliseconds(), + "error", err.Error(), + ) + return 0, nil, err + } + + if e.config.DeletionsDir == "" { + e.logger.Info("dedup undo done", + "batch", batchID, + "restored", restored, + "manifests_cancelled", 0, + "manifests_still_running", 0, + "duration_ms", time.Since(started).Milliseconds(), + ) + return restored, nil, nil + } + + mgr, err := deletion.NewManager(e.config.DeletionsDir) + if err != nil { + return restored, nil, fmt.Errorf("open deletion manager: %w", err) + } + pending, err := mgr.ListPending() + if err != nil { + return restored, nil, fmt.Errorf("list pending: %w", err) + } + inProgress, err := mgr.ListInProgress() + if err != nil { + return restored, nil, fmt.Errorf("list in-progress: %w", err) + } + + var stillExecuting []string + var cancelErrs []error + cancelled := 0 + prefix := batchID + "-" + for _, m := range pending { + if !strings.HasPrefix(m.ID, prefix) { + continue + } + if err := mgr.CancelManifest(m.ID); err != nil { + cancelErrs = append(cancelErrs, fmt.Errorf( + "cancel manifest %s: %w", m.ID, err, + )) + continue + } + cancelled++ + } + for _, m := range inProgress { + if !strings.HasPrefix(m.ID, prefix) { + continue + } + stillExecuting = append(stillExecuting, m.ID) + } + e.logger.Info("dedup undo done", + "batch", batchID, + "restored", restored, + "manifests_cancelled", cancelled, + "manifests_still_running", len(stillExecuting), + "cancel_errors", len(cancelErrs), + "duration_ms", time.Since(started).Milliseconds(), + ) + if len(cancelErrs) > 0 { + return restored, stillExecuting, errors.Join(cancelErrs...) + } + return restored, stillExecuting, nil +} + +// FormatReport renders a human-readable report of the scan results. +func (e *Engine) FormatReport(r *Report) string { + var sb strings.Builder + sb.WriteString("\n=== Deduplication Report ===\n\n") + + if r.BackfilledCount < 0 { + fmt.Fprintf(&sb, + "Note: %d messages need RFC822 Message-ID backfill "+ + "from stored MIME (skipped in dry-run).\n"+ + "These messages will be backfilled and included "+ + "when you re-run without --dry-run.\n\n", + -r.BackfilledCount) + } else if r.BackfilledCount > 0 { + fmt.Fprintf(&sb, + "Backfilled %d messages with RFC822 Message-ID "+ + "from stored MIME.\n\n", + r.BackfilledCount) + } + + if r.DuplicateGroups == 0 { + sb.WriteString("No duplicates found.\n") + return sb.String() + } + + fmt.Fprintf(&sb, "Duplicate groups found: %d\n", r.DuplicateGroups) + fmt.Fprintf(&sb, "Messages to prune: %d\n", r.DuplicateMessages) + if r.ContentHashGroups > 0 { + fmt.Fprintf(&sb, "Content-hash groups: %d\n", r.ContentHashGroups) + } + if r.SkippedDecompressionErrors > 0 { + fmt.Fprintf(&sb, + "Skipped (decompression error): %d "+ + "(see log for per-message details)\n", + r.SkippedDecompressionErrors) + } + + if len(r.BySourcePair) > 0 { + sb.WriteString("\nBreakdown by source pair:\n") + pairs := make([]string, 0, len(r.BySourcePair)) + for k := range r.BySourcePair { + pairs = append(pairs, k) + } + sort.Strings(pairs) + for _, pair := range pairs { + fmt.Fprintf(&sb, " %-20s %d groups\n", + pair, r.BySourcePair[pair]) + } + } + + sentGroups := 0 + for _, g := range r.Groups { + for _, m := range g.Messages { + if m.IsSentCopy() { + sentGroups++ + break + } + } + } + if sentGroups > 0 { + fmt.Fprintf(&sb, + "\nSent-copy groups detected: %d "+ + "(survivor forced to a sent copy)\n", + sentGroups) + } + + if len(r.SampleGroups) > 0 { + sb.WriteString("\nSample duplicate groups:\n") + for i, g := range r.SampleGroups { + label := g.Key + if g.KeyType != "" && g.KeyType != "message-id" { + label = fmt.Sprintf("%s (%s)", g.Key, g.KeyType) + } + fmt.Fprintf(&sb, "\n Group %d: %s\n", i+1, label) + for j, m := range g.Messages { + marker := " " + if j == g.Survivor { + marker = "* " + } + sent := "" + if m.IsSentCopy() { + sent = " [sent]" + } + fmt.Fprintf(&sb, + " %s[%s:%s]%s %s "+ + "(labels: %d, raw: %v)\n", + marker, m.SourceType, m.SourceIdentifier, + sent, m.Subject, m.LabelCount, m.HasRawMIME, + ) + } + } + } + + return sb.String() +} + +// FormatMethodology returns a detailed explanation of how dedup works. +func (e *Engine) FormatMethodology() string { + var sb strings.Builder + sb.WriteString("\n=== Deduplication Methodology ===\n\n") + + sb.WriteString("Scope:\n") + switch { + case e.config.ScopeIsCollection && len(e.config.AccountSourceIDs) > 1: + fmt.Fprintf(&sb, + " Scoped to collection: %s (%d account(s)). "+ + "Cross-account dedup\n"+ + " is enabled within this collection.\n", + e.config.Account, len(e.config.AccountSourceIDs)) + case e.config.ScopeIsCollection: + // Single-member collection — wording matches the + // account-scope branch since no cross-account merging can + // happen with one source. + fmt.Fprintf(&sb, + " Scoped to collection: %s (1 account). Source "+ + "boundaries are\n never crossed (collection has "+ + "only one member).\n", + e.config.Account) + case e.config.Account != "": + fmt.Fprintf(&sb, + " Scoped to account: %s. Source boundaries are "+ + "never crossed.\n", + e.config.Account) + case len(e.config.AccountSourceIDs) > 0: + fmt.Fprintf(&sb, + " Scoped to %d source(s). Source boundaries are "+ + "never crossed.\n", + len(e.config.AccountSourceIDs)) + default: + sb.WriteString( + " No scope specified — only messages that appear " + + "twice in the\n" + + " SAME account are eligible. To compare across " + + "accounts, group\n" + + " them in a collection and rerun with " + + "--collection .\n", + ) + } + sb.WriteString("\n") + + sb.WriteString("Detection:\n") + sb.WriteString(" Message-ID is primary; content-hash is a " + + "supplementary fallback.\n") + sb.WriteString(" Messages are grouped by the RFC822 Message-ID " + + "header.\n") + sb.WriteString(" Messages missing that header are backfilled " + + "from stored MIME\n") + sb.WriteString(" before the scan runs.") + if e.config.ContentHashFallback { + sb.WriteString(" Every remaining message with stored MIME is then compared via\n") + sb.WriteString(" a normalized raw-MIME hash that strips transport " + + "headers such as\n") + sb.WriteString(" Received, Delivered-To, X-Gmail-Labels, and " + + "DKIM/ARC traces.\n") + sb.WriteString(" The hash is byte-sensitive below the header " + + "boundary, so two\n") + sb.WriteString(" messages whose bodies differ only in line-ending " + + "style (CRLF vs LF)\n") + sb.WriteString(" will not match via content-hash.\n\n") + } else { + sb.WriteString(" Messages still without an ID are ignored.\n\n") + } + + sb.WriteString("Survivor selection:\n") + for i, st := range e.config.SourcePreference { + fmt.Fprintf(&sb, " %d. %s\n", i+1, st) + } + sb.WriteString(" Tiebreakers: has raw MIME > more labels > " + + "earlier archived_at > lower id.\n\n") + + sb.WriteString("Sent messages:\n") + if e.config.ScopeIsCollection && len(e.config.AccountSourceIDs) > 1 { + sb.WriteString( + " Collection mode INTENTIONALLY merges messages " + + "across the accounts in this\n" + + " collection. A message alice sent to bob will " + + "have one copy in alice's\n" + + " Sent folder and one in bob's Inbox; if both " + + "accounts are members of\n" + + " this collection, the loser will be hidden " + + "locally (reversible via\n" + + " --undo). Remote deletion remains " + + "same-source-only and will not\n" + + " touch a different account's mailbox. Only " + + "add accounts to a collection\n" + + " when you actually want their copies merged.\n\n", + ) + } else { + sb.WriteString( + " Dedup NEVER merges messages across different " + + "accounts. A message that\n" + + " alice sent to bob is two distinct mailbox " + + "copies — one in alice's\n" + + " Sent folder and one in bob's Inbox. Both are " + + "preserved independently\n" + + " because deleting either would alter the other " + + "user's archive.\n\n", + ) + } + + sb.WriteString("Merge behaviour:\n") + sb.WriteString(" - Labels from every copy are unioned onto " + + "the survivor.\n") + sb.WriteString(" - Raw MIME is backfilled onto the survivor " + + "if it lacks it.\n") + sb.WriteString(" - Only raw MIME is backfilled; parsed " + + "message_bodies are not.\n") + sb.WriteString(" If a survivor is missing text for display, run\n") + sb.WriteString(" 'msgvault repair-encoding' or " + + "'msgvault build-cache --full-rebuild'.\n") + sb.WriteString(" - Pruned duplicates are hidden in the msgvault " + + "database (reversible via --undo).\n") + sb.WriteString(" - Remote mailboxes (Gmail, IMAP) are NEVER " + + "modified by default.\n") + + return sb.String() +} + +func sourcePairKey(msgs []DuplicateMessage) string { + types := make(map[string]bool) + for _, m := range msgs { + types[m.SourceType] = true + } + sorted := make([]string, 0, len(types)) + for t := range types { + sorted = append(sorted, t) + } + sort.Strings(sorted) + return strings.Join(sorted, "+") +} diff --git a/internal/dedup/dedup_test.go b/internal/dedup/dedup_test.go new file mode 100644 index 00000000..e5aadbdc --- /dev/null +++ b/internal/dedup/dedup_test.go @@ -0,0 +1,744 @@ +package dedup_test + +import ( + "context" + "database/sql" + "path/filepath" + "strings" + "testing" + + "github.com/wesm/msgvault/internal/dedup" + "github.com/wesm/msgvault/internal/deletion" + "github.com/wesm/msgvault/internal/store" + "github.com/wesm/msgvault/internal/testutil" + "github.com/wesm/msgvault/internal/testutil/storetest" +) + +func addMessage( + t *testing.T, + st *store.Store, + source *store.Source, + srcMsgID, rfc822ID string, + fromMe bool, +) int64 { + t.Helper() + convID, err := st.EnsureConversation( + source.ID, "thread-"+srcMsgID, "Subject", + ) + testutil.MustNoErr(t, err, "EnsureConversation") + id, err := st.UpsertMessage(&store.Message{ + ConversationID: convID, + SourceID: source.ID, + SourceMessageID: srcMsgID, + RFC822MessageID: sql.NullString{ + String: rfc822ID, Valid: rfc822ID != "", + }, + MessageType: "email", + IsFromMe: fromMe, + SizeEstimate: 1000, + }) + testutil.MustNoErr(t, err, "UpsertMessage") + return id +} + +func assertSoftDeleted( + t *testing.T, st *store.Store, msgID int64, wantDeleted bool, +) { + t.Helper() + var deletedAt sql.NullTime + err := st.DB().QueryRow( + "SELECT deleted_at FROM messages WHERE id = ?", msgID, + ).Scan(&deletedAt) + testutil.MustNoErr(t, err, "query deleted_at") + if wantDeleted && !deletedAt.Valid { + t.Errorf("message %d: deleted_at should be set", msgID) + } + if !wantDeleted && deletedAt.Valid { + t.Errorf("message %d: deleted_at should be NULL", msgID) + } +} + +func linkLabel( + t *testing.T, + st *store.Store, + sourceID, msgID int64, + sourceLabelID, name, typ string, +) { + t.Helper() + lid, err := st.EnsureLabel(sourceID, sourceLabelID, name, typ) + testutil.MustNoErr(t, err, "EnsureLabel "+sourceLabelID) + testutil.MustNoErr(t, + st.LinkMessageLabel(msgID, lid), + "LinkMessageLabel "+sourceLabelID, + ) +} + +func TestEngine_Scan_UnionsLabelsOnSurvivor(t *testing.T) { + f := storetest.New(t) + st := f.Store + gmail := f.Source + + mbox, err := st.GetOrCreateSource("mbox", "test@example.com-mbox") + testutil.MustNoErr(t, err, "GetOrCreateSource mbox") + + idGmail := addMessage(t, st, gmail, "gmail-1", "rfc-union", false) + idMbox := addMessage(t, st, mbox, "mbox-1", "rfc-union", false) + + linkLabel(t, st, gmail.ID, idGmail, "INBOX", "Inbox", "system") + linkLabel(t, st, mbox.ID, idMbox, "Archive", "Archive", "user") + linkLabel(t, st, mbox.ID, idMbox, "Work", "Work", "user") + + eng := dedup.NewEngine(st, dedup.Config{ + AccountSourceIDs: []int64{gmail.ID, mbox.ID}, + Account: "test@example.com", + }, nil) + + report, err := eng.Scan(context.Background()) + testutil.MustNoErr(t, err, "Scan") + if report.DuplicateGroups != 1 { + t.Fatalf("groups = %d, want 1", report.DuplicateGroups) + } + if report.DuplicateMessages != 1 { + t.Fatalf("prune count = %d, want 1", report.DuplicateMessages) + } + + group := report.Groups[0] + survivor := group.Messages[group.Survivor] + if survivor.ID != idGmail { + t.Errorf("survivor = %d, want %d (gmail)", survivor.ID, idGmail) + } + + summary, err := eng.Execute( + context.Background(), report, "batch-union", + ) + testutil.MustNoErr(t, err, "Execute") + if summary.GroupsMerged != 1 { + t.Errorf("groupsMerged = %d, want 1", summary.GroupsMerged) + } + + f.AssertLabelCount(idGmail, 3) + assertSoftDeleted(t, st, idMbox, true) +} + +func TestEngine_Scan_RejectsEmptyAccountSourceIDs(t *testing.T) { + f := storetest.New(t) + st := f.Store + + cases := []struct { + name string + ids []int64 + }{ + {"nil", nil}, + {"empty slice", []int64{}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + eng := dedup.NewEngine(st, dedup.Config{ + AccountSourceIDs: tc.ids, + }, nil) + _, err := eng.Scan(context.Background()) + if err == nil { + t.Fatal("expected error for empty AccountSourceIDs") + } + if !strings.Contains(err.Error(), "AccountSourceIDs must be non-empty") { + t.Errorf("unexpected error: %v", err) + } + }) + } +} + +func TestEngine_SurvivorFavorsSentCopy(t *testing.T) { + f := storetest.New(t) + st := f.Store + gmail := f.Source + + idInbox := addMessage(t, st, gmail, "inbox-sent", "rfc-sent", false) + idSent := addMessage(t, st, gmail, "sent-sent", "rfc-sent", true) + + linkLabel(t, st, gmail.ID, idInbox, "INBOX", "Inbox", "system") + linkLabel(t, st, gmail.ID, idSent, "SENT", "Sent", "system") + + eng := dedup.NewEngine(st, dedup.Config{ + AccountSourceIDs: []int64{gmail.ID}, + Account: "test@example.com", + }, nil) + + report, err := eng.Scan(context.Background()) + testutil.MustNoErr(t, err, "Scan") + if report.DuplicateGroups != 1 { + t.Fatalf("groups = %d, want 1", report.DuplicateGroups) + } + + group := report.Groups[0] + survivor := group.Messages[group.Survivor] + if survivor.ID != idSent { + t.Errorf("survivor = %d, want sent copy %d", + survivor.ID, idSent) + } + if !survivor.IsSentCopy() { + t.Errorf("survivor should be a sent copy") + } +} + +func TestEngine_DefaultConfig_NeverStagesRemote(t *testing.T) { + f := storetest.New(t) + st := f.Store + gmail := f.Source + + _ = addMessage(t, st, gmail, "g-1", "rfc-default", false) + _ = addMessage(t, st, gmail, "g-2", "rfc-default", false) + + deletionsDir := filepath.Join(t.TempDir(), "deletions") + eng := dedup.NewEngine(st, dedup.Config{ + AccountSourceIDs: []int64{gmail.ID}, + Account: "test@example.com", + DeletionsDir: deletionsDir, + }, nil) + + report, err := eng.Scan(context.Background()) + testutil.MustNoErr(t, err, "Scan") + summary, err := eng.Execute( + context.Background(), report, "batch-default", + ) + testutil.MustNoErr(t, err, "Execute") + + if summary.MessagesRemoved != 1 { + t.Errorf("messagesRemoved = %d, want 1", summary.MessagesRemoved) + } + if len(summary.StagedManifests) != 0 { + t.Errorf("stagedManifests = %d, want 0", len(summary.StagedManifests)) + } + + mgr, err := deletion.NewManager(deletionsDir) + testutil.MustNoErr(t, err, "NewManager") + pending, err := mgr.ListPending() + testutil.MustNoErr(t, err, "ListPending") + if len(pending) != 0 { + t.Errorf("pending manifests = %d, want 0", len(pending)) + } +} + +func TestEngine_OptIn_StagesOnlyWithinSameSourceID(t *testing.T) { + f := storetest.New(t) + st := f.Store + gmail := f.Source + + otherGmail, err := st.GetOrCreateSource("gmail", "other@example.com") + testutil.MustNoErr(t, err, "GetOrCreateSource otherGmail") + mbox, err := st.GetOrCreateSource("mbox", "local.mbox") + testutil.MustNoErr(t, err, "GetOrCreateSource mbox") + + idWinner := addMessage(t, st, gmail, "g-1", "rfc-opt", false) + idLoser := addMessage(t, st, gmail, "g-2", "rfc-opt", false) + idOther := addMessage(t, st, otherGmail, "g-3", "rfc-opt", false) + idMbox := addMessage(t, st, mbox, "m-1", "rfc-opt", false) + + deletionsDir := filepath.Join(t.TempDir(), "deletions") + eng := dedup.NewEngine(st, dedup.Config{ + AccountSourceIDs: []int64{gmail.ID, otherGmail.ID, mbox.ID}, + Account: "pile", + DeleteDupsFromSourceServer: true, + DeletionsDir: deletionsDir, + }, nil) + + report, err := eng.Scan(context.Background()) + testutil.MustNoErr(t, err, "Scan") + summary, err := eng.Execute( + context.Background(), report, "batch-opt", + ) + testutil.MustNoErr(t, err, "Execute") + + if summary.MessagesRemoved != 3 { + t.Errorf("messagesRemoved = %d, want 3", summary.MessagesRemoved) + } + assertSoftDeleted(t, st, idWinner, false) + assertSoftDeleted(t, st, idLoser, true) + assertSoftDeleted(t, st, idOther, true) + assertSoftDeleted(t, st, idMbox, true) + + if len(summary.StagedManifests) != 1 { + t.Fatalf("stagedManifests = %d, want 1", len(summary.StagedManifests)) + } + sm := summary.StagedManifests[0] + if sm.Account != "test@example.com" { + t.Errorf("staged account = %q, want test@example.com", sm.Account) + } + if sm.MessageCount != 1 { + t.Errorf("staged count = %d, want 1", sm.MessageCount) + } + + mgr, err := deletion.NewManager(deletionsDir) + testutil.MustNoErr(t, err, "NewManager") + pending, err := mgr.ListPending() + testutil.MustNoErr(t, err, "ListPending") + if len(pending) != 1 { + t.Fatalf("pending = %d, want 1", len(pending)) + } + if len(pending[0].GmailIDs) != 1 || pending[0].GmailIDs[0] != "g-2" { + t.Errorf("manifest GmailIDs = %v, want [g-2]", pending[0].GmailIDs) + } + + restored, stillExec, err := eng.Undo("batch-opt") + testutil.MustNoErr(t, err, "Undo") + if restored != 3 { + t.Errorf("restored = %d, want 3", restored) + } + if len(stillExec) != 0 { + t.Errorf("stillExec = %v, want empty", stillExec) + } + pending, err = mgr.ListPending() + testutil.MustNoErr(t, err, "ListPending after undo") + if len(pending) != 0 { + t.Errorf("pending after undo = %d, want 0", len(pending)) + } +} + +func TestEngine_ScopedToSingleSource_IgnoresCrossAccount(t *testing.T) { + f := storetest.New(t) + st := f.Store + alice := f.Source + + bob, err := st.GetOrCreateSource("gmail", "bob@example.com") + testutil.MustNoErr(t, err, "GetOrCreateSource bob") + + addMessage(t, st, alice, "a-1", "rfc-cross", true) + addMessage(t, st, bob, "b-1", "rfc-cross", false) + + eng := dedup.NewEngine(st, dedup.Config{ + AccountSourceIDs: []int64{alice.ID}, + Account: "test@example.com", + }, nil) + report, err := eng.Scan(context.Background()) + testutil.MustNoErr(t, err, "Scan") + if report.DuplicateGroups != 0 { + t.Errorf("cross-account dedup happened: groups = %d", + report.DuplicateGroups) + } +} + +func TestEngine_ContentHashFallbackFindsNormalizedDuplicates(t *testing.T) { + f := storetest.New(t) + st := f.Store + gmail := f.Source + + mbox, err := st.GetOrCreateSource("mbox", "test@example.com-mbox") + testutil.MustNoErr(t, err, "GetOrCreateSource mbox") + + id1 := addMessage(t, st, gmail, "hash-1", "", false) + id2 := addMessage(t, st, mbox, "hash-2", "", false) + + raw1 := []byte("Received: from mx1.google.com\r\nDelivered-To: one@example.com\r\nX-Gmail-Labels: INBOX\r\nFrom: sender@example.com\r\nSubject: Meeting tomorrow\r\nDate: Mon, 1 Jan 2024 12:00:00 +0000\r\n\r\nLet's meet tomorrow at 3pm.") + raw2 := []byte("Received: from mx2.google.com\r\nDelivered-To: two@example.com\r\nX-Gmail-Labels: SENT\r\nAuthentication-Results: spf=pass\r\nFrom: sender@example.com\r\nSubject: Meeting tomorrow\r\nDate: Mon, 1 Jan 2024 12:00:00 +0000\r\n\r\nLet's meet tomorrow at 3pm.") + testutil.MustNoErr(t, st.UpsertMessageRaw(id1, raw1), "UpsertMessageRaw id1") + testutil.MustNoErr(t, st.UpsertMessageRaw(id2, raw2), "UpsertMessageRaw id2") + + eng := dedup.NewEngine(st, dedup.Config{ + AccountSourceIDs: []int64{gmail.ID, mbox.ID}, + Account: "test@example.com", + ContentHashFallback: true, + }, nil) + + report, err := eng.Scan(context.Background()) + testutil.MustNoErr(t, err, "Scan") + if report.DuplicateGroups != 1 { + t.Fatalf("groups = %d, want 1", report.DuplicateGroups) + } + if report.ContentHashGroups != 1 { + t.Fatalf("contentHashGroups = %d, want 1", report.ContentHashGroups) + } + if got := report.Groups[0].KeyType; got != "normalized-hash" { + t.Fatalf("keyType = %q, want normalized-hash", got) + } +} + +// TestEngine_ContentHash_TwoMessageIDSurvivors_BothPreserved verifies the +// spec contract: "A content-hash group with two Message-ID survivors keeps +// both as winners (one per Message-ID group)." +// +// Four messages, two distinct RFC822 Message-IDs (two messages each). All +// four carry raw MIME that normalizes to the same content hash, so the +// content-hash pass would ordinarily group the two survivors together. +// The correct behaviour is to skip that content-hash group entirely — +// total losers must equal 2 (one per MID group), never 3. +func TestEngine_ContentHash_TwoMessageIDSurvivors_BothPreserved(t *testing.T) { + f := storetest.New(t) + st := f.Store + gmail := f.Source + + // Two MID groups, two messages each. + idA1 := addMessage(t, st, gmail, "src-a1", "mid-A", false) + idA2 := addMessage(t, st, gmail, "src-a2", "mid-A", false) + idB1 := addMessage(t, st, gmail, "src-b1", "mid-B", false) + idB2 := addMessage(t, st, gmail, "src-b2", "mid-B", false) + + // All four messages share the same normalized content (stripped headers + // differ, canonical From/Subject/Date/body are identical) so both + // Message-ID survivors land in the same content-hash group. + makeRaw := func(received, delivered, labels string) []byte { + return []byte( + "Received: " + received + "\r\n" + + "Delivered-To: " + delivered + "\r\n" + + "X-Gmail-Labels: " + labels + "\r\n" + + "From: sender@example.com\r\n" + + "Subject: Two MID survivors\r\n" + + "Date: Mon, 1 Jan 2024 12:00:00 +0000\r\n" + + "\r\n" + + "Body that is identical across all four copies.", + ) + } + testutil.MustNoErr(t, st.UpsertMessageRaw(idA1, makeRaw("mx1.google.com", "a1@example.com", "INBOX")), "raw A1") + testutil.MustNoErr(t, st.UpsertMessageRaw(idA2, makeRaw("mx2.google.com", "a2@example.com", "SENT")), "raw A2") + testutil.MustNoErr(t, st.UpsertMessageRaw(idB1, makeRaw("mx3.google.com", "b1@example.com", "INBOX")), "raw B1") + testutil.MustNoErr(t, st.UpsertMessageRaw(idB2, makeRaw("mx4.google.com", "b2@example.com", "SENT")), "raw B2") + + eng := dedup.NewEngine(st, dedup.Config{ + AccountSourceIDs: []int64{gmail.ID}, + Account: "test@example.com", + ContentHashFallback: true, + }, nil) + + report, err := eng.Scan(context.Background()) + testutil.MustNoErr(t, err, "Scan") + + // Two MID groups, no content-hash group (the group with two MID + // survivors must be skipped, not appended). + if report.DuplicateGroups != 2 { + t.Errorf("DuplicateGroups = %d, want 2", report.DuplicateGroups) + } + if report.ContentHashGroups != 0 { + t.Errorf("ContentHashGroups = %d, want 0 (MID-survivor group must be skipped)", report.ContentHashGroups) + } + // One loser per MID group; the buggy code yields 3 by demoting one + // Message-ID survivor. + if report.DuplicateMessages != 2 { + t.Errorf("DuplicateMessages = %d, want 2 (one loser per MID group)", report.DuplicateMessages) + } +} + +// TestEngine_ContentHash_MIDSurvivorAndSentOrphan_SkipsGroup verifies that the +// content-hash pass does not demote a sent-copy orphan to a loser by forcing +// a non-sent Message-ID survivor to win the content-hash group. Per spec +// § Survivor selection: "When any message in a duplicate group looks like a +// sent copy, only sent copies are eligible to survive." +// +// Three messages: two share rfc822_message_id "mid-A" (neither is_from_me), +// one is a sent orphan (no Message-ID, is_from_me=true). All three carry raw +// MIME that normalizes to the same content hash. The MID-pass survivor would +// otherwise be forced in over the sent orphan; the new skip rule prevents +// that. +func TestEngine_ContentHash_MIDSurvivorAndSentOrphan_SkipsGroup(t *testing.T) { + f := storetest.New(t) + st := f.Store + gmail := f.Source + + // MID group: two messages sharing mid-A, neither is_from_me. + idA1 := addMessage(t, st, gmail, "src-a1", "mid-A", false) + idA2 := addMessage(t, st, gmail, "src-a2", "mid-A", false) + + // Sent orphan: no MID, is_from_me=true. Content matches the MID group. + idOrphan := addMessage(t, st, gmail, "src-orphan", "", true) + + makeRaw := func(received string) []byte { + return []byte( + "Received: " + received + "\r\n" + + "From: sender@example.com\r\n" + + "Subject: MID/sent-orphan collision\r\n" + + "Date: Mon, 1 Jan 2024 12:00:00 +0000\r\n" + + "\r\n" + + "Identical body across all three copies.", + ) + } + testutil.MustNoErr(t, st.UpsertMessageRaw(idA1, makeRaw("mx1.google.com")), "raw a1") + testutil.MustNoErr(t, st.UpsertMessageRaw(idA2, makeRaw("mx2.google.com")), "raw a2") + testutil.MustNoErr(t, st.UpsertMessageRaw(idOrphan, makeRaw("mx3.google.com")), "raw orphan") + + eng := dedup.NewEngine(st, dedup.Config{ + AccountSourceIDs: []int64{gmail.ID}, + Account: "test@example.com", + ContentHashFallback: true, + }, nil) + + report, err := eng.Scan(context.Background()) + testutil.MustNoErr(t, err, "Scan") + + // Expect exactly one duplicate group (the MID group). The content-hash + // group must be skipped to preserve the sent-copy eligibility filter. + if report.DuplicateGroups != 1 { + t.Fatalf("DuplicateGroups = %d, want 1 (only the MID group)", report.DuplicateGroups) + } + if report.ContentHashGroups != 0 { + t.Errorf("ContentHashGroups = %d, want 0 (sent-orphan collision must be skipped)", report.ContentHashGroups) + } + // One loser from the MID group; the sent orphan stays live. + if report.DuplicateMessages != 1 { + t.Errorf("DuplicateMessages = %d, want 1 (one MID loser; orphan untouched)", report.DuplicateMessages) + } + + // Per the spec audit recommendation, pin that the orphan did not leak + // into the surviving MID group's Messages slice. The MID group must + // contain only the two MID-sharing rows. + mid := report.Groups[0] + if mid.KeyType != "message-id" { + t.Fatalf("Groups[0].KeyType = %q, want \"message-id\"", mid.KeyType) + } + if len(mid.Messages) != 2 { + t.Fatalf("MID group Messages len = %d, want 2", len(mid.Messages)) + } + for _, m := range mid.Messages { + if m.ID == idOrphan { + t.Errorf("sent orphan id=%d leaked into MID group Messages — must stay out", idOrphan) + } + } +} + +func TestEngine_ContentHashFallbackDisabledByDefault(t *testing.T) { + f := storetest.New(t) + st := f.Store + gmail := f.Source + + mbox, err := st.GetOrCreateSource("mbox", "test@example.com-mbox") + testutil.MustNoErr(t, err, "GetOrCreateSource mbox") + + id1 := addMessage(t, st, gmail, "hash-off-1", "", false) + id2 := addMessage(t, st, mbox, "hash-off-2", "", false) + raw := []byte("Subject: No Message-ID\r\n\r\nIdentical body") + testutil.MustNoErr(t, st.UpsertMessageRaw(id1, raw), "UpsertMessageRaw id1") + testutil.MustNoErr(t, st.UpsertMessageRaw(id2, raw), "UpsertMessageRaw id2") + + eng := dedup.NewEngine(st, dedup.Config{ + AccountSourceIDs: []int64{gmail.ID, mbox.ID}, + Account: "test@example.com", + }, nil) + + report, err := eng.Scan(context.Background()) + testutil.MustNoErr(t, err, "Scan") + if report.DuplicateGroups != 0 { + t.Fatalf("groups = %d, want 0", report.DuplicateGroups) + } +} + +func TestEngine_FormatMethodology_MentionsSentPolicy(t *testing.T) { + f := storetest.New(t) + eng := dedup.NewEngine(f.Store, dedup.Config{ + Account: "test@example.com", + AccountSourceIDs: []int64{f.Source.ID}, + }, nil) + out := eng.FormatMethodology() + if !strings.Contains( + strings.ToLower(out), + "never merges messages across different", + ) { + t.Errorf("methodology missing cross-account guarantee") + } +} + +// TestEngine_FormatMethodology_SingleMemberCollection asserts that a +// `--collection` invocation with only one resolved source does NOT +// describe itself as cross-account. Regression test for iter14 +// claude Low: ScopeIsCollection alone gated the cross-account +// wording, even when len(AccountSourceIDs) == 1 made cross-account +// merging impossible. +func TestEngine_FormatMethodology_SingleMemberCollection(t *testing.T) { + f := storetest.New(t) + eng := dedup.NewEngine(f.Store, dedup.Config{ + Account: "myCollection", + AccountSourceIDs: []int64{f.Source.ID}, + ScopeIsCollection: true, + }, nil) + out := eng.FormatMethodology() + lower := strings.ToLower(out) + if strings.Contains(lower, "cross-account dedup\n is enabled") { + t.Errorf("single-member collection should not advertise cross-account dedup; got:\n%s", out) + } + if strings.Contains(lower, "intentionally merges messages") { + t.Errorf("single-member collection should not describe intentional cross-account merging; got:\n%s", out) + } + if !strings.Contains(lower, "never merges messages across different") { + t.Errorf("single-member collection should fall to the same-account guarantee; got:\n%s", out) + } +} + +func TestEngine_SurvivorTiebreakers(t *testing.T) { + t.Run("raw MIME wins over no raw MIME", func(t *testing.T) { + f := storetest.New(t) + st := f.Store + + idNoRaw := addMessage(t, st, f.Source, "no-raw", "rfc-raw-tie", false) + idHasRaw := addMessage(t, st, f.Source, "has-raw", "rfc-raw-tie", false) + testutil.MustNoErr(t, + st.UpsertMessageRaw(idHasRaw, []byte("Subject: test\r\n\r\nBody")), + "UpsertMessageRaw", + ) + + eng := dedup.NewEngine(st, dedup.Config{ + AccountSourceIDs: []int64{f.Source.ID}, + Account: "test", + }, nil) + report, err := eng.Scan(context.Background()) + testutil.MustNoErr(t, err, "Scan") + if report.DuplicateGroups != 1 { + t.Fatalf("groups = %d, want 1", report.DuplicateGroups) + } + survivor := report.Groups[0].Messages[report.Groups[0].Survivor] + if survivor.ID != idHasRaw { + t.Errorf("survivor = %d, want %d (has raw)", survivor.ID, idHasRaw) + } + _ = idNoRaw + }) + + t.Run("more labels wins when raw MIME is equal", func(t *testing.T) { + f := storetest.New(t) + st := f.Store + + idFew := addMessage(t, st, f.Source, "few", "rfc-label-tie", false) + idMany := addMessage(t, st, f.Source, "many", "rfc-label-tie", false) + + lid1, _ := st.EnsureLabel(f.Source.ID, "L1", "Label1", "user") + lid2, _ := st.EnsureLabel(f.Source.ID, "L2", "Label2", "user") + lid3, _ := st.EnsureLabel(f.Source.ID, "L3", "Label3", "user") + _ = st.LinkMessageLabel(idFew, lid1) + _ = st.LinkMessageLabel(idMany, lid1) + _ = st.LinkMessageLabel(idMany, lid2) + _ = st.LinkMessageLabel(idMany, lid3) + + eng := dedup.NewEngine(st, dedup.Config{ + AccountSourceIDs: []int64{f.Source.ID}, + Account: "test", + }, nil) + report, err := eng.Scan(context.Background()) + testutil.MustNoErr(t, err, "Scan") + if report.DuplicateGroups != 1 { + t.Fatalf("groups = %d, want 1", report.DuplicateGroups) + } + survivor := report.Groups[0].Messages[report.Groups[0].Survivor] + if survivor.ID != idMany { + t.Errorf("survivor = %d, want %d (more labels)", survivor.ID, idMany) + } + }) + + t.Run("lower ID wins as final tiebreaker", func(t *testing.T) { + f := storetest.New(t) + st := f.Store + + idFirst := addMessage(t, st, f.Source, "first", "rfc-id-tie", false) + _ = addMessage(t, st, f.Source, "second", "rfc-id-tie", false) + + eng := dedup.NewEngine(st, dedup.Config{ + AccountSourceIDs: []int64{f.Source.ID}, + Account: "test", + }, nil) + report, err := eng.Scan(context.Background()) + testutil.MustNoErr(t, err, "Scan") + if report.DuplicateGroups != 1 { + t.Fatalf("groups = %d, want 1", report.DuplicateGroups) + } + survivor := report.Groups[0].Messages[report.Groups[0].Survivor] + if survivor.ID != idFirst { + t.Errorf("survivor = %d, want %d (lower ID)", survivor.ID, idFirst) + } + }) +} + +// addMessageWithFrom is like addMessage but also sets FromEmail via the +// message_recipients table so the dedup query can read it. +func addMessageWithFrom( + t *testing.T, + st *store.Store, + source *store.Source, + srcMsgID, rfc822ID, fromEmail string, +) int64 { + t.Helper() + convID, err := st.EnsureConversation( + source.ID, "thread-"+srcMsgID, "Subject", + ) + testutil.MustNoErr(t, err, "EnsureConversation") + id, err := st.UpsertMessage(&store.Message{ + ConversationID: convID, + SourceID: source.ID, + SourceMessageID: srcMsgID, + RFC822MessageID: sql.NullString{ + String: rfc822ID, Valid: rfc822ID != "", + }, + MessageType: "email", + IsFromMe: false, // no is_from_me so MatchedIdentity is the deciding signal + SizeEstimate: 1000, + }) + testutil.MustNoErr(t, err, "UpsertMessage") + if fromEmail != "" { + pid, pErr := st.EnsureParticipant(fromEmail, "", "") + testutil.MustNoErr(t, pErr, "EnsureParticipant") + testutil.MustNoErr(t, + st.ReplaceMessageRecipients(id, "from", []int64{pid}, []string{""}), + "ReplaceMessageRecipients", + ) + } + return id +} + +// TestEngine_PerSourceIdentity verifies that identity matching is per-source: +// an address confirmed only for source A does not count as "me" in source B. +func TestEngine_PerSourceIdentity(t *testing.T) { + f := storetest.New(t) + st := f.Store + sourceA := f.Source // already created by storetest.New + + sourceB, err := st.GetOrCreateSource("mbox", "bob@example.com-mbox") + testutil.MustNoErr(t, err, "GetOrCreateSource sourceB") + + const me = "me@personal.com" + const rfc = "rfc-identity-perscource" + + // Add me@personal.com as confirmed identity only for source A. + testutil.MustNoErr(t, + st.AddAccountIdentity(sourceA.ID, me, "test"), + "AddAccountIdentity sourceA", + ) + + // Two messages with same RFC822 ID, both From: me@personal.com, + // one in each source. Neither has HasSentLabel or IsFromMe. + idA := addMessageWithFrom(t, st, sourceA, "a-identity", rfc, me) + idB := addMessageWithFrom(t, st, sourceB, "b-identity", rfc, me) + + identities := map[int64]map[string]struct{}{ + sourceA.ID: {me: {}}, + // sourceB intentionally omitted + } + + eng := dedup.NewEngine(st, dedup.Config{ + AccountSourceIDs: []int64{sourceA.ID, sourceB.ID}, + Account: "test", + IdentityAddressesBySource: identities, + }, nil) + + report, err := eng.Scan(context.Background()) + testutil.MustNoErr(t, err, "Scan") + if report.DuplicateGroups != 1 { + t.Fatalf("groups = %d, want 1", report.DuplicateGroups) + } + + group := report.Groups[0] + // Find the message structs for each source. + var msgA, msgB dedup.DuplicateMessage + for _, m := range group.Messages { + switch m.ID { + case idA: + msgA = m + case idB: + msgB = m + } + } + + if !msgA.MatchedIdentity { + t.Errorf("source A copy: MatchedIdentity = false, want true") + } + if msgB.MatchedIdentity { + t.Errorf("source B copy: MatchedIdentity = true, want false (identity not confirmed for source B)") + } + + // Survivor should be the source A copy because it is the sent copy. + survivor := group.Messages[group.Survivor] + if survivor.ID != idA { + t.Errorf("survivor = %d (%s), want %d (source A, matched identity)", + survivor.ID, survivor.SourceIdentifier, idA) + } +} diff --git a/internal/dedup/normalize_test.go b/internal/dedup/normalize_test.go new file mode 100644 index 00000000..e79609aa --- /dev/null +++ b/internal/dedup/normalize_test.go @@ -0,0 +1,97 @@ +package dedup + +import ( + "bytes" + "testing" +) + +func TestNormalizeRawMIME(t *testing.T) { + tests := []struct { + name string + input []byte + wantSame bool // true if output should equal input + contains string // substring the output must contain + excludes string // substring the output must NOT contain + }{ + { + name: "strips Received header (CRLF)", + input: []byte("Received: from mx1.google.com\r\nFrom: alice@example.com\r\nSubject: Hi\r\n\r\nBody"), + contains: "From: alice@example.com", + excludes: "Received", + }, + { + name: "strips multiple transport headers", + input: []byte("Delivered-To: bob@example.com\r\nX-Gmail-Labels: INBOX\r\nAuthentication-Results: spf=pass\r\nFrom: alice@example.com\r\nSubject: Test\r\n\r\nBody"), + contains: "From: alice@example.com", + excludes: "Delivered-To", + }, + { + name: "preserves non-transport headers", + input: []byte("From: alice@example.com\r\nTo: bob@example.com\r\nSubject: Meeting\r\nDate: Mon, 1 Jan 2024 12:00:00 +0000\r\n\r\nBody text"), + contains: "Subject: Meeting", + }, + { + name: "handles LF-only line endings", + input: []byte("Received: from mx1\nFrom: alice@example.com\nSubject: Test\n\nBody with LF"), + contains: "From: alice@example.com", + excludes: "Received", + }, + { + name: "no header/body separator returns raw unchanged", + input: []byte("This is just a blob of text with no headers"), + wantSame: true, + }, + { + name: "empty body preserved", + input: []byte("From: alice@example.com\r\nSubject: Empty\r\n\r\n"), + contains: "Subject: Empty", + }, + { + name: "preserves body content exactly", + input: []byte("Received: from mx1\r\nFrom: a@b.com\r\n\r\nExact body content here."), + contains: "Exact body content here.", + }, + { + name: "LF headers with CRLF in body uses earliest boundary", + input: []byte("From: a@b.com\nSubject: Test\n\nBody has \r\n\r\n inside"), + contains: "Body has \r\n\r\n inside", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inputCopy := make([]byte, len(tt.input)) + copy(inputCopy, tt.input) + + result := normalizeRawMIME(tt.input) + + if !bytes.Equal(tt.input, inputCopy) { + t.Error("normalizeRawMIME mutated its input buffer") + } + + if tt.wantSame { + if !bytes.Equal(result, tt.input) { + t.Errorf("expected unchanged output, got:\n%s", result) + } + return + } + if tt.contains != "" && !bytes.Contains(result, []byte(tt.contains)) { + t.Errorf("output missing %q:\n%s", tt.contains, result) + } + if tt.excludes != "" && bytes.Contains(result, []byte(tt.excludes)) { + t.Errorf("output should not contain %q:\n%s", tt.excludes, result) + } + }) + } +} + +func TestNormalizeRawMIME_DeterministicOutput(t *testing.T) { + raw1 := []byte("Received: from mx1.google.com\r\nFrom: sender@example.com\r\nSubject: Meeting\r\nDate: Mon, 1 Jan 2024 12:00:00 +0000\r\n\r\nLet's meet at 3pm.") + raw2 := []byte("Received: from mx2.google.com\r\nDelivered-To: other@example.com\r\nFrom: sender@example.com\r\nSubject: Meeting\r\nDate: Mon, 1 Jan 2024 12:00:00 +0000\r\n\r\nLet's meet at 3pm.") + + hash1 := sha256Hex(normalizeRawMIME(raw1)) + hash2 := sha256Hex(normalizeRawMIME(raw2)) + if hash1 != hash2 { + t.Errorf("same message with different transport headers produced different hashes") + } +} diff --git a/internal/deletion/manifest.go b/internal/deletion/manifest.go index 5439b8d6..9a1a1908 100644 --- a/internal/deletion/manifest.go +++ b/internal/deletion/manifest.go @@ -203,11 +203,12 @@ var statusDirMap = map[Status]string{ StatusInProgress: "in_progress", StatusCompleted: "completed", StatusFailed: "failed", + StatusCancelled: "cancelled", } // persistedStatuses lists all statuses that have on-disk directories. var persistedStatuses = []Status{ - StatusPending, StatusInProgress, StatusCompleted, StatusFailed, + StatusPending, StatusInProgress, StatusCompleted, StatusFailed, StatusCancelled, } // Manager handles deletion manifest files. @@ -273,6 +274,11 @@ func (m *Manager) ListFailed() ([]*Manifest, error) { return m.listManifests(m.dirForStatus(StatusFailed)) } +// ListCancelled returns all cancelled deletion manifests. +func (m *Manager) ListCancelled() ([]*Manifest, error) { + return m.listManifests(m.dirForStatus(StatusCancelled)) +} + func (m *Manager) listManifests(dir string) ([]*Manifest, error) { entries, err := os.ReadDir(dir) if err != nil { @@ -350,7 +356,7 @@ func (m *Manager) MoveManifest(id string, fromStatus, toStatus Status) error { } switch toStatus { - case StatusInProgress, StatusCompleted, StatusFailed: + case StatusInProgress, StatusCompleted, StatusFailed, StatusCancelled: // allowed default: return fmt.Errorf("cannot move to status %s", toStatus) @@ -361,18 +367,41 @@ func (m *Manager) MoveManifest(id string, fromStatus, toStatus Status) error { return os.Rename(fromPath, toPath) } -// CancelManifest removes a pending or in-progress manifest. +// CancelManifest moves a pending or in-progress manifest to the +// cancelled directory and updates its inline Status field. Returns +// an error if the manifest is not found in pending or in_progress. +// +// Order: rename first (atomic on same fs), then rewrite inline Status +// at the new location. The directory is authoritative per spec, so a +// crash between rename and status rewrite leaves a manifest in +// cancelled/ with a stale Status=pending field — readers still see +// it as cancelled and the inline field self-heals on the next save. +// The reverse order risks the worst outcome: a manifest in pending/ +// with Status=cancelled, which contradicts the authoritative dir. +// +// Note: Manifest.String() prints the inline Status field. A concurrent +// reader that rendered a manifest between the rename and the inline +// rewrite would see the pre-cancel status. Acceptable because callers +// re-read after a successful CancelManifest return. func (m *Manager) CancelManifest(id string) error { - // Try pending first, then in_progress - for _, dir := range []string{m.PendingDir(), m.InProgressDir()} { - path := filepath.Join(dir, id+".json") - err := os.Remove(path) - if err == nil { - return nil + for _, fromStatus := range []Status{StatusPending, StatusInProgress} { + fromPath := filepath.Join(m.dirForStatus(fromStatus), id+".json") + if _, err := os.Stat(fromPath); os.IsNotExist(err) { + continue + } + if err := m.MoveManifest(id, fromStatus, StatusCancelled); err != nil { + return fmt.Errorf("move manifest %s to cancelled: %w", id, err) + } + toPath := filepath.Join(m.dirForStatus(StatusCancelled), id+".json") + manifest, err := LoadManifest(toPath) + if err != nil { + return fmt.Errorf("reload manifest %s after move: %w", id, err) } - if !os.IsNotExist(err) { - return fmt.Errorf("remove %s: %w", path, err) + manifest.Status = StatusCancelled + if err := manifest.Save(toPath); err != nil { + return fmt.Errorf("update inline status for %s: %w", id, err) } + return nil } return fmt.Errorf("manifest %s not found in pending or in_progress", id) } diff --git a/internal/deletion/manifest_test.go b/internal/deletion/manifest_test.go index 2c158037..e4395355 100644 --- a/internal/deletion/manifest_test.go +++ b/internal/deletion/manifest_test.go @@ -126,6 +126,8 @@ func AssertManifestInState(t *testing.T, mgr *Manager, id string, state Status) dir = mgr.CompletedDir() case StatusFailed: dir = mgr.FailedDir() + case StatusCancelled: + dir = mgr.dirForStatus(StatusCancelled) default: t.Fatalf("unknown state %q", state) } @@ -423,7 +425,7 @@ func TestNewManager(t *testing.T) { } // Verify all directories were created - expectedDirs := []string{"pending", "in_progress", "completed", "failed"} + expectedDirs := []string{"pending", "in_progress", "completed", "failed", "cancelled"} for _, d := range expectedDirs { path := filepath.Join(baseDir, d) if info, err := os.Stat(path); err != nil || !info.IsDir() { @@ -527,6 +529,8 @@ func TestManager_Transitions(t *testing.T) { {"in_progress->failed", [][2]Status{{StatusPending, StatusInProgress}, {StatusInProgress, StatusFailed}}, false, [4]int{0, 0, 0, 1}}, {"completed->pending (invalid)", [][2]Status{{StatusPending, StatusInProgress}, {StatusInProgress, StatusCompleted}, {StatusCompleted, StatusPending}}, true, [4]int{}}, {"pending->pending (invalid)", [][2]Status{{StatusPending, StatusPending}}, true, [4]int{}}, + {"pending->cancelled", [][2]Status{{StatusPending, StatusCancelled}}, false, [4]int{0, 0, 0, 0}}, + {"in_progress->cancelled", [][2]Status{{StatusPending, StatusInProgress}, {StatusInProgress, StatusCancelled}}, false, [4]int{0, 0, 0, 0}}, } for _, tc := range tests { @@ -563,32 +567,67 @@ func TestManager_Transitions(t *testing.T) { func TestManager_CancelManifest(t *testing.T) { mgr := testManager(t) - m := createTestManifest(t, mgr, "cancel test") + manifest := createTestManifest(t, mgr, "cancel test") // Cancel it - if err := mgr.CancelManifest(m.ID); err != nil { + if err := mgr.CancelManifest(manifest.ID); err != nil { t.Fatalf("CancelManifest() error = %v", err) } - // Should be gone - assertListCount(t, mgr.ListPending, 0) + baseDir := filepath.Dir(mgr.PendingDir()) + + // File should now exist at cancelled/.json with Status=cancelled. + cancelledPath := filepath.Join(baseDir, "cancelled", manifest.ID+".json") + if _, err := os.Stat(cancelledPath); err != nil { + t.Fatalf("expected cancelled manifest at %s: %v", cancelledPath, err) + } + loaded, err := LoadManifest(cancelledPath) + if err != nil { + t.Fatalf("load cancelled manifest: %v", err) + } + if loaded.Status != StatusCancelled { + t.Errorf("loaded.Status = %q, want %q", loaded.Status, StatusCancelled) + } + // File should be absent from pending/. + pendingPath := filepath.Join(baseDir, "pending", manifest.ID+".json") + if _, err := os.Stat(pendingPath); !os.IsNotExist(err) { + t.Errorf("expected pending manifest gone, got err=%v", err) + } } func TestManager_CancelManifest_InProgress(t *testing.T) { mgr := testManager(t) - m := createTestManifest(t, mgr, "cancel in-progress") + manifest := createTestManifest(t, mgr, "cancel in-progress") // Move to in_progress - if err := mgr.MoveManifest(m.ID, StatusPending, StatusInProgress); err != nil { + if err := mgr.MoveManifest(manifest.ID, StatusPending, StatusInProgress); err != nil { t.Fatalf("MoveManifest() error = %v", err) } // Cancel it - if err := mgr.CancelManifest(m.ID); err != nil { + if err := mgr.CancelManifest(manifest.ID); err != nil { t.Fatalf("CancelManifest() error = %v", err) } - assertListCount(t, mgr.ListInProgress, 0) + baseDir := filepath.Dir(mgr.PendingDir()) + + // File should now exist at cancelled/.json with Status=cancelled. + cancelledPath := filepath.Join(baseDir, "cancelled", manifest.ID+".json") + if _, err := os.Stat(cancelledPath); err != nil { + t.Fatalf("expected cancelled manifest at %s: %v", cancelledPath, err) + } + loaded, err := LoadManifest(cancelledPath) + if err != nil { + t.Fatalf("load cancelled manifest: %v", err) + } + if loaded.Status != StatusCancelled { + t.Errorf("loaded.Status = %q, want %q", loaded.Status, StatusCancelled) + } + // File should be absent from in_progress/. + inProgressPath := filepath.Join(baseDir, "in_progress", manifest.ID+".json") + if _, err := os.Stat(inProgressPath); !os.IsNotExist(err) { + t.Errorf("expected in_progress manifest gone, got err=%v", err) + } } func TestManager_CancelManifest_NotFound(t *testing.T) { @@ -698,6 +737,7 @@ func TestStatusDirMap(t *testing.T) { StatusInProgress: "in_progress", StatusCompleted: "completed", StatusFailed: "failed", + StatusCancelled: "cancelled", } for status, wantDir := range expectedMappings { gotDir, ok := statusDirMap[status] @@ -723,6 +763,7 @@ func TestDirForStatus(t *testing.T) { {StatusInProgress, "in_progress"}, {StatusCompleted, "completed"}, {StatusFailed, "failed"}, + {StatusCancelled, "cancelled"}, } for _, tc := range tests { @@ -760,12 +801,37 @@ func TestPersistedStatusesComplete(t *testing.T) { } } - // StatusCancelled should NOT be in persistedStatuses (cancelled manifests are deleted) + // StatusCancelled MUST be in persistedStatuses; cancelled manifests + // are persisted on disk for audit (per spec § Manifest format). + found := false for _, ps := range persistedStatuses { if ps == StatusCancelled { - t.Errorf("StatusCancelled should not be in persistedStatuses") + found = true + break } } + if !found { + t.Errorf("StatusCancelled missing from persistedStatuses; cancelled manifests must persist on disk") + } +} + +func TestManager_ListCancelled(t *testing.T) { + mgr := testManager(t) + + manifest := createTestManifest(t, mgr, "test cancel") + if err := mgr.CancelManifest(manifest.ID); err != nil { + t.Fatalf("CancelManifest: %v", err) + } + got, err := mgr.ListCancelled() + if err != nil { + t.Fatalf("ListCancelled: %v", err) + } + if len(got) != 1 { + t.Fatalf("ListCancelled returned %d manifests, want 1", len(got)) + } + if got[0].ID != manifest.ID { + t.Errorf("got ID %q, want %q", got[0].ID, manifest.ID) + } } // TestManager_SaveManifest_UnknownStatus tests saving with an unknown status. diff --git a/internal/importer/emlx_import.go b/internal/importer/emlx_import.go index 6c22c76a..11641b66 100644 --- a/internal/importer/emlx_import.go +++ b/internal/importer/emlx_import.go @@ -55,6 +55,7 @@ type EmlxImportOptions struct { // EmlxImportSummary reports the results of an emlx import. type EmlxImportSummary struct { + SourceID int64 WasResumed bool Duration time.Duration MailboxesTotal int @@ -130,6 +131,7 @@ func ImportEmlxDir( if err != nil { return nil, fmt.Errorf("get/create source: %w", err) } + summary.SourceID = src.ID // Resume support. var ( diff --git a/internal/importer/mbox_import.go b/internal/importer/mbox_import.go index c91642a2..92b24b95 100644 --- a/internal/importer/mbox_import.go +++ b/internal/importer/mbox_import.go @@ -51,6 +51,7 @@ type MboxImportOptions struct { } type MboxImportSummary struct { + SourceID int64 WasResumed bool ResumedOffset int64 FinalOffset int64 @@ -117,6 +118,7 @@ func ImportMbox(ctx context.Context, st *store.Store, mboxPath string, opts Mbox if err != nil { return nil, fmt.Errorf("get/create source: %w", err) } + summary.SourceID = src.ID // Create or resume the sync run for this source. var ( diff --git a/internal/logging/logging.go b/internal/logging/logging.go index 623abc47..02a5f111 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -211,10 +211,17 @@ func BuildHandler(opts Options) (*Result, error) { _ = f.Close() }) default: + target := path + if target == "" { + target = opts.FilePath + } + if target == "" { + target = opts.LogsDir + } _, _ = fmt.Fprintf(stderr, "warning: could not open msgvault log file in %s: %v "+ "(continuing with stderr-only logging)\n", - opts.LogsDir, err, + target, err, ) } } diff --git a/internal/mcp/handlers.go b/internal/mcp/handlers.go index d667a22d..fbd9544c 100644 --- a/internal/mcp/handlers.go +++ b/internal/mcp/handlers.go @@ -188,7 +188,9 @@ func (h *handlers) searchMessages(ctx context.Context, req mcp.CallToolRequest) } q := search.Parse(queryStr) - q.AccountID = sourceID + if sourceID != nil { + q.AccountIDs = []int64{*sourceID} + } filter := query.MessageFilter{SourceID: sourceID} @@ -888,7 +890,9 @@ func (h *handlers) stageDeletion(ctx context.Context, req mcp.CallToolRequest) ( if hasQuery { // Query-based search q := search.Parse(queryStr) - q.AccountID = sourceID + if sourceID != nil { + q.AccountIDs = []int64{*sourceID} + } // Try fast search first filter := query.MessageFilter{SourceID: sourceID} @@ -1004,7 +1008,7 @@ func (h *handlers) stageDeletion(ctx context.Context, req mcp.CallToolRequest) ( BatchID: manifest.ID, MessageCount: len(gmailIDs), Status: string(manifest.Status), - NextStep: "Run 'msgvault delete-staged' to execute deletion, or 'msgvault cancel-deletion " + manifest.ID + "' to cancel", + NextStep: "Run 'MSGVAULT_ENABLE_REMOTE_DELETE=1 msgvault delete-staged' to execute deletion (gated for v1), or 'msgvault cancel-deletion " + manifest.ID + "' to cancel", } return jsonResult(resp) diff --git a/internal/query/duckdb.go b/internal/query/duckdb.go index 54fd944a..5c0e6ff3 100644 --- a/internal/query/duckdb.go +++ b/internal/query/duckdb.go @@ -15,6 +15,7 @@ import ( _ "github.com/marcboeker/go-duckdb" "github.com/wesm/msgvault/internal/search" + "github.com/wesm/msgvault/internal/store" ) // DuckDBEngine implements Engine using DuckDB for fast Parquet queries. @@ -297,6 +298,11 @@ func (e *DuckDBEngine) parquetCTEs() string { } else { msgExtra = append(msgExtra, "'' AS message_type") } + if e.hasCol("messages", "deleted_at") { + msgReplace = append(msgReplace, "TRY_CAST(deleted_at AS TIMESTAMP) AS deleted_at") + } else { + msgExtra = append(msgExtra, "NULL::TIMESTAMP AS deleted_at") + } msgCTE := fmt.Sprintf("SELECT * REPLACE (\n\t\t\t\t%s\n\t\t\t)", strings.Join(msgReplace, ",\n\t\t\t\t")) if len(msgExtra) > 0 { msgCTE += ", " + strings.Join(msgExtra, ", ") @@ -648,10 +654,8 @@ func (e *DuckDBEngine) buildWhereClause(opts AggregateOptions, keyColumns ...str // message_type IS NULL and '' handle old data without the column. conditions = append(conditions, "(msg.message_type = 'email' OR msg.message_type IS NULL OR msg.message_type = '')") - if opts.SourceID != nil { - conditions = append(conditions, "msg.source_id = ?") - args = append(args, *opts.SourceID) - } + conditions = append(conditions, store.LiveMessagesWhere("msg", opts.HideDeletedFromSource)) + conditions, args = appendSourceFilter(conditions, args, "msg.", opts.SourceID, opts.SourceIDs) if opts.After != nil { conditions = append(conditions, "msg.sent_at >= CAST(? AS TIMESTAMP)") @@ -666,9 +670,6 @@ func (e *DuckDBEngine) buildWhereClause(opts AggregateOptions, keyColumns ...str if opts.WithAttachmentsOnly { conditions = append(conditions, "msg.has_attachments = 1") } - if opts.HideDeletedFromSource { - conditions = append(conditions, "msg.deleted_from_source_at IS NULL") - } // Text search filter for aggregates - filter on view's key columns searchConds, searchArgs := e.buildAggregateSearchConditions(opts.SearchQuery, keyColumns...) @@ -854,10 +855,8 @@ func (e *DuckDBEngine) buildFilterConditions(filter MessageFilter) (string, []in // message_type IS NULL and '' handle old data without the column. conditions = append(conditions, "(msg.message_type = 'email' OR msg.message_type IS NULL OR msg.message_type = '')") - if filter.SourceID != nil { - conditions = append(conditions, "msg.source_id = ?") - args = append(args, *filter.SourceID) - } + conditions = append(conditions, store.LiveMessagesWhere("msg", filter.HideDeletedFromSource)) + conditions, args = appendSourceFilter(conditions, args, "msg.", filter.SourceID, filter.SourceIDs) if filter.ConversationID != nil { conditions = append(conditions, "msg.conversation_id = ?") @@ -877,9 +876,6 @@ func (e *DuckDBEngine) buildFilterConditions(filter MessageFilter) (string, []in if filter.WithAttachmentsOnly { conditions = append(conditions, "msg.has_attachments = true") } - if filter.HideDeletedFromSource { - conditions = append(conditions, "msg.deleted_from_source_at IS NULL") - } // Sender filter - check both message_recipients (email) and direct sender_id (WhatsApp/chat) // Also checks phone_number for phone-based lookups (e.g., from:+447...) @@ -1041,6 +1037,11 @@ func (e *DuckDBEngine) SubAggregate(ctx context.Context, filter MessageFilter, g return nil, err } + // Reconcile opts.HideDeletedFromSource into filter so the helper + // inside buildFilterConditions sees the OR of both fields. + if opts.HideDeletedFromSource { + filter.HideDeletedFromSource = true + } where, args := e.buildFilterConditions(filter) // Add opts-based conditions (source_id, date range, attachment filter) @@ -1059,9 +1060,6 @@ func (e *DuckDBEngine) SubAggregate(ctx context.Context, filter MessageFilter, g if opts.WithAttachmentsOnly { where += " AND msg.has_attachments = true" } - if opts.HideDeletedFromSource { - where += " AND msg.deleted_from_source_at IS NULL" - } // Add search query conditions using the view's key columns searchConds, searchArgs := e.buildAggregateSearchConditions(opts.SearchQuery, def.keyColumns...) @@ -1117,17 +1115,12 @@ func (e *DuckDBEngine) GetTotalStats(ctx context.Context, opts StatsOptions) (*T // Restrict to email messages only; NULL and '' handle pre-message_type data. conditions = append(conditions, emailOnlyFilterMsg) - if opts.SourceID != nil { - conditions = append(conditions, "msg.source_id = ?") - args = append(args, *opts.SourceID) - } + conditions = append(conditions, store.LiveMessagesWhere("msg", opts.HideDeletedFromSource)) + conditions, args = appendSourceFilter(conditions, args, "msg.", opts.SourceID, opts.SourceIDs) if opts.WithAttachmentsOnly { conditions = append(conditions, "msg.has_attachments = 1") } - if opts.HideDeletedFromSource { - conditions = append(conditions, "msg.deleted_from_source_at IS NULL") - } // Search filter — uses EXISTS subqueries so no row multiplication. // For 1:N views (Recipients, RecipientNames, Labels), filter on the @@ -1478,7 +1471,9 @@ func (e *DuckDBEngine) Search(ctx context.Context, q *search.Query, limit, offse var args []interface{} var joins []string - // Include all messages (deleted messages shown with indicator in TUI) + // Exclude rows soft-deleted by deduplicate; gate source-deleted on + // q.HideDeleted via the helper. + conditions = append(conditions, store.LiveMessagesWhere("m", q.HideDeleted)) // From filter if len(q.FromAddrs) > 0 { @@ -1566,15 +1561,7 @@ func (e *DuckDBEngine) Search(ctx context.Context, q *search.Query, limit, offse } // Account filter - if q.AccountID != nil { - conditions = append(conditions, "m.source_id = ?") - args = append(args, *q.AccountID) - } - - // Hide-deleted filter - if q.HideDeleted { - conditions = append(conditions, "m.deleted_from_source_at IS NULL") - } + conditions, args = appendSourceFilter(conditions, args, "m.", nil, q.AccountIDs) if limit == 0 { limit = 100 @@ -1686,17 +1673,11 @@ func (e *DuckDBEngine) GetGmailIDsByFilter(ctx context.Context, filter MessageFi var conditions []string var args []interface{} - // Always exclude deleted messages - conditions = append(conditions, "msg.deleted_from_source_at IS NULL") - - // Gmail scoping is handled by JOIN src in the query below — this function - // is used for Gmail-specific deletion/staging workflows and must not - // return WhatsApp or other source IDs. - - if filter.SourceID != nil { - conditions = append(conditions, "msg.source_id = ?") - args = append(args, *filter.SourceID) - } + // Always exclude deleted messages. + // Always pass true: this surface feeds remote-deletion staging and + // must never honor an opt-in. + conditions = append(conditions, store.LiveMessagesWhere("msg", true)) + conditions, args = appendSourceFilter(conditions, args, "msg.", filter.SourceID, filter.SourceIDs) // Use EXISTS subqueries for filtering (becomes semi-joins, no duplicates) if filter.Sender != "" { @@ -2336,10 +2317,8 @@ func (e *DuckDBEngine) buildSearchConditions(q *search.Query, filter MessageFilt conditions = append(conditions, emailOnlyFilterMsg) // Apply basic filter conditions (ignoring join flags for search - we handle those differently) - if filter.SourceID != nil { - conditions = append(conditions, "msg.source_id = ?") - args = append(args, *filter.SourceID) - } + conditions = append(conditions, store.LiveMessagesWhere("msg", filter.HideDeletedFromSource)) + conditions, args = appendSourceFilter(conditions, args, "msg.", filter.SourceID, filter.SourceIDs) if filter.After != nil { conditions = append(conditions, "msg.sent_at >= CAST(? AS TIMESTAMP)") args = append(args, filter.After.Format("2006-01-02 15:04:05")) @@ -2351,9 +2330,6 @@ func (e *DuckDBEngine) buildSearchConditions(q *search.Query, filter MessageFilt if filter.WithAttachmentsOnly { conditions = append(conditions, "msg.has_attachments = true") } - if filter.HideDeletedFromSource { - conditions = append(conditions, "msg.deleted_from_source_at IS NULL") - } // Sender filter - check both message_recipients (email/phone) and direct sender_id (WhatsApp/chat) if filter.Sender != "" { conditions = append(conditions, `(EXISTS ( @@ -2495,10 +2471,7 @@ func (e *DuckDBEngine) buildSearchConditions(q *search.Query, filter MessageFilt } // Account filter - if q.AccountID != nil { - conditions = append(conditions, "msg.source_id = ?") - args = append(args, *q.AccountID) - } + conditions, args = appendSourceFilter(conditions, args, "msg.", nil, q.AccountIDs) // Default conditions if none specified if len(conditions) == 0 { diff --git a/internal/query/duckdb_text.go b/internal/query/duckdb_text.go index e8ef94e2..379bc77c 100644 --- a/internal/query/duckdb_text.go +++ b/internal/query/duckdb_text.go @@ -5,6 +5,8 @@ import ( "database/sql" "fmt" "strings" + + "github.com/wesm/msgvault/internal/store" ) // Compile-time interface assertion. @@ -409,7 +411,7 @@ func (e *DuckDBEngine) TextSearch( } // Use FTS5 MATCH on messages_fts, filtered to text message types. - sqlQuery := ` + sqlQuery := fmt.Sprintf(` SELECT m.id, COALESCE(m.source_message_id, '') AS source_message_id, @@ -433,9 +435,10 @@ func (e *DuckDBEngine) TextSearch( LEFT JOIN conversations c ON c.id = m.conversation_id WHERE messages_fts MATCH ? AND m.message_type IN ('whatsapp','imessage','sms','google_voice_text') + AND %s ORDER BY m.sent_at DESC LIMIT ? OFFSET ? - ` + `, store.LiveMessagesWhere("m", true)) rows, err := e.sqliteDB.QueryContext(ctx, sqlQuery, query, limit, offset) diff --git a/internal/query/models.go b/internal/query/models.go index 43b13df0..aa63b355 100644 --- a/internal/query/models.go +++ b/internal/query/models.go @@ -212,7 +212,8 @@ type MessageFilter struct { TimeRange TimeRange // Account filter - SourceID *int64 // nil means all accounts + SourceID *int64 // nil means all accounts + SourceIDs []int64 // multi-source filter (collections); overrides SourceID // Date range After *time.Time @@ -282,13 +283,17 @@ func (f MessageFilter) Clone() MessageFilter { clone.EmptyValueTargets[k] = v } } + if f.SourceIDs != nil { + clone.SourceIDs = append([]int64(nil), f.SourceIDs...) + } return clone } // AggregateOptions configures an aggregate query. type AggregateOptions struct { // Account filter - SourceID *int64 // nil means all accounts + SourceID *int64 // nil means all accounts + SourceIDs []int64 // multi-source filter (collections) // Date range After *time.Time @@ -333,6 +338,7 @@ type AccountInfo struct { // StatsOptions configures a stats query. type StatsOptions struct { SourceID *int64 // nil means all accounts + SourceIDs []int64 // multi-source filter (collections) WithAttachmentsOnly bool // only count messages with attachments HideDeletedFromSource bool // exclude messages where deleted_from_source_at IS NOT NULL SearchQuery string // when set, stats reflect only messages matching this search diff --git a/internal/query/shared.go b/internal/query/shared.go index cb272c58..a157f36c 100644 --- a/internal/query/shared.go +++ b/internal/query/shared.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/wesm/msgvault/internal/mime" + "github.com/wesm/msgvault/internal/store" ) // emailOnlyFilterMsg is the SQL condition restricting to email messages with "msg." alias (DuckDB). @@ -184,9 +185,10 @@ func extractBodyFromRawShared(ctx context.Context, db *sql.DB, tablePrefix strin } // getMessageRawShared retrieves and decompresses raw MIME data for a message. -// Returns nil, nil if no raw data is stored, or if the message has been -// deleted from source — the listing/search endpoints hide deleted-from-source -// messages, so the raw-MIME path stays aligned and refuses to serve them. +// Returns nil, nil if no raw data is stored, or if the message is hidden from +// normal reads — dedup losers (deleted_at) and source-deleted rows +// (deleted_from_source_at) are both filtered, matching the visibility rule +// the list/search endpoints apply via store.LiveMessagesWhere. func getMessageRawShared(ctx context.Context, db *sql.DB, tablePrefix string, messageID int64) ([]byte, error) { var compressed []byte var compression sql.NullString @@ -195,8 +197,8 @@ func getMessageRawShared(ctx context.Context, db *sql.DB, tablePrefix string, me SELECT mr.raw_data, mr.compression FROM %smessage_raw mr JOIN %smessages m ON m.id = mr.message_id - WHERE mr.message_id = ? AND m.deleted_from_source_at IS NULL - `, tablePrefix, tablePrefix), messageID).Scan(&compressed, &compression) + WHERE mr.message_id = ? AND %s + `, tablePrefix, tablePrefix, store.LiveMessagesWhere("m", true)), messageID).Scan(&compressed, &compression) if err == sql.ErrNoRows { return nil, nil } diff --git a/internal/query/source_filter.go b/internal/query/source_filter.go new file mode 100644 index 00000000..501419c6 --- /dev/null +++ b/internal/query/source_filter.go @@ -0,0 +1,37 @@ +package query + +import ( + "fmt" + "strings" +) + +// appendSourceFilter returns conditions/args updated with a source-id +// filter drawn from either SourceIDs (multi) or SourceID (single). +// SourceIDs takes precedence when both are provided. A non-nil but +// empty multiIDs slice produces a 1=0 (match-nothing) condition. +func appendSourceFilter( + conditions []string, args []any, + prefix string, singleID *int64, multiIDs []int64, +) ([]string, []any) { + if multiIDs != nil && len(multiIDs) == 0 { + conditions = append(conditions, "1=0") + return conditions, args + } + if len(multiIDs) > 0 { + placeholders := make([]string, len(multiIDs)) + for i, id := range multiIDs { + placeholders[i] = "?" + args = append(args, id) + } + conditions = append(conditions, fmt.Sprintf( + "%ssource_id IN (%s)", + prefix, strings.Join(placeholders, ","), + )) + return conditions, args + } + if singleID != nil { + conditions = append(conditions, prefix+"source_id = ?") + args = append(args, *singleID) + } + return conditions, args +} diff --git a/internal/query/source_filter_test.go b/internal/query/source_filter_test.go new file mode 100644 index 00000000..3887fd78 --- /dev/null +++ b/internal/query/source_filter_test.go @@ -0,0 +1,112 @@ +package query + +import ( + "testing" +) + +func TestAppendSourceFilter(t *testing.T) { + id42 := int64(42) + + tests := []struct { + name string + singleID *int64 + multiIDs []int64 + prefix string + wantConditions int + wantArgs int + wantCondition string + }{ + { + name: "neither single nor multi", + singleID: nil, + multiIDs: nil, + prefix: "m.", + wantConditions: 0, + wantArgs: 0, + }, + { + name: "single ID", + singleID: &id42, + multiIDs: nil, + prefix: "m.", + wantConditions: 1, + wantArgs: 1, + wantCondition: "m.source_id = ?", + }, + { + name: "empty multi IDs matches nothing", + singleID: nil, + multiIDs: []int64{}, + prefix: "m.", + wantConditions: 1, + wantArgs: 0, + wantCondition: "1=0", + }, + { + name: "empty multi IDs overrides singleID", + singleID: &id42, + multiIDs: []int64{}, + prefix: "m.", + wantConditions: 1, + wantArgs: 0, + wantCondition: "1=0", + }, + { + name: "single multi ID", + singleID: nil, + multiIDs: []int64{7}, + prefix: "m.", + wantConditions: 1, + wantArgs: 1, + wantCondition: "m.source_id IN (?)", + }, + { + name: "multi IDs", + singleID: nil, + multiIDs: []int64{1, 2, 3}, + prefix: "msg.", + wantConditions: 1, + wantArgs: 3, + wantCondition: "msg.source_id IN (?,?,?)", + }, + { + name: "multi IDs take precedence over single", + singleID: &id42, + multiIDs: []int64{10, 20}, + prefix: "", + wantConditions: 1, + wantArgs: 2, + wantCondition: "source_id IN (?,?)", + }, + { + name: "empty prefix", + singleID: &id42, + multiIDs: nil, + prefix: "", + wantConditions: 1, + wantArgs: 1, + wantCondition: "source_id = ?", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conditions, args := appendSourceFilter( + nil, nil, tt.prefix, tt.singleID, tt.multiIDs, + ) + if len(conditions) != tt.wantConditions { + t.Errorf("conditions = %d, want %d: %v", + len(conditions), tt.wantConditions, conditions) + } + if len(args) != tt.wantArgs { + t.Errorf("args = %d, want %d", len(args), tt.wantArgs) + } + if tt.wantCondition != "" && len(conditions) > 0 { + if conditions[0] != tt.wantCondition { + t.Errorf("condition = %q, want %q", + conditions[0], tt.wantCondition) + } + } + }) + } +} diff --git a/internal/query/sqlite.go b/internal/query/sqlite.go index 19ae7444..18e3a8ba 100644 --- a/internal/query/sqlite.go +++ b/internal/query/sqlite.go @@ -10,6 +10,7 @@ import ( "time" "github.com/wesm/msgvault/internal/search" + "github.com/wesm/msgvault/internal/store" ) // SQLiteEngine implements Engine using direct SQLite queries. @@ -189,10 +190,13 @@ func optsToFilterConditions(opts AggregateOptions, prefix string) ([]string, []i // message_type IS NULL and '' handle old data without the column. conditions = append(conditions, "("+prefix+"message_type = 'email' OR "+prefix+"message_type IS NULL OR "+prefix+"message_type = '')") - if opts.SourceID != nil { - conditions = append(conditions, prefix+"source_id = ?") - args = append(args, *opts.SourceID) - } + // Always exclude rows soft-deleted by deduplicate; gate + // source-deleted on opts.HideDeletedFromSource via the helper. + conditions = append(conditions, store.LiveMessagesWhere(strings.TrimSuffix(prefix, "."), opts.HideDeletedFromSource)) + + conditions, args = appendSourceFilter( + conditions, args, prefix, opts.SourceID, opts.SourceIDs, + ) if opts.After != nil { conditions = append(conditions, prefix+"sent_at >= ?") args = append(args, opts.After.Format("2006-01-02 15:04:05")) @@ -204,9 +208,6 @@ func optsToFilterConditions(opts AggregateOptions, prefix string) ([]string, []i if opts.WithAttachmentsOnly { conditions = append(conditions, prefix+"has_attachments = 1") } - if opts.HideDeletedFromSource { - conditions = append(conditions, prefix+"deleted_from_source_at IS NULL") - } return conditions, args } @@ -261,10 +262,13 @@ func buildFilterJoinsAndConditions(filter MessageFilter, tableAlias string) (str // message_type IS NULL and '' handle old data without the column. conditions = append(conditions, "("+prefix+"message_type = 'email' OR "+prefix+"message_type IS NULL OR "+prefix+"message_type = '')") - if filter.SourceID != nil { - conditions = append(conditions, prefix+"source_id = ?") - args = append(args, *filter.SourceID) - } + // Always exclude rows soft-deleted by deduplicate; gate + // source-deleted on filter.HideDeletedFromSource via the helper. + conditions = append(conditions, store.LiveMessagesWhere(strings.TrimSuffix(prefix, "."), filter.HideDeletedFromSource)) + + conditions, args = appendSourceFilter( + conditions, args, prefix, filter.SourceID, filter.SourceIDs, + ) if filter.ConversationID != nil { conditions = append(conditions, prefix+"conversation_id = ?") @@ -284,9 +288,6 @@ func buildFilterJoinsAndConditions(filter MessageFilter, tableAlias string) (str if filter.WithAttachmentsOnly { conditions = append(conditions, prefix+"has_attachments = 1") } - if filter.HideDeletedFromSource { - conditions = append(conditions, prefix+"deleted_from_source_at IS NULL") - } // Sender filter - check both message_recipients (email) and direct sender_id (WhatsApp/chat) // Also checks phone_number for phone-based lookups (e.g., from:+447...) @@ -449,9 +450,21 @@ func buildFilterJoinsAndConditions(filter MessageFilter, tableAlias string) (str // SubAggregate performs aggregation on a filtered subset of messages. // This is used for sub-grouping after drill-down. func (e *SQLiteEngine) SubAggregate(ctx context.Context, filter MessageFilter, groupBy ViewType, opts AggregateOptions) ([]AggregateRow, error) { + // Reconcile opts.HideDeletedFromSource into filter so the helper + // inside buildFilterJoinsAndConditions / optsToFilterConditions + // sees the OR of both fields. Mirrors the DuckDB SubAggregate + // path so both engines emit one authoritative live-message + // predicate per query. + if opts.HideDeletedFromSource { + filter.HideDeletedFromSource = true + } filterJoins, filterConditions, args := buildFilterJoinsAndConditions(filter, "m") - // Add opts-based conditions + // Add opts-based conditions. Note: optsToFilterConditions emits + // its own LiveMessagesWhere clause (correct for the Aggregate + // caller below, which doesn't go through buildFilterJoinsAndConditions). + // In SubAggregate this means both filter-side and opts-side helpers + // emit the same clause, producing a redundant-but-correct AND chain. optsConds, optsArgs := optsToFilterConditions(opts, "m.") filterConditions = append(filterConditions, optsConds...) args = append(args, optsArgs...) @@ -737,8 +750,8 @@ func (e *SQLiteEngine) GetMessageSummariesByIDs(ctx context.Context, ids []int64 LEFT JOIN message_recipients mr_sender ON mr_sender.message_id = m.id AND mr_sender.recipient_type = 'from' LEFT JOIN participants p_sender ON p_sender.id = COALESCE(mr_sender.participant_id, m.sender_id) LEFT JOIN conversations conv ON conv.id = m.conversation_id - WHERE m.id IN (%s) AND m.deleted_from_source_at IS NULL - `, strings.Join(placeholders, ",")) + WHERE m.id IN (%s) AND %s + `, strings.Join(placeholders, ","), store.LiveMessagesWhere("m", true)) rows, err := e.db.QueryContext(ctx, q, args...) if err != nil { @@ -886,17 +899,15 @@ func (e *SQLiteEngine) GetTotalStats(ctx context.Context, opts StatsOptions) (*T var args []interface{} // Restrict to email messages only; NULL and '' handle pre-message_type data. conditions = append(conditions, emailOnlyFilterM) - // Include all messages (deleted messages shown with indicator in TUI) - if opts.SourceID != nil { - conditions = append(conditions, "m.source_id = ?") - args = append(args, *opts.SourceID) - } + // Exclude rows soft-deleted by deduplicate; gate source-deleted on + // opts.HideDeletedFromSource via the helper. + conditions = append(conditions, store.LiveMessagesWhere("m", opts.HideDeletedFromSource)) + conditions, args = appendSourceFilter( + conditions, args, "m.", opts.SourceID, opts.SourceIDs, + ) if opts.WithAttachmentsOnly { conditions = append(conditions, "m.has_attachments = 1") } - if opts.HideDeletedFromSource { - conditions = append(conditions, "m.deleted_from_source_at IS NULL") - } // Merge search conditions conditions = append(conditions, searchConditions...) args = append(args, searchArgs...) @@ -1004,13 +1015,12 @@ func (e *SQLiteEngine) GetGmailIDsByFilter(ctx context.Context, filter MessageFi var conditions []string var args []interface{} - // Always exclude deleted messages - conditions = append(conditions, "m.deleted_from_source_at IS NULL") + // Exclude remote-deleted and dedup-soft-deleted messages. + // Always pass true: this surface feeds remote-deletion staging and + // must never honor an opt-in. + conditions = append(conditions, store.LiveMessagesWhere("m", true)) - if filter.SourceID != nil { - conditions = append(conditions, "m.source_id = ?") - args = append(args, *filter.SourceID) - } + conditions, args = appendSourceFilter(conditions, args, "m.", filter.SourceID, filter.SourceIDs) // Build JOIN clauses based on filter type var joins []string @@ -1188,6 +1198,9 @@ func (e *SQLiteEngine) SearchByDomains(ctx context.Context, domains []string, af func (e *SQLiteEngine) buildSearchQueryParts(ctx context.Context, q *search.Query) (conditions []string, args []interface{}, joins []string, ftsJoin string) { // Restrict to email messages only; NULL and '' handle pre-message_type data. conditions = append(conditions, emailOnlyFilterM) + // Exclude rows soft-deleted by deduplicate; gate source-deleted on + // q.HideDeleted via the helper. + conditions = append(conditions, store.LiveMessagesWhere("m", q.HideDeleted)) // From filter - uses EXISTS to avoid join multiplication in aggregates. // Handles both exact addresses and @domain patterns. @@ -1335,15 +1348,7 @@ func (e *SQLiteEngine) buildSearchQueryParts(ctx context.Context, q *search.Quer } // Account filter - if q.AccountID != nil { - conditions = append(conditions, "m.source_id = ?") - args = append(args, *q.AccountID) - } - - // Hide-deleted filter - if q.HideDeleted { - conditions = append(conditions, "m.deleted_from_source_at IS NULL") - } + conditions, args = appendSourceFilter(conditions, args, "m.", nil, q.AccountIDs) return conditions, args, joins, ftsJoin } @@ -1478,10 +1483,23 @@ func MergeFilterIntoQuery(q *search.Query, filter MessageFilter) *search.Query { merged.BccAddrs = append([]string(nil), q.BccAddrs...) merged.SubjectTerms = append([]string(nil), q.SubjectTerms...) merged.Labels = append([]string(nil), q.Labels...) - - // Account filter - always apply if set - if filter.SourceID != nil { - merged.AccountID = filter.SourceID + // Deep-copy AccountIDs alongside the other slices so the merged + // query never aliases the original's slice header. Filter overrides + // below replace the deep-copied slice when set. + merged.AccountIDs = append([]int64(nil), q.AccountIDs...) + + // Account filter - always apply if set. Multi-source SourceIDs takes + // precedence over single SourceID, matching appendSourceFilter + // semantics elsewhere in the package: a non-nil but empty SourceIDs + // slice is "match nothing" (the caller explicitly scoped to no + // sources) and must clear any AccountIDs the original query carried. + // Allocate a fresh slice (not append-from-nil, which would collapse + // an explicit empty back to nil and lose the match-nothing signal). + if filter.SourceIDs != nil { + merged.AccountIDs = make([]int64, len(filter.SourceIDs)) + copy(merged.AccountIDs, filter.SourceIDs) + } else if filter.SourceID != nil { + merged.AccountIDs = []int64{*filter.SourceID} } // Sender filter - append to existing from: filters diff --git a/internal/query/sqlite_search_test.go b/internal/query/sqlite_search_test.go index b9b0e99e..3e2142c7 100644 --- a/internal/query/sqlite_search_test.go +++ b/internal/query/sqlite_search_test.go @@ -334,7 +334,7 @@ func TestMergeFilterIntoQuery(t *testing.T) { name: "SourceID", initial: &search.Query{}, filter: MessageFilter{SourceID: &sourceID42}, - expected: &search.Query{AccountID: &sourceID42}, + expected: &search.Query{AccountIDs: []int64{sourceID42}}, }, { name: "SenderAppends", @@ -386,7 +386,7 @@ func TestMergeFilterIntoQuery(t *testing.T) { ToAddrs: []string{"carol@example.com"}, Labels: []string{"starred"}, HasAttachment: ptr.Bool(true), - AccountID: &sourceID1, + AccountIDs: []int64{sourceID1}, }, }, } @@ -412,6 +412,37 @@ func TestMergeFilterIntoQuery_DoesNotMutateOriginal(t *testing.T) { } } +// TestMergeFilterIntoQuery_EmptySourceIDsClearsAccountScope verifies that +// an explicit empty (non-nil) SourceIDs slice is treated as match-nothing, +// matching appendSourceFilter's contract. Previously the code only +// applied SourceIDs when len > 0, so an explicit empty slice silently +// fell through and let the original query's AccountIDs leak through. +func TestMergeFilterIntoQuery_EmptySourceIDsClearsAccountScope(t *testing.T) { + q := &search.Query{AccountIDs: []int64{1, 2, 3}} + filter := MessageFilter{SourceIDs: []int64{}} // non-nil, len=0 + + merged := MergeFilterIntoQuery(q, filter) + if merged.AccountIDs == nil { + t.Fatal("merged.AccountIDs is nil; want non-nil empty slice (match-nothing)") + } + if len(merged.AccountIDs) != 0 { + t.Errorf("merged.AccountIDs = %v; want empty (match-nothing)", merged.AccountIDs) + } +} + +// TestMergeFilterIntoQuery_NilSourceIDsPreservesAccountScope verifies the +// flip-side: a nil SourceIDs slice is "no override", and the original +// query's AccountIDs survive unchanged. +func TestMergeFilterIntoQuery_NilSourceIDsPreservesAccountScope(t *testing.T) { + q := &search.Query{AccountIDs: []int64{1, 2, 3}} + filter := MessageFilter{} // SourceIDs is nil + + merged := MergeFilterIntoQuery(q, filter) + if len(merged.AccountIDs) != 3 { + t.Errorf("merged.AccountIDs = %v; want [1 2 3]", merged.AccountIDs) + } +} + func TestMergeFilterIntoQuery_SliceAliasingMutation(t *testing.T) { backing := make([]string, 1, 10) backing[0] = "original@example.com" diff --git a/internal/query/sqlite_text.go b/internal/query/sqlite_text.go index 8b82e1fb..bbffabc4 100644 --- a/internal/query/sqlite_text.go +++ b/internal/query/sqlite_text.go @@ -6,6 +6,8 @@ import ( "fmt" "strings" "time" + + "github.com/wesm/msgvault/internal/store" ) // Compile-time interface assertion. @@ -424,7 +426,7 @@ func (e *SQLiteEngine) TextSearch( limit = 50 } - sqlQuery := ` + sqlQuery := fmt.Sprintf(` SELECT m.id, COALESCE(m.source_message_id, '') AS source_message_id, @@ -448,9 +450,10 @@ func (e *SQLiteEngine) TextSearch( LEFT JOIN conversations c ON c.id = m.conversation_id WHERE fts.messages_fts MATCH ? AND m.message_type IN ('whatsapp','imessage','sms','google_voice_text') + AND %s ORDER BY m.sent_at DESC LIMIT ? OFFSET ? - ` + `, store.LiveMessagesWhere("m", true)) rows, err := e.db.QueryContext(ctx, sqlQuery, query, limit, offset) if err != nil { diff --git a/internal/query/text_search_live_test.go b/internal/query/text_search_live_test.go new file mode 100644 index 00000000..93bce042 --- /dev/null +++ b/internal/query/text_search_live_test.go @@ -0,0 +1,128 @@ +package query + +import ( + "context" + "database/sql" + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +// openTextSearchDB creates a minimal in-memory SQLite DB with one text +// message indexed in FTS. The caller may soft-delete the message via +// SQL after this call to verify live-message filtering. +func openTextSearchDB(t *testing.T) (*sql.DB, int64) { + t.Helper() + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec(` + CREATE TABLE sources ( + id INTEGER PRIMARY KEY, + source_type TEXT NOT NULL DEFAULT 'imessage', + identifier TEXT NOT NULL UNIQUE + ); + CREATE TABLE conversations ( + id INTEGER PRIMARY KEY, + source_id INTEGER, + source_conversation_id TEXT, + title TEXT + ); + CREATE TABLE participants ( + id INTEGER PRIMARY KEY, + email_address TEXT, + display_name TEXT, + phone_number TEXT, + domain TEXT + ); + CREATE TABLE messages ( + id INTEGER PRIMARY KEY, + source_id INTEGER, + source_message_id TEXT, + conversation_id INTEGER, + sender_id INTEGER, + subject TEXT, + snippet TEXT, + sent_at DATETIME, + size_estimate INTEGER DEFAULT 0, + has_attachments INTEGER DEFAULT 0, + attachment_count INTEGER DEFAULT 0, + deleted_at DATETIME, + deleted_from_source_at DATETIME, + message_type TEXT NOT NULL DEFAULT 'imessage' + ); + CREATE VIRTUAL TABLE messages_fts USING fts5(subject, body, content='', contentless_delete=1); + `) + if err != nil { + t.Skipf("FTS5 not available: %v", err) + } + + _, err = db.Exec(`INSERT INTO sources (id, identifier) VALUES (1, 'test@example.com')`) + if err != nil { + t.Fatalf("insert source: %v", err) + } + _, err = db.Exec(`INSERT INTO conversations (id, source_id) VALUES (1, 1)`) + if err != nil { + t.Fatalf("insert conv: %v", err) + } + res, err := db.Exec(`INSERT INTO messages (id, source_id, conversation_id, subject, message_type) VALUES (1, 1, 1, 'hello world', 'imessage')`) + if err != nil { + t.Fatalf("insert message: %v", err) + } + msgID, _ := res.LastInsertId() + _, err = db.Exec(`INSERT INTO messages_fts (rowid, subject, body) VALUES (?, 'hello world', 'hello world')`, msgID) + if err != nil { + t.Fatalf("insert fts: %v", err) + } + return db, msgID +} + +func TestSQLiteEngine_TextSearch_ExcludesDedupHidden(t *testing.T) { + db, msgID := openTextSearchDB(t) + engine := NewSQLiteEngine(db) + ctx := context.Background() + + // Confirm the message appears before deletion. + results, err := engine.TextSearch(ctx, "hello", 10, 0) + if err != nil { + t.Fatalf("TextSearch before delete: %v", err) + } + if len(results) != 1 { + t.Fatalf("want 1 result before delete, got %d", len(results)) + } + + // Soft-delete via dedup (deleted_at). + if _, err := db.Exec(`UPDATE messages SET deleted_at = CURRENT_TIMESTAMP WHERE id = ?`, msgID); err != nil { + t.Fatalf("set deleted_at: %v", err) + } + + results, err = engine.TextSearch(ctx, "hello", 10, 0) + if err != nil { + t.Fatalf("TextSearch after dedup delete: %v", err) + } + if len(results) != 0 { + t.Errorf("want 0 results after dedup delete, got %d", len(results)) + } +} + +func TestSQLiteEngine_TextSearch_ExcludesSourceDeleted(t *testing.T) { + db, msgID := openTextSearchDB(t) + engine := NewSQLiteEngine(db) + ctx := context.Background() + + // Soft-delete via source deletion (deleted_from_source_at). + if _, err := db.Exec(`UPDATE messages SET deleted_from_source_at = CURRENT_TIMESTAMP WHERE id = ?`, msgID); err != nil { + t.Fatalf("set deleted_from_source_at: %v", err) + } + + results, err := engine.TextSearch(ctx, "hello", 10, 0) + if err != nil { + t.Fatalf("TextSearch after source delete: %v", err) + } + if len(results) != 0 { + t.Errorf("want 0 results after source delete, got %d", len(results)) + } +} diff --git a/internal/remote/engine.go b/internal/remote/engine.go index 7c85b84e..377a9679 100644 --- a/internal/remote/engine.go +++ b/internal/remote/engine.go @@ -595,9 +595,17 @@ func (e *Engine) Search(ctx context.Context, q *search.Query, limit, offset int) params.Set("limit", strconv.Itoa(limit)) // Forward filter-only fields that can't be represented in the - // query string syntax (AccountID, HideDeleted). - if q.AccountID != nil { - params.Set("source_id", strconv.FormatInt(*q.AccountID, 10)) + // query string syntax (AccountIDs, HideDeleted). The remote API + // accepts a single source_id; CLI/MCP layers reject collection + // scope in remote mode, but defend here against any future caller + // that bypasses those checks rather than silently dropping IDs. + if len(q.AccountIDs) > 1 { + return nil, fmt.Errorf( + "remote search does not support multi-account scope; "+ + "got %d account IDs", len(q.AccountIDs)) + } + if len(q.AccountIDs) == 1 { + params.Set("source_id", strconv.FormatInt(q.AccountIDs[0], 10)) } if q.HideDeleted { params.Set("hide_deleted", "true") diff --git a/internal/search/parser.go b/internal/search/parser.go index 8f5efd26..6a3f3156 100644 --- a/internal/search/parser.go +++ b/internal/search/parser.go @@ -22,7 +22,7 @@ type Query struct { AfterDate *time.Time // after: filter LargerThan *int64 // larger: filter (bytes) SmallerThan *int64 // smaller: filter (bytes) - AccountID *int64 // in: account filter + AccountIDs []int64 // in: account filter (one or more source IDs) HideDeleted bool // exclude messages where deleted_from_source_at IS NOT NULL } @@ -40,7 +40,7 @@ func (q *Query) IsEmpty() bool { q.AfterDate == nil && q.LargerThan == nil && q.SmallerThan == nil && - q.AccountID == nil + len(q.AccountIDs) == 0 } // operatorFn handles a parsed operator:value pair by applying it to the query. diff --git a/internal/search/parser_test.go b/internal/search/parser_test.go index 2628742a..24209384 100644 --- a/internal/search/parser_test.go +++ b/internal/search/parser_test.go @@ -403,12 +403,12 @@ func TestQuery_IsEmpty(t *testing.T) { }) } - t.Run("AccountID only", func(t *testing.T) { + t.Run("AccountIDs only", func(t *testing.T) { q := &Query{} id := int64(42) - q.AccountID = &id + q.AccountIDs = []int64{id} if q.IsEmpty() { - t.Error("IsEmpty() = true for query with AccountID set") + t.Error("IsEmpty() = true for query with AccountIDs set") } }) } diff --git a/internal/store/account_identities.go b/internal/store/account_identities.go new file mode 100644 index 00000000..843ca16c --- /dev/null +++ b/internal/store/account_identities.go @@ -0,0 +1,219 @@ +package store + +import ( + "database/sql" + "fmt" + "sort" + "strings" + "time" +) + +// AccountIdentity is one confirmed "me" address for one source. +type AccountIdentity struct { + SourceID int64 + Address string + SourceSignal string + ConfirmedAt time.Time +} + +// looksLikeEmail returns true for tokens that have the shape of an +// email address. Emails are matched case-insensitively in the identity +// store; other identifier shapes (phone E.164, Matrix MXIDs like +// "@user:server.org", Slack/IRC handles) preserve case. The check is: +// at least one "@" not at index 0 and the substring after the last "@" +// contains a ".". This excludes Matrix MXIDs (which start with "@") +// and bare handles, and accepts conventional emails. +func looksLikeEmail(addr string) bool { + at := strings.LastIndex(addr, "@") + if at <= 0 || at == len(addr)-1 { + return false + } + return strings.Contains(addr[at+1:], ".") +} + +// AddAccountIdentity confirms an identifier for one source. +// +// Behavior: +// - If (source_id, address) does not exist: insert with the given signal +// and confirmed_at = now. An empty signal inserts an empty source_signal. +// - If it exists and the signal is already in the row's source_signal set: +// no-op. +// - If it exists and the signal is not yet in the set: add it (set is kept +// sorted alphabetically, comma-delimited). confirmed_at is NOT updated; +// it records first confirmation. +// - Empty signal on an existing row: no-op (no new evidence to record). +// - All-whitespace identifier: no-op (returns nil). +// - Comma in signal: error. Comma is reserved as the in-column delimiter. +// +// The function trims the identifier; case is preserved (the identifier +// column accommodates email, phone E.164, and synthetic identifiers like +// chat handles where case can be significant). +// +// Read-modify-write inside a transaction. The single-writer SQLite model +// serializes commits within one process; cross-process concurrency is not +// a supported deployment. +func (s *Store) AddAccountIdentity(sourceID int64, address, signal string) error { + addr := strings.TrimSpace(address) + if addr == "" { + return nil + } + if strings.Contains(signal, ",") { + return fmt.Errorf("signal names cannot contain commas: %q", signal) + } + // Email-shaped tokens match case-insensitively to keep the + // add/remove paths symmetric. Synthetic identifiers (phones, + // Matrix MXIDs, chat handles) stay case-sensitive. The branch + // lives in identifierMatch — see identifier_match.go. + match := newIdentifierMatch(addr) + + return s.withTx(func(tx *loggedTx) error { + var existing string + err := tx.QueryRow( + `SELECT source_signal FROM account_identities + WHERE source_id = ? AND `+match.WhereClause("address"), + sourceID, match.BindValue(), + ).Scan(&existing) + switch { + case err == sql.ErrNoRows: + _, txErr := tx.Exec( + `INSERT INTO account_identities (source_id, address, source_signal) + VALUES (?, ?, ?)`, + sourceID, addr, signal, + ) + if txErr != nil { + return fmt.Errorf("insert account identity: %w", txErr) + } + return nil + case err != nil: + return fmt.Errorf("read existing source_signal: %w", err) + } + + merged := mergeSignalSet(existing, signal) + if merged == existing { + return nil + } + _, updateErr := tx.Exec( + `UPDATE account_identities SET source_signal = ? + WHERE source_id = ? AND `+match.WhereClause("address"), + merged, sourceID, match.BindValue(), + ) + if updateErr != nil { + return fmt.Errorf("update source_signal: %w", updateErr) + } + return nil + }) +} + +// mergeSignalSet returns the comma-joined sorted union of the existing +// signal set and the new signal. Empty strings (in either argument) are +// treated as the empty set. +func mergeSignalSet(existing, signal string) string { + set := make(map[string]struct{}) + if existing != "" { + for _, s := range strings.Split(existing, ",") { + if s != "" { + set[s] = struct{}{} + } + } + } + if signal != "" { + set[signal] = struct{}{} + } + if len(set) == 0 { + return "" + } + keys := make([]string, 0, len(set)) + for k := range set { + keys = append(keys, k) + } + sort.Strings(keys) + return strings.Join(keys, ",") +} + +// ListAccountIdentities returns all identities for one source, ordered by address. +func (s *Store) ListAccountIdentities(sourceID int64) ([]AccountIdentity, error) { + rows, err := s.db.Query(` + SELECT source_id, address, source_signal, confirmed_at + FROM account_identities + WHERE source_id = ? + ORDER BY address + `, sourceID) + if err != nil { + return nil, fmt.Errorf("list account identities: %w", err) + } + defer func() { _ = rows.Close() }() + + var out []AccountIdentity + for rows.Next() { + var ai AccountIdentity + if err := rows.Scan(&ai.SourceID, &ai.Address, &ai.SourceSignal, &ai.ConfirmedAt); err != nil { + return nil, fmt.Errorf("scan account identity: %w", err) + } + out = append(out, ai) + } + return out, rows.Err() +} + +// RemoveAccountIdentity deletes (source_id, address) rows that match +// under the helper's case-aware rule. Returns the number of rows +// deleted (typically 0 or 1, but can be >1 in legacy databases that +// hold case-variant duplicates pre-dating the case-folding work). +// +// Email-shaped identifiers match case-insensitively because email is +// case-insensitive in practice; this avoids the UX trap where a row +// was inserted as foo@x.com but the user types Foo@x.com on remove. +// Synthetic identifiers (Matrix MXIDs, chat handles, phone numbers) +// match case-sensitively because case can be significant there. The +// shape check is in looksLikeEmail. +func (s *Store) RemoveAccountIdentity(sourceID int64, address string) (int64, error) { + match := newIdentifierMatch(address) + res, err := s.db.Exec( + `DELETE FROM account_identities WHERE source_id = ? AND `+match.WhereClause("address"), + sourceID, match.BindValue(), + ) + if err != nil { + return 0, fmt.Errorf("remove account identity: %w", err) + } + n, err := res.RowsAffected() + if err != nil { + return 0, fmt.Errorf("rows affected: %w", err) + } + return n, nil +} + +// GetIdentitiesForScope returns the union of confirmed identifier addresses +// across the given source IDs. Empty input returns an empty map — no global +// default; an explicit empty scope means no identity matching. +// +// Identifiers are returned with the case the user stored. Callers comparing +// against email-shaped strings should lowercase both sides at compare time. +func (s *Store) GetIdentitiesForScope(sourceIDs []int64) (map[string]struct{}, error) { + out := make(map[string]struct{}) + if len(sourceIDs) == 0 { + return out, nil + } + + placeholders := make([]string, len(sourceIDs)) + args := make([]any, len(sourceIDs)) + for i, id := range sourceIDs { + placeholders[i] = "?" + args[i] = id + } + query := `SELECT address FROM account_identities WHERE source_id IN (` + + strings.Join(placeholders, ",") + `)` + + rows, err := s.db.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("get identities for scope: %w", err) + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var addr string + if err := rows.Scan(&addr); err != nil { + return nil, fmt.Errorf("scan identity address: %w", err) + } + out[addr] = struct{}{} + } + return out, rows.Err() +} diff --git a/internal/store/account_identities_test.go b/internal/store/account_identities_test.go new file mode 100644 index 00000000..b9c7205d --- /dev/null +++ b/internal/store/account_identities_test.go @@ -0,0 +1,422 @@ +package store_test + +import ( + "strings" + "testing" + "time" + + "github.com/wesm/msgvault/internal/testutil" + "github.com/wesm/msgvault/internal/testutil/storetest" +) + +func TestAddAndListAccountIdentities(t *testing.T) { + f := storetest.New(t) + st := f.Store + + if err := st.AddAccountIdentity(f.Source.ID, "me@example.com", "manual"); err != nil { + t.Fatalf("AddAccountIdentity: %v", err) + } + + ids, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities") + if len(ids) != 1 { + t.Fatalf("got %d identities, want 1", len(ids)) + } + got := ids[0] + if got.Address != "me@example.com" { + t.Errorf("address = %q, want me@example.com", got.Address) + } + if got.SourceSignal != "manual" { + t.Errorf("source_signal = %q, want manual", got.SourceSignal) + } + if got.SourceID != f.Source.ID { + t.Errorf("source_id = %d, want %d", got.SourceID, f.Source.ID) + } + if got.ConfirmedAt.IsZero() { + t.Error("confirmed_at should be set after first insert") + } +} + +func TestAddAccountIdentity_Idempotent(t *testing.T) { + f := storetest.New(t) + st := f.Store + + if err := st.AddAccountIdentity(f.Source.ID, "me@example.com", "manual"); err != nil { + t.Fatalf("AddAccountIdentity (1): %v", err) + } + ids1, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities (1)") + if len(ids1) != 1 { + t.Fatalf("after first insert: got %d rows, want 1", len(ids1)) + } + first := ids1[0].ConfirmedAt + + time.Sleep(2 * time.Millisecond) + + if err := st.AddAccountIdentity(f.Source.ID, "me@example.com", "manual"); err != nil { + t.Fatalf("AddAccountIdentity (2): %v", err) + } + ids2, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities (2)") + if len(ids2) != 1 { + t.Errorf("after idempotent re-add: got %d rows, want 1", len(ids2)) + } + if !ids2[0].ConfirmedAt.Equal(first) { + t.Errorf("confirmed_at moved on idempotent re-add: %v -> %v", + first, ids2[0].ConfirmedAt) + } +} + +// TestAddAccountIdentity_PreservesCase verifies that the first +// add of an email-shaped identifier wins the stored casing. Subsequent +// adds with different cases merge into the same row (case-insensitive +// match) rather than producing duplicate rows. This preserves the +// "case-preserved storage, email-case-insensitive logical identity" +// contract that the add/remove paths share. +func TestAddAccountIdentity_PreservesCase(t *testing.T) { + f := storetest.New(t) + st := f.Store + + if err := st.AddAccountIdentity(f.Source.ID, "Alice@Example.com", "manual"); err != nil { + t.Fatalf("AddAccountIdentity Alice: %v", err) + } + if err := st.AddAccountIdentity(f.Source.ID, "alice@example.com", "manual"); err != nil { + t.Fatalf("AddAccountIdentity alice: %v", err) + } + + rows, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities") + if len(rows) != 1 { + t.Fatalf("want 1 row (email is case-insensitive), got %d: %+v", len(rows), rows) + } + if rows[0].Address != "Alice@Example.com" { + t.Errorf("address = %q, want first-write 'Alice@Example.com' (case-preserved)", + rows[0].Address) + } +} + +func TestAddAccountIdentity_AdditionalSignal(t *testing.T) { + f := storetest.New(t) + st := f.Store + + if err := st.AddAccountIdentity(f.Source.ID, "alice@example.com", "manual"); err != nil { + t.Fatal(err) + } + rows1, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities") + first := rows1[0].ConfirmedAt + time.Sleep(2 * time.Millisecond) + + if err := st.AddAccountIdentity(f.Source.ID, "alice@example.com", "account-identifier"); err != nil { + t.Fatal(err) + } + rows2, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities after second signal") + if rows2[0].SourceSignal != "account-identifier,manual" { + t.Errorf("signal=%q want %q", rows2[0].SourceSignal, "account-identifier,manual") + } + if !rows2[0].ConfirmedAt.Equal(first) { + t.Errorf("confirmed_at moved on signal augment") + } +} + +func TestAddAccountIdentity_ThreeSignalAccumulation(t *testing.T) { + f := storetest.New(t) + st := f.Store + + for _, sig := range []string{"manual", "account-identifier", "is_from_me"} { + if err := st.AddAccountIdentity(f.Source.ID, "alice@example.com", sig); err != nil { + t.Fatal(err) + } + } + rows, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities") + if rows[0].SourceSignal != "account-identifier,is_from_me,manual" { + t.Errorf("signal=%q want account-identifier,is_from_me,manual", rows[0].SourceSignal) + } +} + +func TestAddAccountIdentity_EmptySignalOnExistingRow(t *testing.T) { + f := storetest.New(t) + st := f.Store + + if err := st.AddAccountIdentity(f.Source.ID, "alice@example.com", "manual"); err != nil { + t.Fatal(err) + } + if err := st.AddAccountIdentity(f.Source.ID, "alice@example.com", ""); err != nil { + t.Fatal(err) + } + rows, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities") + if rows[0].SourceSignal != "manual" { + t.Errorf("signal=%q want manual (empty signal on existing row is no-op)", rows[0].SourceSignal) + } +} + +func TestAddAccountIdentity_EmptySignalOnMissingRow(t *testing.T) { + f := storetest.New(t) + st := f.Store + + if err := st.AddAccountIdentity(f.Source.ID, "alice@example.com", ""); err != nil { + t.Fatal(err) + } + rows, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities") + if len(rows) != 1 || rows[0].SourceSignal != "" { + t.Fatalf("want one row with empty signal, got %+v", rows) + } + if rows[0].ConfirmedAt.IsZero() { + t.Error("confirmed_at should be set even with empty signal") + } +} + +func TestAddAccountIdentity_NonEmptySignalReplacesEmptyRow(t *testing.T) { + f := storetest.New(t) + st := f.Store + + if err := st.AddAccountIdentity(f.Source.ID, "alice@example.com", ""); err != nil { + t.Fatal(err) + } + if err := st.AddAccountIdentity(f.Source.ID, "alice@example.com", "manual"); err != nil { + t.Fatal(err) + } + rows, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities") + if rows[0].SourceSignal != "manual" { + t.Errorf("signal=%q want manual", rows[0].SourceSignal) + } +} + +func TestAddAccountIdentity_RejectsCommaInSignal(t *testing.T) { + f := storetest.New(t) + st := f.Store + + err := st.AddAccountIdentity(f.Source.ID, "alice@example.com", "a,b") + if err == nil { + t.Fatal("expected error for comma in signal") + } + if !strings.Contains(err.Error(), "comma") { + t.Errorf("error doesn't mention comma: %v", err) + } +} + +func TestAddAccountIdentity_AllWhitespaceIdentifierIsNoOp(t *testing.T) { + f := storetest.New(t) + st := f.Store + + if err := st.AddAccountIdentity(f.Source.ID, " ", "manual"); err != nil { + t.Fatal(err) + } + rows, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities") + if len(rows) != 0 { + t.Errorf("whitespace identifier should not insert, got %+v", rows) + } +} + +func TestAccountIdentities_FKCascadeOnSourceDelete(t *testing.T) { + f := storetest.New(t) + st := f.Store + + if err := st.AddAccountIdentity(f.Source.ID, "alice@example.com", "manual"); err != nil { + t.Fatal(err) + } + if err := st.RemoveSource(f.Source.ID); err != nil { + t.Fatal(err) + } + var n int + if err := st.DB().QueryRow( + `SELECT COUNT(*) FROM account_identities WHERE source_id = ?`, f.Source.ID, + ).Scan(&n); err != nil { + t.Fatal(err) + } + if n != 0 { + t.Errorf("FK cascade failed: %d rows remain", n) + } +} + +func TestGetIdentitiesForScope_MultiSource(t *testing.T) { + f := storetest.New(t) + st := f.Store + + src2, err := st.GetOrCreateSource("gmail", "other@example.com") + testutil.MustNoErr(t, err, "GetOrCreateSource") + + testutil.MustNoErr(t, st.AddAccountIdentity(f.Source.ID, "alice@example.com", "manual"), "add alice") + testutil.MustNoErr(t, st.AddAccountIdentity(src2.ID, "bob@example.com", "manual"), "add bob") + + scope, err := st.GetIdentitiesForScope([]int64{f.Source.ID, src2.ID}) + testutil.MustNoErr(t, err, "GetIdentitiesForScope") + + if len(scope) != 2 { + t.Fatalf("got %d addresses, want 2", len(scope)) + } + if _, ok := scope["alice@example.com"]; !ok { + t.Error("alice@example.com missing from scope") + } + if _, ok := scope["bob@example.com"]; !ok { + t.Error("bob@example.com missing from scope") + } +} + +func TestGetIdentitiesForScope_EmptyInput(t *testing.T) { + f := storetest.New(t) + st := f.Store + + testutil.MustNoErr(t, st.AddAccountIdentity(f.Source.ID, "me@example.com", "manual"), "add identity") + + scope, err := st.GetIdentitiesForScope([]int64{}) + testutil.MustNoErr(t, err, "GetIdentitiesForScope empty") + if scope == nil { + t.Error("expected non-nil map for empty scope") + } + if len(scope) != 0 { + t.Errorf("got %d entries, want 0 for empty scope", len(scope)) + } +} + +func TestRemoveAccountIdentity_Hit(t *testing.T) { + f := storetest.New(t) + st := f.Store + + testutil.MustNoErr(t, st.AddAccountIdentity(f.Source.ID, "alice@example.com", "manual"), "add identity") + removed, err := st.RemoveAccountIdentity(f.Source.ID, "alice@example.com") + testutil.MustNoErr(t, err, "RemoveAccountIdentity") + if removed != 1 { + t.Errorf("removed=%d, want 1", removed) + } + rows, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities") + if len(rows) != 0 { + t.Errorf("want empty, got %+v", rows) + } +} + +func TestRemoveAccountIdentity_Miss(t *testing.T) { + f := storetest.New(t) + st := f.Store + + removed, err := st.RemoveAccountIdentity(f.Source.ID, "nope@example.com") + testutil.MustNoErr(t, err, "RemoveAccountIdentity") + if removed != 0 { + t.Errorf("removed=%d, want 0 on miss", removed) + } +} + +// TestRemoveAccountIdentity_EmailIsCaseInsensitive verifies that an +// email-shaped identifier removed with different casing matches the +// stored row, since email addresses are case-insensitive in practice. +func TestRemoveAccountIdentity_EmailIsCaseInsensitive(t *testing.T) { + f := storetest.New(t) + st := f.Store + + testutil.MustNoErr(t, + st.AddAccountIdentity(f.Source.ID, "alice@Example.com", "manual"), + "add identity") + + removed, err := st.RemoveAccountIdentity(f.Source.ID, "ALICE@example.com") + testutil.MustNoErr(t, err, "RemoveAccountIdentity") + if removed != 1 { + t.Fatalf("removed=%d, want 1 (email match should be case-insensitive)", removed) + } + + rows, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities") + if len(rows) != 0 { + t.Errorf("want empty, got %+v", rows) + } +} + +// TestAddAccountIdentity_EmailIsCaseInsensitive verifies that a second +// add with different casing merges signals into the existing row +// instead of inserting a duplicate. This pairs with +// TestRemoveAccountIdentity_EmailIsCaseInsensitive: add/remove must +// agree on case-folding for "@"-shaped identifiers, otherwise an +// 'identity add Foo@x.com' followed by 'identity remove foo@x.com' +// could leave (or remove) the wrong row. +func TestAddAccountIdentity_EmailIsCaseInsensitive(t *testing.T) { + f := storetest.New(t) + st := f.Store + + testutil.MustNoErr(t, + st.AddAccountIdentity(f.Source.ID, "alice@example.com", "manual"), + "first add (lowercase)") + testutil.MustNoErr(t, + st.AddAccountIdentity(f.Source.ID, "ALICE@Example.com", "is_from_me"), + "second add (different case)") + + rows, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities") + if len(rows) != 1 { + t.Fatalf("len(rows) = %d, want 1 (case-folded merge expected); rows=%+v", len(rows), rows) + } + if rows[0].Address != "alice@example.com" { + t.Errorf("address = %q, want first-write 'alice@example.com'", rows[0].Address) + } + if !strings.Contains(rows[0].SourceSignal, "manual") || + !strings.Contains(rows[0].SourceSignal, "is_from_me") { + t.Errorf("source_signal = %q, want both 'manual' and 'is_from_me' merged", + rows[0].SourceSignal) + } +} + +// TestAddAccountIdentity_NonEmailStaysCaseSensitive guards the +// chat-handle invariant: synthetic identifiers can be case-significant +// so two distinct cases must produce two rows. +func TestAddAccountIdentity_NonEmailStaysCaseSensitive(t *testing.T) { + f := storetest.New(t) + st := f.Store + + testutil.MustNoErr(t, + st.AddAccountIdentity(f.Source.ID, "AliceHandle", "manual"), + "first add") + testutil.MustNoErr(t, + st.AddAccountIdentity(f.Source.ID, "alicehandle", "manual"), + "second add (different case)") + + rows, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities") + if len(rows) != 2 { + t.Fatalf("len(rows) = %d, want 2 distinct rows for non-email; rows=%+v", len(rows), rows) + } +} + +// TestAddAccountIdentity_MatrixMXIDStaysCaseSensitive guards against an +// over-broad email heuristic: Matrix MXIDs like "@user:server.org" start +// with "@" and contain a "." but are not emails. Two distinct cases must +// produce two distinct rows. +func TestAddAccountIdentity_MatrixMXIDStaysCaseSensitive(t *testing.T) { + f := storetest.New(t) + st := f.Store + + testutil.MustNoErr(t, + st.AddAccountIdentity(f.Source.ID, "@Alice:matrix.org", "manual"), + "first add (Matrix MXID, mixed case)") + testutil.MustNoErr(t, + st.AddAccountIdentity(f.Source.ID, "@alice:matrix.org", "manual"), + "second add (Matrix MXID, lower case)") + + rows, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities") + if len(rows) != 2 { + t.Fatalf("len(rows) = %d, want 2 distinct rows for Matrix MXID; rows=%+v", len(rows), rows) + } +} + +// TestRemoveAccountIdentity_NonEmailIsCaseSensitive guards the +// case-preserving path for synthetic identifiers (chat handles, etc.): +// removing with different casing on a non-email value must not match. +func TestRemoveAccountIdentity_NonEmailIsCaseSensitive(t *testing.T) { + f := storetest.New(t) + st := f.Store + + testutil.MustNoErr(t, + st.AddAccountIdentity(f.Source.ID, "AliceHandle", "manual"), + "add identity") + + removed, err := st.RemoveAccountIdentity(f.Source.ID, "alicehandle") + testutil.MustNoErr(t, err, "RemoveAccountIdentity") + if removed != 0 { + t.Fatalf("removed=%d, want 0 on case-mismatch for non-email identifier", removed) + } +} diff --git a/internal/store/api.go b/internal/store/api.go index 1793c0f8..2040dbaf 100644 --- a/internal/store/api.go +++ b/internal/store/api.go @@ -38,15 +38,19 @@ type APIAttachment struct { // ListMessages returns a paginated list of messages with batch-loaded recipients and labels. func (s *Store) ListMessages(offset, limit int) ([]APIMessage, int64, error) { - // Get total count + // Get total count. Use the canonical live-messages predicate so + // dedup-hidden rows (deleted_at) are excluded alongside source- + // deleted rows. var total int64 - err := s.db.QueryRow("SELECT COUNT(*) FROM messages WHERE deleted_from_source_at IS NULL").Scan(&total) + err := s.db.QueryRow( + "SELECT COUNT(*) FROM messages WHERE " + LiveMessagesWhere("", true), + ).Scan(&total) if err != nil { return nil, 0, err } // Query messages with sender info - query := ` + query := fmt.Sprintf(` SELECT m.id, COALESCE(m.conversation_id, 0) as conversation_id, @@ -59,10 +63,10 @@ func (s *Store) ListMessages(offset, limit int) ([]APIMessage, int64, error) { FROM messages m LEFT JOIN message_recipients mr ON mr.message_id = m.id AND mr.recipient_type = 'from' LEFT JOIN participants p ON p.id = mr.participant_id - WHERE m.deleted_from_source_at IS NULL + WHERE %s ORDER BY COALESCE(m.sent_at, m.received_at, m.internal_date) DESC LIMIT ? OFFSET ? - ` + `, LiveMessagesWhere("m", true)) rows, err := s.db.Query(query, limit, offset) if err != nil { @@ -210,8 +214,8 @@ func (s *Store) GetMessagesSummariesByIDs(ids []int64) ([]APIMessage, error) { FROM messages m LEFT JOIN message_recipients mr ON mr.message_id = m.id AND mr.recipient_type = 'from' LEFT JOIN participants p ON p.id = mr.participant_id - WHERE m.id IN (%s) AND m.deleted_from_source_at IS NULL - `, strings.Join(placeholders, ",")) + WHERE m.id IN (%s) AND %s + `, strings.Join(placeholders, ","), LiveMessagesWhere("m", true)) rows, err := s.db.Query(q, args...) if err != nil { return nil, fmt.Errorf("get message summaries: %w", err) @@ -262,10 +266,10 @@ func (s *Store) SearchMessages(query string, offset, limit int) ([]APIMessage, i %s LEFT JOIN message_recipients mr ON mr.message_id = m.id AND mr.recipient_type = 'from' LEFT JOIN participants p ON p.id = mr.participant_id - WHERE %s AND m.deleted_from_source_at IS NULL + WHERE %s AND %s ORDER BY %s LIMIT ? OFFSET ? - `, ftsJoin, ftsWhere, ftsOrder) + `, ftsJoin, ftsWhere, LiveMessagesWhere("m", true), ftsOrder) // Bind the search term once for WHERE, plus orderArgCount more times // for any ? placeholders the dialect put in the order-by fragment. @@ -298,8 +302,8 @@ func (s *Store) SearchMessages(query string, offset, limit int) ([]APIMessage, i SELECT COUNT(*) FROM messages m %s - WHERE %s AND m.deleted_from_source_at IS NULL - `, ftsJoin, ftsWhere) + WHERE %s AND %s + `, ftsJoin, ftsWhere, LiveMessagesWhere("m", true)) if err := s.db.QueryRow(countQuery, query).Scan(&total); err != nil { return nil, 0, fmt.Errorf("count FTS results: %w", err) } @@ -320,8 +324,7 @@ func (s *Store) SearchMessagesQuery( var conditions []string var args []interface{} - conditions = append(conditions, - "m.deleted_from_source_at IS NULL") + conditions = append(conditions, LiveMessagesWhere("m", true)) // FTS text terms. ftsEnabled is the authoritative signal that FTS is // active — ftsJoin may be empty on dialects (e.g. PostgreSQL) whose @@ -546,17 +549,17 @@ func escapeLike(s string) string { func (s *Store) searchMessagesLike(query string, offset, limit int) ([]APIMessage, int64, error) { likePattern := "%" + escapeLike(query) + "%" - countQuery := ` + countQuery := fmt.Sprintf(` SELECT COUNT(*) FROM messages - WHERE deleted_from_source_at IS NULL + WHERE %s AND (subject LIKE ? ESCAPE '\' OR snippet LIKE ? ESCAPE '\') - ` + `, LiveMessagesWhere("", true)) var total int64 if err := s.db.QueryRow(countQuery, likePattern, likePattern).Scan(&total); err != nil { return nil, 0, fmt.Errorf("count search results: %w", err) } - searchQuery := ` + searchQuery := fmt.Sprintf(` SELECT m.id, COALESCE(m.conversation_id, 0) as conversation_id, @@ -569,11 +572,11 @@ func (s *Store) searchMessagesLike(query string, offset, limit int) ([]APIMessag FROM messages m LEFT JOIN message_recipients mr ON mr.message_id = m.id AND mr.recipient_type = 'from' LEFT JOIN participants p ON p.id = mr.participant_id - WHERE m.deleted_from_source_at IS NULL + WHERE %s AND (m.subject LIKE ? ESCAPE '\' OR m.snippet LIKE ? ESCAPE '\') ORDER BY COALESCE(m.sent_at, m.received_at, m.internal_date) DESC LIMIT ? OFFSET ? - ` + `, LiveMessagesWhere("m", true)) rows, err := s.db.Query(searchQuery, likePattern, likePattern, limit, offset) if err != nil { diff --git a/internal/store/collection.go b/internal/store/collection.go new file mode 100644 index 00000000..6f44ef1d --- /dev/null +++ b/internal/store/collection.go @@ -0,0 +1,407 @@ +package store + +import ( + "database/sql" + "errors" + "fmt" + "strings" + "time" +) + +// Collection is a named grouping of sources that should be treated as +// a single logical archive. +type Collection struct { + ID int64 + Name string + Description string + CreatedAt time.Time +} + +// CollectionWithSources bundles a Collection with its member source +// IDs and a message-count aggregate. +type CollectionWithSources struct { + Collection + SourceIDs []int64 + MessageCount int64 +} + +// DefaultCollectionName is the always-present collection that mirrors +// every source. It is auto-managed by EnsureDefaultCollection on every +// schema init, so explicit add/remove/delete operations against it are +// rejected; the next CLI invocation would silently revert any change. +const DefaultCollectionName = "All" + +// ErrCollectionNotFound is returned when a collection lookup has no hits. +var ErrCollectionNotFound = errors.New("collection not found") + +// ErrCollectionImmutable is returned when an explicit mutation is +// attempted against the auto-managed default collection. +var ErrCollectionImmutable = errors.New( + `cannot modify the auto-managed "All" collection`, +) + +// EnsureDefaultCollection creates the auto-managed default collection +// if it doesn't exist and adds all current sources to it. Safe to call +// on every schema init. Mutations to this collection are rejected by +// AddSourcesToCollection / RemoveSourcesFromCollection / DeleteCollection +// so users don't get a silent revert on the next CLI invocation. +// +// Concurrency: the create step uses INSERT OR IGNORE (dialect-rewritten +// for PostgreSQL via dialect.InsertOrIgnore) followed by an unconditional +// SELECT, so two processes calling this at the same time both succeed — +// the second insert is ignored, both selects return the same row id. +// Earlier this used SELECT-then-INSERT, which raced when a CLI command +// and `serve` both initialised the schema simultaneously. +func (s *Store) EnsureDefaultCollection() error { + if _, err := s.db.Exec( + s.dialect.InsertOrIgnore( + `INSERT OR IGNORE INTO collections (name, description) + VALUES (?, 'All accounts')`, + ), + DefaultCollectionName, + ); err != nil { + return fmt.Errorf("create default collection: %w", err) + } + + var id int64 + if err := s.db.QueryRow( + `SELECT id FROM collections WHERE name = ?`, DefaultCollectionName, + ).Scan(&id); err != nil { + return fmt.Errorf("look up default collection id: %w", err) + } + + // Add all sources not already in it. + if _, err := s.db.Exec( + `INSERT OR IGNORE INTO collection_sources (collection_id, source_id) + SELECT ?, id FROM sources`, + id, + ); err != nil { + return fmt.Errorf("seed default collection membership: %w", err) + } + return nil +} + +// CreateCollection inserts a new collection with the given name, +// description, and member source IDs. +func (s *Store) CreateCollection( + name, description string, sourceIDs []int64, +) (*Collection, error) { + name = strings.TrimSpace(name) + if name == "" { + return nil, fmt.Errorf("collection name is required") + } + if name == DefaultCollectionName { + // Mirror the AddSourcesToCollection / RemoveSourcesFromCollection + // / DeleteCollection guards: the default collection is auto- + // managed by EnsureDefaultCollection. A manual create of "All" + // would have raced the auto-create; rejecting up front gives + // the consistent error surface as the rest of the collection + // surface. + return nil, ErrCollectionImmutable + } + if len(sourceIDs) == 0 { + return nil, fmt.Errorf( + "collection %q needs at least one source", name, + ) + } + + unique := uniqueInt64s(sourceIDs) + if err := s.validateSourceIDs(unique); err != nil { + return nil, err + } + + var created *Collection + err := s.withTx(func(tx *loggedTx) error { + res, err := tx.Exec( + `INSERT INTO collections (name, description) + VALUES (?, ?)`, + name, description, + ) + if err != nil { + if isSQLiteError(err, "UNIQUE constraint failed") { + return fmt.Errorf( + "collection %q already exists", name, + ) + } + return fmt.Errorf("insert collection: %w", err) + } + id, err := res.LastInsertId() + if err != nil { + return fmt.Errorf("last insert id: %w", err) + } + + for _, sid := range unique { + if _, err := tx.Exec( + `INSERT INTO collection_sources + (collection_id, source_id) + VALUES (?, ?)`, + id, sid, + ); err != nil { + return fmt.Errorf("link source %d: %w", sid, err) + } + } + + row := tx.QueryRow( + `SELECT id, name, description, created_at + FROM collections WHERE id = ?`, id, + ) + c, scanErr := scanCollection(row) + if scanErr != nil { + return scanErr + } + created = c + return nil + }) + if err != nil { + return nil, err + } + return created, nil +} + +// GetCollectionByName returns the collection with the given name and +// its member source IDs. +func (s *Store) GetCollectionByName( + name string, +) (*CollectionWithSources, error) { + row := s.db.QueryRow( + `SELECT id, name, description, created_at + FROM collections WHERE name = ?`, name, + ) + c, err := scanCollection(row) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrCollectionNotFound + } + return nil, err + } + return s.hydrateCollection(c) +} + +// ListCollections returns every collection with source IDs and +// message counts. +func (s *Store) ListCollections() ([]*CollectionWithSources, error) { + rows, err := s.db.Query( + `SELECT id, name, description, created_at + FROM collections ORDER BY name`, + ) + if err != nil { + return nil, fmt.Errorf("list collections: %w", err) + } + defer func() { _ = rows.Close() }() + + var collections []*Collection + for rows.Next() { + c, scanErr := scanCollection(rows) + if scanErr != nil { + return nil, scanErr + } + collections = append(collections, c) + } + if err := rows.Err(); err != nil { + return nil, err + } + + result := make([]*CollectionWithSources, 0, len(collections)) + for _, c := range collections { + hydrated, err := s.hydrateCollection(c) + if err != nil { + return nil, err + } + result = append(result, hydrated) + } + return result, nil +} + +// getCollectionID looks up a collection ID by name without hydrating. +func (s *Store) getCollectionID(name string) (int64, error) { + var id int64 + err := s.db.QueryRow( + `SELECT id FROM collections WHERE name = ?`, name, + ).Scan(&id) + if errors.Is(err, sql.ErrNoRows) { + return 0, ErrCollectionNotFound + } + if err != nil { + return 0, err + } + return id, nil +} + +// AddSourcesToCollection attaches sources to a collection. Idempotent. +// Rejects mutations of the auto-managed default collection. +func (s *Store) AddSourcesToCollection(name string, sourceIDs []int64) error { + if name == DefaultCollectionName { + return ErrCollectionImmutable + } + if err := s.validateSourceIDs(sourceIDs); err != nil { + return err + } + collID, err := s.getCollectionID(name) + if err != nil { + return err + } + return s.withTx(func(tx *loggedTx) error { + for _, sid := range sourceIDs { + if _, err := tx.Exec( + `INSERT OR IGNORE INTO collection_sources + (collection_id, source_id) + VALUES (?, ?)`, + collID, sid, + ); err != nil { + return fmt.Errorf("add source %d: %w", sid, err) + } + } + return nil + }) +} + +// RemoveSourcesFromCollection detaches sources. Idempotent. +// Rejects mutations of the auto-managed default collection. +func (s *Store) RemoveSourcesFromCollection(name string, sourceIDs []int64) error { + if name == DefaultCollectionName { + return ErrCollectionImmutable + } + if err := s.validateSourceIDs(sourceIDs); err != nil { + return err + } + collID, err := s.getCollectionID(name) + if err != nil { + return err + } + return s.withTx(func(tx *loggedTx) error { + for _, sid := range sourceIDs { + if _, err := tx.Exec( + `DELETE FROM collection_sources + WHERE collection_id = ? AND source_id = ?`, + collID, sid, + ); err != nil { + return fmt.Errorf("remove source %d: %w", sid, err) + } + } + return nil + }) +} + +// DeleteCollection drops the collection. Sources and messages untouched. +// Rejects deletion of the auto-managed default collection. +func (s *Store) DeleteCollection(name string) error { + if name == DefaultCollectionName { + return ErrCollectionImmutable + } + res, err := s.db.Exec( + `DELETE FROM collections WHERE name = ?`, name, + ) + if err != nil { + return fmt.Errorf("delete collection: %w", err) + } + n, _ := res.RowsAffected() + if n == 0 { + return ErrCollectionNotFound + } + return nil +} + +func (s *Store) hydrateCollection( + c *Collection, +) (*CollectionWithSources, error) { + rows, err := s.db.Query( + `SELECT source_id FROM collection_sources + WHERE collection_id = ? + ORDER BY source_id`, + c.ID, + ) + if err != nil { + return nil, fmt.Errorf("load sources for %s: %w", c.Name, err) + } + var sourceIDs []int64 + for rows.Next() { + var sid int64 + if err := rows.Scan(&sid); err != nil { + _ = rows.Close() + return nil, err + } + sourceIDs = append(sourceIDs, sid) + } + // Idiomatic ordering: check rows.Err() before closing so the + // iteration error is observed against an open rows handle. + if err := rows.Err(); err != nil { + _ = rows.Close() + return nil, err + } + _ = rows.Close() + + var count int64 + if len(sourceIDs) > 0 { + count, err = s.CountActiveMessages(sourceIDs...) + if err != nil { + return nil, err + } + } + + return &CollectionWithSources{ + Collection: *c, + SourceIDs: sourceIDs, + MessageCount: count, + }, nil +} + +func scanCollection(row interface { + Scan(dest ...any) error +}) (*Collection, error) { + var c Collection + if err := row.Scan( + &c.ID, &c.Name, &c.Description, &c.CreatedAt, + ); err != nil { + return nil, err + } + return &c, nil +} + +func (s *Store) validateSourceIDs(ids []int64) error { + if len(ids) == 0 { + return nil + } + placeholders := make([]string, len(ids)) + args := make([]any, len(ids)) + for i, id := range ids { + placeholders[i] = "?" + args[i] = id + } + query := "SELECT id FROM sources WHERE id IN (" + + strings.Join(placeholders, ",") + ")" + rows, err := s.db.Query(query, args...) + if err != nil { + return fmt.Errorf("validate source IDs: %w", err) + } + defer func() { _ = rows.Close() }() + found := make(map[int64]bool, len(ids)) + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return err + } + found[id] = true + } + if err := rows.Err(); err != nil { + return err + } + for _, id := range ids { + if !found[id] { + return fmt.Errorf("source %d not found", id) + } + } + return nil +} + +func uniqueInt64s(in []int64) []int64 { + seen := make(map[int64]bool, len(in)) + out := make([]int64, 0, len(in)) + for _, v := range in { + if seen[v] { + continue + } + seen[v] = true + out = append(out, v) + } + return out +} diff --git a/internal/store/collection_test.go b/internal/store/collection_test.go new file mode 100644 index 00000000..021aaa59 --- /dev/null +++ b/internal/store/collection_test.go @@ -0,0 +1,211 @@ +package store_test + +import ( + "testing" + + "github.com/wesm/msgvault/internal/store" + "github.com/wesm/msgvault/internal/testutil" + "github.com/wesm/msgvault/internal/testutil/storetest" +) + +func TestCollection_CRUD(t *testing.T) { + f := storetest.New(t) + st := f.Store + + src2, err := st.GetOrCreateSource("mbox", "backup@example.com") + testutil.MustNoErr(t, err, "GetOrCreateSource") + + // Create + coll, err := st.CreateCollection("work", "Work emails", []int64{f.Source.ID, src2.ID}) + testutil.MustNoErr(t, err, "CreateCollection") + if coll.Name != "work" { + t.Fatalf("name = %q, want work", coll.Name) + } + + // List — includes the auto-created "All" collection plus "work" + list, err := st.ListCollections() + testutil.MustNoErr(t, err, "ListCollections") + if len(list) != 2 { + t.Fatalf("list = %d, want 2", len(list)) + } + // Find "work" in the list and verify its sources. + var workColl *store.CollectionWithSources + for _, c := range list { + if c.Name == "work" { + workColl = c + break + } + } + if workColl == nil { + t.Fatal("expected 'work' collection in list") + } + if len(workColl.SourceIDs) != 2 { + t.Fatalf("sourceIDs = %d, want 2", len(workColl.SourceIDs)) + } + + // Get by name + got, err := st.GetCollectionByName("work") + testutil.MustNoErr(t, err, "GetCollectionByName") + if got.Name != "work" { + t.Fatalf("got name = %q", got.Name) + } + + // Not found + _, err = st.GetCollectionByName("nonexistent") + if err != store.ErrCollectionNotFound { + t.Fatalf("expected ErrCollectionNotFound, got %v", err) + } + + // Duplicate name rejected + _, err = st.CreateCollection("work", "", []int64{f.Source.ID}) + if err == nil { + t.Fatal("expected error for duplicate name") + } + + // Remove source + err = st.RemoveSourcesFromCollection("work", []int64{src2.ID}) + testutil.MustNoErr(t, err, "RemoveSourcesFromCollection") + got, err = st.GetCollectionByName("work") + testutil.MustNoErr(t, err, "GetCollectionByName after remove") + if len(got.SourceIDs) != 1 { + t.Fatalf("sourceIDs after remove = %d, want 1", len(got.SourceIDs)) + } + + // Add source back + err = st.AddSourcesToCollection("work", []int64{src2.ID}) + testutil.MustNoErr(t, err, "AddSourcesToCollection") + got, err = st.GetCollectionByName("work") + testutil.MustNoErr(t, err, "GetCollectionByName after add") + if len(got.SourceIDs) != 2 { + t.Fatalf("sourceIDs after add = %d, want 2", len(got.SourceIDs)) + } + + // Delete + err = st.DeleteCollection("work") + testutil.MustNoErr(t, err, "DeleteCollection") + _, err = st.GetCollectionByName("work") + if err != store.ErrCollectionNotFound { + t.Fatalf("expected not found after delete, got %v", err) + } +} + +func TestCollection_DefaultAll(t *testing.T) { + f := storetest.New(t) + st := f.Store + + err := st.EnsureDefaultCollection() + testutil.MustNoErr(t, err, "EnsureDefaultCollection") + + coll, err := st.GetCollectionByName("All") + testutil.MustNoErr(t, err, "GetCollectionByName All") + if coll.Name != "All" { + t.Fatalf("name = %q, want All", coll.Name) + } + // Should include the fixture's source + if len(coll.SourceIDs) < 1 { + t.Fatalf("All collection should have at least 1 source") + } + + // Idempotent + err = st.EnsureDefaultCollection() + testutil.MustNoErr(t, err, "EnsureDefaultCollection (2nd call)") +} + +func TestCollection_Validation(t *testing.T) { + f := storetest.New(t) + st := f.Store + + t.Run("empty name rejected", func(t *testing.T) { + _, err := st.CreateCollection("", "", []int64{f.Source.ID}) + if err == nil { + t.Fatal("expected error for empty name") + } + }) + + t.Run("zero sources rejected", func(t *testing.T) { + _, err := st.CreateCollection("empty", "", nil) + if err == nil { + t.Fatal("expected error for zero sources") + } + }) + + t.Run("nonexistent source rejected", func(t *testing.T) { + _, err := st.CreateCollection("bad", "", []int64{99999}) + if err == nil { + t.Fatal("expected error for nonexistent source") + } + }) + + t.Run("delete nonexistent returns error", func(t *testing.T) { + err := st.DeleteCollection("nonexistent") + if err != store.ErrCollectionNotFound { + t.Fatalf("expected ErrCollectionNotFound, got %v", err) + } + }) +} + +func TestCollection_Idempotent(t *testing.T) { + f := storetest.New(t) + st := f.Store + + _, err := st.CreateCollection("idem", "", []int64{f.Source.ID}) + testutil.MustNoErr(t, err, "CreateCollection") + + t.Run("add same source twice is no-op", func(t *testing.T) { + err := st.AddSourcesToCollection("idem", []int64{f.Source.ID}) + testutil.MustNoErr(t, err, "AddSourcesToCollection (dupe)") + coll, err := st.GetCollectionByName("idem") + testutil.MustNoErr(t, err, "GetCollectionByName") + if len(coll.SourceIDs) != 1 { + t.Fatalf("sourceIDs = %d, want 1", len(coll.SourceIDs)) + } + }) + + t.Run("remove absent source is no-op", func(t *testing.T) { + src2, err := st.GetOrCreateSource("mbox", "other@example.com") + testutil.MustNoErr(t, err, "GetOrCreateSource") + err = st.RemoveSourcesFromCollection("idem", []int64{src2.ID}) + testutil.MustNoErr(t, err, "RemoveSourcesFromCollection (absent)") + }) +} + +// TestCollection_DefaultAllIsImmutable verifies that explicit +// add/remove/delete on the auto-managed "All" collection are rejected +// with ErrCollectionImmutable. Otherwise the next EnsureDefaultCollection +// call would silently revert the change, surprising the user. +func TestCollection_DefaultAllIsImmutable(t *testing.T) { + f := storetest.New(t) + st := f.Store + + testutil.MustNoErr(t, st.EnsureDefaultCollection(), "EnsureDefaultCollection") + + if err := st.AddSourcesToCollection("All", []int64{f.Source.ID}); err != store.ErrCollectionImmutable { + t.Errorf("AddSourcesToCollection(All) = %v, want ErrCollectionImmutable", err) + } + if err := st.RemoveSourcesFromCollection("All", []int64{f.Source.ID}); err != store.ErrCollectionImmutable { + t.Errorf("RemoveSourcesFromCollection(All) = %v, want ErrCollectionImmutable", err) + } + if err := st.DeleteCollection("All"); err != store.ErrCollectionImmutable { + t.Errorf("DeleteCollection(All) = %v, want ErrCollectionImmutable", err) + } +} + +func TestCollection_DefaultAllIncremental(t *testing.T) { + f := storetest.New(t) + st := f.Store + + testutil.MustNoErr(t, st.EnsureDefaultCollection(), "EnsureDefaultCollection 1") + coll, err := st.GetCollectionByName("All") + testutil.MustNoErr(t, err, "GetCollectionByName") + initialCount := len(coll.SourceIDs) + + _, err = st.GetOrCreateSource("mbox", "new@example.com") + testutil.MustNoErr(t, err, "GetOrCreateSource") + + testutil.MustNoErr(t, st.EnsureDefaultCollection(), "EnsureDefaultCollection 2") + coll, err = st.GetCollectionByName("All") + testutil.MustNoErr(t, err, "GetCollectionByName after add") + if len(coll.SourceIDs) != initialCount+1 { + t.Errorf("sourceIDs = %d, want %d", len(coll.SourceIDs), initialCount+1) + } +} diff --git a/internal/store/dedup.go b/internal/store/dedup.go new file mode 100644 index 00000000..62e45ddc --- /dev/null +++ b/internal/store/dedup.go @@ -0,0 +1,625 @@ +package store + +import ( + "bytes" + "compress/zlib" + "database/sql" + "fmt" + "io" + "strings" + "time" + + "github.com/wesm/msgvault/internal/mime" +) + +// DuplicateGroupKey identifies a group of messages sharing the same +// RFC822 Message-ID. Lightweight return type for the store layer. +type DuplicateGroupKey struct { + RFC822MessageID string + Count int +} + +// DuplicateMessageRow holds metadata needed to select the survivor in a +// duplicate group. Lightweight return type for the store layer. +type DuplicateMessageRow struct { + ID int64 + SourceID int64 + SourceType string + SourceIdentifier string + SourceMessageID string + Subject string + SentAt time.Time + ArchivedAt time.Time + HasRawMIME bool + LabelCount int + IsFromMe bool + HasSentLabel bool // true if the message has the Gmail SENT label + // Raw From: address with original case preserved. The dedup engine + // normalizes via NormalizeIdentifierForCompare for identity-match + // sent detection, which is case-insensitive for email shapes and + // case-sensitive for synthetic identifiers (Matrix MXIDs, chat + // handles). + FromEmail string +} + +// MergeResult holds the counts from a MergeDuplicates operation. +type MergeResult struct { + LabelsTransferred int + RawMIMEBackfilled int +} + +// ContentHashCandidate holds message metadata for raw-MIME hash scans. +type ContentHashCandidate struct { + ID int64 + SourceID int64 + SourceType string + SourceIdentifier string + SourceMessageID string + Subject string + SentAt time.Time + ArchivedAt time.Time + LabelCount int + IsFromMe bool + HasSentLabel bool + FromEmail string +} + +func (s *Store) FindDuplicatesByRFC822ID(sourceIDs ...int64) ([]DuplicateGroupKey, error) { + query := ` + SELECT rfc822_message_id, COUNT(*) AS cnt + FROM messages + WHERE rfc822_message_id IS NOT NULL + AND rfc822_message_id != '' + AND ` + LiveMessagesWhere("", true) + var args []any + if len(sourceIDs) > 0 { + placeholders := make([]string, len(sourceIDs)) + for i, id := range sourceIDs { + placeholders[i] = "?" + args = append(args, id) + } + query += " AND source_id IN (" + strings.Join(placeholders, ",") + ")" + } + query += ` + GROUP BY rfc822_message_id + HAVING cnt > 1` + + rows, err := s.db.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("find duplicates by rfc822 id: %w", err) + } + defer func() { _ = rows.Close() }() + + var groups []DuplicateGroupKey + for rows.Next() { + var g DuplicateGroupKey + if err := rows.Scan(&g.RFC822MessageID, &g.Count); err != nil { + return nil, err + } + groups = append(groups, g) + } + return groups, rows.Err() +} + +func (s *Store) GetDuplicateGroupMessages( + rfc822ID string, sourceIDs ...int64, +) ([]DuplicateMessageRow, error) { + query := ` + SELECT m.id, m.source_id, s.source_type, s.identifier, + m.source_message_id, + COALESCE(m.subject, ''), m.sent_at, m.archived_at, + (CASE WHEN mr.message_id IS NOT NULL THEN 1 ELSE 0 END) AS has_raw, + (SELECT COUNT(*) FROM message_labels ml + WHERE ml.message_id = m.id) AS label_count, + COALESCE(m.is_from_me, 0) AS is_from_me, + CAST(EXISTS ( + SELECT 1 FROM message_labels ml2 + JOIN labels l ON l.id = ml2.label_id + WHERE ml2.message_id = m.id + AND (l.source_label_id = 'SENT' OR UPPER(l.name) = 'SENT') + ) AS INTEGER) AS has_sent_label, + COALESCE(( + SELECT p_from.email_address + FROM message_recipients mr_from + JOIN participants p_from + ON p_from.id = mr_from.participant_id + WHERE mr_from.message_id = m.id + AND mr_from.recipient_type = 'from' + LIMIT 1 + ), '') AS from_email + FROM messages m + JOIN sources s ON s.id = m.source_id + LEFT JOIN message_raw mr ON mr.message_id = m.id + WHERE m.rfc822_message_id = ? AND ` + LiveMessagesWhere("m", true) + args := []any{rfc822ID} + if len(sourceIDs) > 0 { + placeholders := make([]string, len(sourceIDs)) + for i, id := range sourceIDs { + placeholders[i] = "?" + args = append(args, id) + } + query += " AND m.source_id IN (" + strings.Join(placeholders, ",") + ")" + } + query += " ORDER BY m.id" + + rows, err := s.db.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("get duplicate group messages: %w", err) + } + defer func() { _ = rows.Close() }() + + var msgs []DuplicateMessageRow + for rows.Next() { + var dm DuplicateMessageRow + var sentAt, archivedAt sql.NullTime + var hasRaw, isFromMe, hasSent int + if err := rows.Scan( + &dm.ID, &dm.SourceID, &dm.SourceType, &dm.SourceIdentifier, + &dm.SourceMessageID, &dm.Subject, &sentAt, &archivedAt, + &hasRaw, &dm.LabelCount, &isFromMe, &hasSent, + &dm.FromEmail, + ); err != nil { + return nil, err + } + if sentAt.Valid { + dm.SentAt = sentAt.Time + } + if archivedAt.Valid { + dm.ArchivedAt = archivedAt.Time + } + dm.HasRawMIME = hasRaw == 1 + dm.IsFromMe = isFromMe == 1 + dm.HasSentLabel = hasSent == 1 + msgs = append(msgs, dm) + } + return msgs, rows.Err() +} + +func (s *Store) MergeDuplicates( + survivorID int64, duplicateIDs []int64, batchID string, +) (*MergeResult, error) { + if len(duplicateIDs) == 0 { + return &MergeResult{}, nil + } + + result := &MergeResult{} + unionLabelsSQL := s.dialect.InsertOrIgnore(`INSERT OR IGNORE INTO message_labels (message_id, label_id) + SELECT ?, label_id FROM message_labels WHERE message_id = ?`) + backfillRawSQL := s.dialect.InsertOrIgnore(`INSERT OR IGNORE INTO message_raw + (message_id, raw_data, raw_format, compression) + SELECT ?, raw_data, raw_format, compression + FROM message_raw WHERE message_id = ?`) + softDeleteSQL := fmt.Sprintf(`UPDATE messages + SET deleted_at = %s, delete_batch_id = ? + WHERE id = ?`, s.dialect.Now()) + + err := s.withTx(func(tx *loggedTx) error { + for _, dupID := range duplicateIDs { + res, err := tx.Exec(unionLabelsSQL, survivorID, dupID) + if err != nil { + return fmt.Errorf("union labels from %d: %w", dupID, err) + } + affected, _ := res.RowsAffected() + result.LabelsTransferred += int(affected) + } + + var survivorHasRaw int + if err := tx.QueryRow( + `SELECT COUNT(*) FROM message_raw WHERE message_id = ?`, + survivorID, + ).Scan(&survivorHasRaw); err != nil { + return fmt.Errorf("check survivor raw MIME: %w", err) + } + if survivorHasRaw == 0 { + for _, dupID := range duplicateIDs { + res, err := tx.Exec(backfillRawSQL, survivorID, dupID) + if err != nil { + return fmt.Errorf("backfill raw MIME from %d: %w", dupID, err) + } + affected, _ := res.RowsAffected() + if affected > 0 { + result.RawMIMEBackfilled += int(affected) + break + } + } + } + + for _, dupID := range duplicateIDs { + if _, err := tx.Exec(softDeleteSQL, batchID, dupID); err != nil { + return fmt.Errorf("soft-delete duplicate %d: %w", dupID, err) + } + } + return nil + }) + return result, err +} + +func (s *Store) GetAllRawMIMECandidates( + sourceIDs ...int64, +) ([]ContentHashCandidate, error) { + query := ` + SELECT m.id, m.source_id, s.source_type, s.identifier, + m.source_message_id, + COALESCE(m.subject, ''), m.sent_at, m.archived_at, + (SELECT COUNT(*) FROM message_labels ml + WHERE ml.message_id = m.id) AS label_count, + COALESCE(m.is_from_me, 0) AS is_from_me, + CAST(EXISTS ( + SELECT 1 FROM message_labels ml2 + JOIN labels l ON l.id = ml2.label_id + WHERE ml2.message_id = m.id + AND (l.source_label_id = 'SENT' OR UPPER(l.name) = 'SENT') + ) AS INTEGER) AS has_sent_label, + COALESCE(( + SELECT p_from.email_address + FROM message_recipients mr_from + JOIN participants p_from + ON p_from.id = mr_from.participant_id + WHERE mr_from.message_id = m.id + AND mr_from.recipient_type = 'from' + LIMIT 1 + ), '') AS from_email + FROM messages m + JOIN sources s ON s.id = m.source_id + JOIN message_raw mr ON mr.message_id = m.id + WHERE ` + LiveMessagesWhere("m", true) + var args []any + if len(sourceIDs) > 0 { + placeholders := make([]string, len(sourceIDs)) + for i, id := range sourceIDs { + placeholders[i] = "?" + args = append(args, id) + } + query += " AND m.source_id IN (" + strings.Join(placeholders, ",") + ")" + } + query += " ORDER BY m.id" + + rows, err := s.db.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("get all raw MIME candidates: %w", err) + } + defer func() { _ = rows.Close() }() + + var candidates []ContentHashCandidate + for rows.Next() { + var c ContentHashCandidate + var sentAt, archivedAt sql.NullTime + var isFromMe, hasSent int + if err := rows.Scan( + &c.ID, &c.SourceID, &c.SourceType, &c.SourceIdentifier, + &c.SourceMessageID, &c.Subject, &sentAt, &archivedAt, + &c.LabelCount, &isFromMe, &hasSent, &c.FromEmail, + ); err != nil { + return nil, err + } + if sentAt.Valid { + c.SentAt = sentAt.Time + } + if archivedAt.Valid { + c.ArchivedAt = archivedAt.Time + } + c.IsFromMe = isFromMe == 1 + c.HasSentLabel = hasSent == 1 + candidates = append(candidates, c) + } + return candidates, rows.Err() +} + +func (s *Store) StreamMessageRaw( + messageIDs []int64, + fn func(messageID int64, rawData []byte, compression string), +) error { + const chunkSize = 500 + for start := 0; start < len(messageIDs); start += chunkSize { + end := min(start+chunkSize, len(messageIDs)) + chunk := messageIDs[start:end] + + placeholders := make([]string, len(chunk)) + args := make([]any, len(chunk)) + for i, id := range chunk { + placeholders[i] = "?" + args[i] = id + } + + query := "SELECT message_id, raw_data, compression FROM message_raw WHERE message_id IN (" + + strings.Join(placeholders, ",") + ")" + rows, err := s.db.Query(query, args...) + if err != nil { + return fmt.Errorf("stream message raw: %w", err) + } + + for rows.Next() { + var msgID int64 + var rawData []byte + var compression sql.NullString + if err := rows.Scan(&msgID, &rawData, &compression); err != nil { + _ = rows.Close() + return err + } + comp := "" + if compression.Valid { + comp = compression.String + } + fn(msgID, rawData, comp) + } + if err := rows.Err(); err != nil { + _ = rows.Close() + return err + } + _ = rows.Close() + } + return nil +} + +// UndoDedup restores soft-deleted duplicates from a dedup batch by +// clearing deleted_at and delete_batch_id. Merge side effects (labels +// copied to survivors, raw MIME backfilled onto survivors) are not +// reversed — those changes are additive enrichment that leaves +// survivors strictly better off. +func (s *Store) UndoDedup(batchID string) (int64, error) { + result, err := s.db.Exec(` + UPDATE messages + SET deleted_at = NULL, delete_batch_id = NULL + WHERE delete_batch_id = ? + `, batchID) + if err != nil { + return 0, fmt.Errorf("undo dedup: %w", err) + } + return result.RowsAffected() +} + +// DeleteDedupedBatch permanently deletes all hidden rows associated with a +// dedup batch. Only deletes rows where deleted_at IS NOT NULL AND +// delete_batch_id = batchID. Returns the number of rows deleted. +// +// This is irreversible. Caller is responsible for backups. +// Attachments cascade-delete from the metadata row; on-disk blobs are +// content-addressed and survive until separate cleanup. +func (s *Store) DeleteDedupedBatch(batchID string) (int64, error) { + result, err := s.db.Exec(` + DELETE FROM messages + WHERE delete_batch_id = ? AND deleted_at IS NOT NULL + `, batchID) + if err != nil { + return 0, fmt.Errorf("delete dedup batch %q: %w", batchID, err) + } + return result.RowsAffected() +} + +// DeleteAllDeduped permanently deletes every dedup-hidden row regardless of +// batch. Returns the number of rows deleted and the number of distinct +// batches affected. +// +// The delete is gated on the positive marker `delete_batch_id IS NOT NULL` +// in addition to `deleted_at IS NOT NULL` so that the contract is "permanently +// remove rows the dedup pipeline soft-hid." If a future feature ever adds +// another soft-delete semantics that writes deleted_at without a batch ID +// (e.g. a "trash" view, a per-message user hide), this command will leave +// those rows alone — they are not dedup-hidden and have no business being +// purged by the local dedup hard-delete rung. +// +// This is irreversible. Caller is responsible for backups. +// Attachments cascade-delete from the metadata row; on-disk blobs are +// content-addressed and survive until separate cleanup. +func (s *Store) DeleteAllDeduped() (deleted int64, distinctBatches int64, err error) { + committed := false + tx, err := s.db.Begin() + if err != nil { + return 0, 0, fmt.Errorf("delete all dedup-hidden: begin tx: %w", err) + } + defer func() { + if !committed { + _ = tx.Rollback() + } + }() + + if err = tx.QueryRow(` + SELECT COUNT(DISTINCT delete_batch_id) + FROM messages + WHERE deleted_at IS NOT NULL AND delete_batch_id IS NOT NULL + `).Scan(&distinctBatches); err != nil { + return 0, 0, fmt.Errorf("delete all dedup-hidden: count batches: %w", err) + } + + result, err := tx.Exec(` + DELETE FROM messages + WHERE deleted_at IS NOT NULL AND delete_batch_id IS NOT NULL + `) + if err != nil { + return 0, 0, fmt.Errorf("delete all dedup-hidden: delete: %w", err) + } + deleted, err = result.RowsAffected() + if err != nil { + return 0, 0, fmt.Errorf("delete all dedup-hidden: rows affected: %w", err) + } + + if err = tx.Commit(); err != nil { + return 0, 0, fmt.Errorf("delete all dedup-hidden: commit: %w", err) + } + committed = true + return deleted, distinctBatches, nil +} + +func (s *Store) CountActiveMessages(sourceIDs ...int64) (int64, error) { + query := "SELECT COUNT(*) FROM messages WHERE " + LiveMessagesWhere("", true) + var args []any + if len(sourceIDs) > 0 { + placeholders := make([]string, len(sourceIDs)) + for i, id := range sourceIDs { + placeholders[i] = "?" + args = append(args, id) + } + query += " AND source_id IN (" + strings.Join(placeholders, ",") + ")" + } + var count int64 + err := s.db.QueryRow(query, args...).Scan(&count) + return count, err +} + +func (s *Store) CountMessagesWithoutRFC822ID(sourceIDs ...int64) (int64, error) { + q := `SELECT COUNT(*) FROM messages m + JOIN message_raw mr ON mr.message_id = m.id + WHERE (m.rfc822_message_id IS NULL OR m.rfc822_message_id = '') + AND ` + LiveMessagesWhere("m", true) + var args []any + if len(sourceIDs) > 0 { + placeholders := make([]string, len(sourceIDs)) + for i, id := range sourceIDs { + placeholders[i] = "?" + args = append(args, id) + } + q += " AND m.source_id IN (" + strings.Join(placeholders, ",") + ")" + } + var count int64 + err := s.db.QueryRow(q, args...).Scan(&count) + return count, err +} + +func (s *Store) BackfillRFC822IDs( + sourceIDs []int64, + progress func(done, total int64), +) (updated int64, failed int64, err error) { + scopeClause := "" + var scopeArgs []any + if len(sourceIDs) > 0 { + placeholders := make([]string, len(sourceIDs)) + for i, id := range sourceIDs { + placeholders[i] = "?" + scopeArgs = append(scopeArgs, id) + } + scopeClause = " AND m.source_id IN (" + strings.Join(placeholders, ",") + ")" + } + + var total int64 + countQ := `SELECT COUNT(*) FROM messages m + JOIN message_raw mr ON mr.message_id = m.id + WHERE (m.rfc822_message_id IS NULL OR m.rfc822_message_id = '') + AND ` + LiveMessagesWhere("m", true) + scopeClause + err = s.db.QueryRow(countQ, scopeArgs...).Scan(&total) + if err != nil { + return 0, 0, fmt.Errorf("count backfill candidates: %w", err) + } + if total == 0 { + return 0, 0, nil + } + + const batchSize = 1000 + lastID := int64(0) + + for { + batchQ := `SELECT m.id FROM messages m + JOIN message_raw mr ON mr.message_id = m.id + WHERE (m.rfc822_message_id IS NULL OR m.rfc822_message_id = '') + AND ` + LiveMessagesWhere("m", true) + ` + AND m.id > ?` + scopeClause + ` + ORDER BY m.id + LIMIT ?` + batchArgs := append([]any{lastID}, scopeArgs...) + batchArgs = append(batchArgs, batchSize) + rows, err := s.db.Query(batchQ, batchArgs...) + if err != nil { + return updated, failed, fmt.Errorf("fetch backfill batch: %w", err) + } + + var batchIDs []int64 + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + _ = rows.Close() + return updated, failed, err + } + batchIDs = append(batchIDs, id) + } + _ = rows.Close() + if err := rows.Err(); err != nil { + return updated, failed, err + } + if len(batchIDs) == 0 { + break + } + + updates := make([]struct { + id int64 + normalizedID string + }, 0, len(batchIDs)) + // Failed rows are not updated and are not retried in this run: + // because lastID advances past every batch element below, the + // next batch query (m.id > lastID) skips them. The selection + // filter (rfc822_message_id IS NULL OR '') will pick them up + // again on the next BackfillRFC822IDs invocation. + seen := make(map[int64]bool, len(batchIDs)) + streamErr := s.StreamMessageRaw(batchIDs, func(id int64, rawData []byte, compression string) { + seen[id] = true + raw := rawData + if compression == "zlib" { + r, err := zlib.NewReader(bytes.NewReader(rawData)) + if err != nil { + failed++ + return + } + decompressed, err := io.ReadAll(r) + _ = r.Close() + if err != nil { + failed++ + return + } + raw = decompressed + } + parsed, err := mime.Parse(raw) + if err != nil || parsed.MessageID == "" { + failed++ + return + } + normalizedID := strings.TrimSpace(parsed.MessageID) + normalizedID = strings.Trim(normalizedID, "<>") + if normalizedID == "" { + failed++ + return + } + updates = append(updates, struct { + id int64 + normalizedID string + }{ + id: id, + normalizedID: normalizedID, + }) + }) + if streamErr != nil { + return updated, failed, fmt.Errorf("stream raw for backfill batch: %w", streamErr) + } + // Rows whose message_raw row went missing between the batch + // SELECT and the stream are counted as failed so totals reconcile. + for _, id := range batchIDs { + if !seen[id] { + failed++ + } + } + + var batchUpdated int64 + err = s.withTx(func(tx *loggedTx) error { + for _, update := range updates { + if _, err := tx.Exec( + "UPDATE messages SET rfc822_message_id = ? WHERE id = ?", + update.normalizedID, update.id, + ); err != nil { + return fmt.Errorf("update message %d: %w", update.id, err) + } + batchUpdated++ + } + return nil + }) + if err != nil { + return updated, failed, fmt.Errorf( + "apply backfill batch ending at %d: %w", + batchIDs[len(batchIDs)-1], err, + ) + } + updated += batchUpdated + + lastID = batchIDs[len(batchIDs)-1] + if progress != nil { + progress(updated+failed, total) + } + } + return updated, failed, nil +} diff --git a/internal/store/dedup_delete_test.go b/internal/store/dedup_delete_test.go new file mode 100644 index 00000000..a584bece --- /dev/null +++ b/internal/store/dedup_delete_test.go @@ -0,0 +1,185 @@ +package store_test + +import ( + "testing" + + "github.com/wesm/msgvault/internal/testutil" + "github.com/wesm/msgvault/internal/testutil/storetest" +) + +// TestDeleteDedupedBatch_DeletesHiddenRows verifies that DeleteDedupedBatch removes only the +// rows associated with the given batch ID and that ON DELETE CASCADE removes +// child rows (message_labels). +func TestDeleteDedupedBatch_DeletesHiddenRows(t *testing.T) { + f := storetest.New(t) + idKeep := newRFC822Message(t, f, "keep", "rfc822-delete-a") + idDrop := newRFC822Message(t, f, "drop", "rfc822-delete-a") + + labels := f.EnsureLabels( + map[string]string{"INBOX": "Inbox", "SENT": "Sent"}, "system", + ) + testutil.MustNoErr(t, f.Store.LinkMessageLabel(idDrop, labels["INBOX"]), "link INBOX") + testutil.MustNoErr(t, f.Store.LinkMessageLabel(idDrop, labels["SENT"]), "link SENT") + + _, err := f.Store.MergeDuplicates(idKeep, []int64{idDrop}, "batch-delete") + testutil.MustNoErr(t, err, "MergeDuplicates") + + // idDrop should be hidden before delete. + assertDedupDeleted(t, f.Store, idDrop, true) + + deleted, err := f.Store.DeleteDedupedBatch("batch-delete") + testutil.MustNoErr(t, err, "DeleteDedupedBatch") + if deleted != 1 { + t.Errorf("DeleteDedupedBatch deleted = %d, want 1", deleted) + } + + // Row should be gone. + var count int + err = f.Store.DB().QueryRow( + "SELECT COUNT(*) FROM messages WHERE id = ?", idDrop, + ).Scan(&count) + testutil.MustNoErr(t, err, "query messages after delete") + if count != 0 { + t.Errorf("message %d still present after delete", idDrop) + } + + // Child message_labels rows should cascade-delete. + err = f.Store.DB().QueryRow( + "SELECT COUNT(*) FROM message_labels WHERE message_id = ?", idDrop, + ).Scan(&count) + testutil.MustNoErr(t, err, "query message_labels after delete") + if count != 0 { + t.Errorf("message_labels for %d still present after delete (%d rows)", idDrop, count) + } + + // Survivor should be untouched. + assertDedupDeleted(t, f.Store, idKeep, false) +} + +// TestDeleteDedupedBatch_UnknownBatch verifies that DeleteDedupedBatch with a non-existent +// batch ID returns 0 without error. +func TestDeleteDedupedBatch_UnknownBatch(t *testing.T) { + f := storetest.New(t) + _ = newRFC822Message(t, f, "msg-a", "rfc822-only") + + deleted, err := f.Store.DeleteDedupedBatch("no-such-batch") + testutil.MustNoErr(t, err, "DeleteDedupedBatch unknown batch") + if deleted != 0 { + t.Errorf("DeleteDedupedBatch deleted = %d, want 0", deleted) + } +} + +// TestDeleteAllDeduped_MultiplesBatches verifies that DeleteAllDeduped removes +// rows from all batches and reports the correct counts. +func TestDeleteAllDeduped_MultipleBatches(t *testing.T) { + f := storetest.New(t) + + // batch-alpha hides one message + idKeepA := newRFC822Message(t, f, "keep-a", "rfc822-multi-a") + idDropA := newRFC822Message(t, f, "drop-a", "rfc822-multi-a") + _, err := f.Store.MergeDuplicates(idKeepA, []int64{idDropA}, "batch-alpha") + testutil.MustNoErr(t, err, "MergeDuplicates alpha") + + // batch-beta hides one message + idKeepB := newRFC822Message(t, f, "keep-b", "rfc822-multi-b") + idDropB := newRFC822Message(t, f, "drop-b", "rfc822-multi-b") + _, err = f.Store.MergeDuplicates(idKeepB, []int64{idDropB}, "batch-beta") + testutil.MustNoErr(t, err, "MergeDuplicates beta") + + deleted, batches, err := f.Store.DeleteAllDeduped() + testutil.MustNoErr(t, err, "DeleteAllDeduped") + if deleted != 2 { + t.Errorf("DeleteAllDeduped deleted = %d, want 2", deleted) + } + if batches != 2 { + t.Errorf("DeleteAllDeduped distinctBatches = %d, want 2", batches) + } + + // All four messages should still exist (survivors untouched). + var count int + err = f.Store.DB().QueryRow("SELECT COUNT(*) FROM messages").Scan(&count) + testutil.MustNoErr(t, err, "count messages after DeleteAllDeduped") + if count != 2 { + t.Errorf("messages count = %d, want 2 (survivors only)", count) + } +} + +// TestDeleteAllDeduped_PreservesBatchlessSoftDelete verifies that a row with +// deleted_at set but no delete_batch_id is *not* purged by DeleteAllDeduped. +// The contract is "permanently remove rows the dedup pipeline soft-hid", +// keyed on the positive delete_batch_id marker. A future feature that writes +// deleted_at for any other reason (trash view, per-message hide) must not +// have its rows silently destroyed by the dedup hard-delete rung. +func TestDeleteAllDeduped_PreservesBatchlessSoftDelete(t *testing.T) { + f := storetest.New(t) + + // One real dedup batch — should be purged. + idKeep := newRFC822Message(t, f, "keep", "rfc822-batchless") + idDrop := newRFC822Message(t, f, "drop", "rfc822-batchless") + _, err := f.Store.MergeDuplicates(idKeep, []int64{idDrop}, "batch-real") + testutil.MustNoErr(t, err, "MergeDuplicates") + + // One row soft-hidden without a batch ID — simulates a future + // non-dedup soft-delete writer. Should survive DeleteAllDeduped. + idBatchless := newRFC822Message(t, f, "batchless", "rfc822-other") + _, err = f.Store.DB().Exec( + "UPDATE messages SET deleted_at = CURRENT_TIMESTAMP WHERE id = ?", + idBatchless, + ) + testutil.MustNoErr(t, err, "set batchless deleted_at") + + deleted, batches, err := f.Store.DeleteAllDeduped() + testutil.MustNoErr(t, err, "DeleteAllDeduped") + if deleted != 1 { + t.Errorf("DeleteAllDeduped deleted = %d, want 1 (only the batched row)", deleted) + } + if batches != 1 { + t.Errorf("DeleteAllDeduped distinctBatches = %d, want 1", batches) + } + + // The batchless row must still exist after the purge. + var count int + err = f.Store.DB().QueryRow( + "SELECT COUNT(*) FROM messages WHERE id = ?", idBatchless, + ).Scan(&count) + testutil.MustNoErr(t, err, "query batchless row after delete") + if count != 1 { + t.Errorf("batchless soft-deleted row %d was purged; DeleteAllDeduped must only touch dedup-batched rows", idBatchless) + } +} + +// TestDeleteAllDeduped_Empty verifies that DeleteAllDeduped with no hidden rows +// returns 0/0 without error. +func TestDeleteAllDeduped_Empty(t *testing.T) { + f := storetest.New(t) + _ = newRFC822Message(t, f, "visible", "rfc822-vis") + + deleted, batches, err := f.Store.DeleteAllDeduped() + testutil.MustNoErr(t, err, "DeleteAllDeduped empty") + if deleted != 0 { + t.Errorf("deleted = %d, want 0", deleted) + } + if batches != 0 { + t.Errorf("distinctBatches = %d, want 0", batches) + } +} + +// TestDeleteDedupedBatch_ThenUndoNoOps verifies that calling UndoDedup after DeleteDedupedBatch +// returns 0 (the rows no longer exist) without error. +func TestDeleteDedupedBatch_ThenUndoNoOps(t *testing.T) { + f := storetest.New(t) + idKeep := newRFC822Message(t, f, "keep", "rfc822-undo-noop") + idDrop := newRFC822Message(t, f, "drop", "rfc822-undo-noop") + + _, err := f.Store.MergeDuplicates(idKeep, []int64{idDrop}, "batch-noop") + testutil.MustNoErr(t, err, "MergeDuplicates") + + _, err = f.Store.DeleteDedupedBatch("batch-noop") + testutil.MustNoErr(t, err, "DeleteDedupedBatch") + + restored, err := f.Store.UndoDedup("batch-noop") + testutil.MustNoErr(t, err, "UndoDedup after delete") + if restored != 0 { + t.Errorf("UndoDedup after delete restored = %d, want 0", restored) + } +} diff --git a/internal/store/dedup_test.go b/internal/store/dedup_test.go new file mode 100644 index 00000000..a3dd62de --- /dev/null +++ b/internal/store/dedup_test.go @@ -0,0 +1,377 @@ +package store_test + +import ( + "database/sql" + "fmt" + "testing" + + "github.com/wesm/msgvault/internal/store" + "github.com/wesm/msgvault/internal/testutil" + "github.com/wesm/msgvault/internal/testutil/storetest" +) + +func newRFC822Message( + t *testing.T, f *storetest.Fixture, sourceMessageID, rfc822ID string, +) int64 { + t.Helper() + id, err := f.Store.UpsertMessage(&store.Message{ + ConversationID: f.ConvID, + SourceID: f.Source.ID, + SourceMessageID: sourceMessageID, + RFC822MessageID: sql.NullString{ + String: rfc822ID, Valid: rfc822ID != "", + }, + MessageType: "email", + SizeEstimate: 1000, + }) + testutil.MustNoErr(t, err, "UpsertMessage") + return id +} + +func TestStore_FindDuplicatesByRFC822ID(t *testing.T) { + f := storetest.New(t) + idA := newRFC822Message(t, f, "src-a", "rfc822-shared") + idB := newRFC822Message(t, f, "src-b", "rfc822-shared") + _ = newRFC822Message(t, f, "src-c", "rfc822-unique") + + groups, err := f.Store.FindDuplicatesByRFC822ID() + testutil.MustNoErr(t, err, "FindDuplicatesByRFC822ID") + if len(groups) != 1 { + t.Fatalf("groups = %d, want 1", len(groups)) + } + if groups[0].RFC822MessageID != "rfc822-shared" { + t.Errorf("key = %q, want rfc822-shared", groups[0].RFC822MessageID) + } + if groups[0].Count != 2 { + t.Errorf("count = %d, want 2", groups[0].Count) + } + + _, err = f.Store.MergeDuplicates(idA, []int64{idB}, "batch-test") + testutil.MustNoErr(t, err, "MergeDuplicates") + + groups, err = f.Store.FindDuplicatesByRFC822ID() + testutil.MustNoErr(t, err, "FindDuplicatesByRFC822ID after merge") + if len(groups) != 0 { + t.Errorf("groups after merge = %d, want 0", len(groups)) + } +} + +func TestStore_GetDuplicateGroupMessages_SentLabel(t *testing.T) { + f := storetest.New(t) + idInbox := newRFC822Message(t, f, "inbox-copy", "rfc822-sent") + idSent := newRFC822Message(t, f, "sent-copy", "rfc822-sent") + + labels := f.EnsureLabels( + map[string]string{"SENT": "Sent", "INBOX": "Inbox"}, "system", + ) + testutil.MustNoErr(t, f.Store.LinkMessageLabel(idInbox, labels["INBOX"]), "link INBOX") + testutil.MustNoErr(t, f.Store.LinkMessageLabel(idSent, labels["SENT"]), "link SENT") + + rows, err := f.Store.GetDuplicateGroupMessages("rfc822-sent") + testutil.MustNoErr(t, err, "GetDuplicateGroupMessages") + if len(rows) != 2 { + t.Fatalf("rows = %d, want 2", len(rows)) + } + + var sentRow, inboxRow *store.DuplicateMessageRow + for i := range rows { + switch rows[i].ID { + case idSent: + sentRow = &rows[i] + case idInbox: + inboxRow = &rows[i] + } + } + if sentRow == nil || inboxRow == nil { + t.Fatalf("missing rows: sent=%v inbox=%v", sentRow, inboxRow) + } + if !sentRow.HasSentLabel { + t.Errorf("sent row: HasSentLabel = false, want true") + } + if inboxRow.HasSentLabel { + t.Errorf("inbox row: HasSentLabel = true, want false") + } +} + +func TestStore_MergeDuplicates_UnionsLabels(t *testing.T) { + f := storetest.New(t) + idKeep := newRFC822Message(t, f, "keep", "rfc822-merge") + idDrop := newRFC822Message(t, f, "drop", "rfc822-merge") + + labels := f.EnsureLabels( + map[string]string{"INBOX": "Inbox", "IMPORTANT": "Important", "WORK": "Work"}, "user", + ) + testutil.MustNoErr(t, f.Store.LinkMessageLabel(idKeep, labels["INBOX"]), "link INBOX to keep") + testutil.MustNoErr(t, f.Store.LinkMessageLabel(idDrop, labels["IMPORTANT"]), "link IMPORTANT to drop") + testutil.MustNoErr(t, f.Store.LinkMessageLabel(idDrop, labels["WORK"]), "link WORK to drop") + + result, err := f.Store.MergeDuplicates(idKeep, []int64{idDrop}, "batch-labels") + testutil.MustNoErr(t, err, "MergeDuplicates") + if result.LabelsTransferred != 2 { + t.Errorf("labelsTransferred = %d, want 2", result.LabelsTransferred) + } + + f.AssertLabelCount(idKeep, 3) + assertDedupDeleted(t, f.Store, idDrop, true) + + restored, err := f.Store.UndoDedup("batch-labels") + testutil.MustNoErr(t, err, "UndoDedup") + if restored != 1 { + t.Errorf("restored = %d, want 1", restored) + } + assertDedupDeleted(t, f.Store, idDrop, false) +} + +func assertDedupDeleted( + t *testing.T, st *store.Store, msgID int64, wantDeleted bool, +) { + t.Helper() + var deletedAt sql.NullTime + err := st.DB().QueryRow( + "SELECT deleted_at FROM messages WHERE id = ?", msgID, + ).Scan(&deletedAt) + testutil.MustNoErr(t, err, "query deleted_at") + if wantDeleted && !deletedAt.Valid { + t.Errorf("message %d: deleted_at should be set", msgID) + } + if !wantDeleted && deletedAt.Valid { + t.Errorf("message %d: deleted_at should be NULL", msgID) + } +} + +func TestStore_BackfillRFC822IDs_EmptyTable(t *testing.T) { + f := storetest.New(t) + count, err := f.Store.CountMessagesWithoutRFC822ID() + testutil.MustNoErr(t, err, "CountMessagesWithoutRFC822ID") + if count != 0 { + t.Errorf("empty-table count = %d, want 0", count) + } + + updated, _, err := f.Store.BackfillRFC822IDs(nil, nil) + testutil.MustNoErr(t, err, "BackfillRFC822IDs") + if updated != 0 { + t.Errorf("updated = %d, want 0", updated) + } +} + +func TestStore_CountActiveMessages(t *testing.T) { + f := storetest.New(t) + _ = newRFC822Message(t, f, "a", "id-a") + idB := newRFC822Message(t, f, "b", "id-b") + + total, err := f.Store.CountActiveMessages() + testutil.MustNoErr(t, err, "CountActiveMessages") + if total != 2 { + t.Errorf("active = %d, want 2", total) + } + + _, err = f.Store.MergeDuplicates( + newRFC822Message(t, f, "c", "id-c"), + []int64{idB}, + "batch-count", + ) + testutil.MustNoErr(t, err, "MergeDuplicates") + + total, err = f.Store.CountActiveMessages() + testutil.MustNoErr(t, err, "CountActiveMessages after merge") + if total != 2 { + t.Errorf("active after merge = %d, want 2", total) + } +} + +func TestStore_BackfillRFC822IDs_ParsesFromRawMIME(t *testing.T) { + f := storetest.New(t) + + id := newRFC822Message(t, f, "needs-backfill", "") + + rawMIME := []byte("From: alice@example.com\r\nTo: bob@example.com\r\nMessage-ID: \r\nSubject: Backfill test\r\n\r\nBody text") + testutil.MustNoErr(t, + f.Store.UpsertMessageRaw(id, rawMIME), + "UpsertMessageRaw", + ) + + count, err := f.Store.CountMessagesWithoutRFC822ID() + testutil.MustNoErr(t, err, "CountMessagesWithoutRFC822ID") + if count != 1 { + t.Fatalf("count without rfc822 = %d, want 1", count) + } + + updated, _, err := f.Store.BackfillRFC822IDs(nil, nil) + testutil.MustNoErr(t, err, "BackfillRFC822IDs") + if updated != 1 { + t.Fatalf("updated = %d, want 1", updated) + } + + var rfc822ID string + err = f.Store.DB().QueryRow( + "SELECT rfc822_message_id FROM messages WHERE id = ?", id, + ).Scan(&rfc822ID) + testutil.MustNoErr(t, err, "scan rfc822_message_id") + if rfc822ID != "unique-123@example.com" { + t.Errorf("rfc822_message_id = %q, want unique-123@example.com", rfc822ID) + } + + count, err = f.Store.CountMessagesWithoutRFC822ID() + testutil.MustNoErr(t, err, "CountMessagesWithoutRFC822ID after backfill") + if count != 0 { + t.Errorf("count after backfill = %d, want 0", count) + } +} + +func TestStore_BackfillRFC822IDs_DoesNotOvercountRolledBackBatch(t *testing.T) { + f := storetest.New(t) + + idA := newRFC822Message(t, f, "needs-backfill-a", "") + idB := newRFC822Message(t, f, "needs-backfill-b", "") + + rawA := []byte("From: alice@example.com\r\nMessage-ID: \r\n\r\nBody") + rawB := []byte("From: bob@example.com\r\nMessage-ID: \r\n\r\nBody") + testutil.MustNoErr(t, f.Store.UpsertMessageRaw(idA, rawA), "UpsertMessageRaw A") + testutil.MustNoErr(t, f.Store.UpsertMessageRaw(idB, rawB), "UpsertMessageRaw B") + + _, err := f.Store.DB().Exec(fmt.Sprintf(` + CREATE TRIGGER fail_backfill_second_message + BEFORE UPDATE OF rfc822_message_id ON messages + WHEN NEW.id = %d + BEGIN + SELECT RAISE(FAIL, 'forced backfill failure'); + END + `, idB)) + testutil.MustNoErr(t, err, "create trigger") + + updated, failed, err := f.Store.BackfillRFC822IDs(nil, nil) + if err == nil { + t.Fatal("expected backfill error, got nil") + } + if updated != 0 { + t.Fatalf("updated = %d, want 0 after rollback", updated) + } + if failed != 0 { + t.Fatalf("failed = %d, want 0", failed) + } + + var count int64 + err = f.Store.DB().QueryRow(` + SELECT COUNT(*) FROM messages + WHERE rfc822_message_id IS NOT NULL AND rfc822_message_id != '' + `).Scan(&count) + testutil.MustNoErr(t, err, "count backfilled rows") + if count != 0 { + t.Fatalf("backfilled rows = %d, want 0 after rollback", count) + } +} + +func TestStore_MergeDuplicates_BackfillsRawMIME(t *testing.T) { + f := storetest.New(t) + + idSurvivor := newRFC822Message(t, f, "survivor", "rfc822-mime-backfill") + idDuplicate := newRFC822Message(t, f, "duplicate", "rfc822-mime-backfill") + + rawData := []byte("From: alice@example.com\r\nSubject: Test\r\n\r\nBody") + testutil.MustNoErr(t, + f.Store.UpsertMessageRaw(idDuplicate, rawData), + "UpsertMessageRaw on duplicate", + ) + + _, err := f.Store.GetMessageRaw(idSurvivor) + if err == nil { + t.Fatal("survivor should not have raw MIME before merge") + } + + result, err := f.Store.MergeDuplicates( + idSurvivor, []int64{idDuplicate}, "batch-mime", + ) + testutil.MustNoErr(t, err, "MergeDuplicates") + if result.RawMIMEBackfilled != 1 { + t.Errorf("RawMIMEBackfilled = %d, want 1", result.RawMIMEBackfilled) + } + + got, err := f.Store.GetMessageRaw(idSurvivor) + testutil.MustNoErr(t, err, "GetMessageRaw survivor after merge") + if len(got) == 0 { + t.Error("survivor raw MIME should not be empty after backfill") + } +} + +// TestStore_GetDuplicateGroupMessages_PreservesFromCase verifies that the +// FromEmail field returned by GetDuplicateGroupMessages preserves the +// original case of the sender's address. The query layer must NOT +// blanket-lowercase the address — synthetic identifiers like Matrix +// MXIDs (`@Alice:matrix.org`) and chat handles are case-sensitive in +// the rest of the identity subsystem (NormalizeIdentifierForCompare +// preserves case for non-email shapes), so any pre-lowering in SQL +// would prevent dedup's per-source identity match from finding a +// stored case-mixed identity. Regression test for iter12 codex Medium. +func TestStore_GetDuplicateGroupMessages_PreservesFromCase(t *testing.T) { + f := storetest.New(t) + + mxid := "@Alice:matrix.org" + pid := f.EnsureParticipant(mxid, "", "") + + id := newRFC822Message(t, f, "msg-mxid", "rfc822-mxid") + + if _, err := f.Store.DB().Exec( + f.Store.Rebind(`INSERT INTO message_recipients + (message_id, participant_id, recipient_type) + VALUES (?, ?, 'from')`), + id, pid, + ); err != nil { + t.Fatalf("insert from recipient: %v", err) + } + + rows, err := f.Store.GetDuplicateGroupMessages("rfc822-mxid") + testutil.MustNoErr(t, err, "GetDuplicateGroupMessages") + if len(rows) != 1 { + t.Fatalf("rows = %d, want 1", len(rows)) + } + if rows[0].FromEmail != mxid { + t.Errorf("FromEmail = %q, want %q (case must be preserved)", rows[0].FromEmail, mxid) + } +} + +// TestStore_GetAllRawMIMECandidates_PreservesFromCase mirrors +// TestStore_GetDuplicateGroupMessages_PreservesFromCase but covers the +// content-hash candidate path. Both queries had the same SQL `LOWER()` +// problem before iter12; both fixes need regression coverage so a +// future refactor that reintroduces lowercasing in either query is +// caught. Iter13 claude follow-up. +func TestStore_GetAllRawMIMECandidates_PreservesFromCase(t *testing.T) { + f := storetest.New(t) + + mxid := "@Bob:matrix.org" + pid := f.EnsureParticipant(mxid, "", "") + + id := newRFC822Message(t, f, "msg-mxid-raw", "rfc822-mxid-raw") + + if _, err := f.Store.DB().Exec( + f.Store.Rebind(`INSERT INTO message_recipients + (message_id, participant_id, recipient_type) + VALUES (?, ?, 'from')`), + id, pid, + ); err != nil { + t.Fatalf("insert from recipient: %v", err) + } + + // GetAllRawMIMECandidates only returns messages that have a raw + // MIME row, so synthesize one. + testutil.MustNoErr(t, + f.Store.UpsertMessageRaw(id, []byte("From: "+mxid+"\r\n\r\nbody")), + "UpsertMessageRaw", + ) + + cands, err := f.Store.GetAllRawMIMECandidates() + testutil.MustNoErr(t, err, "GetAllRawMIMECandidates") + var got *store.ContentHashCandidate + for i := range cands { + if cands[i].ID == id { + got = &cands[i] + break + } + } + if got == nil { + t.Fatalf("test message %d not in candidates: %+v", id, cands) + } + if got.FromEmail != mxid { + t.Errorf("FromEmail = %q, want %q (case must be preserved)", got.FromEmail, mxid) + } +} diff --git a/internal/store/identifier_match.go b/internal/store/identifier_match.go new file mode 100644 index 00000000..68ee28ba --- /dev/null +++ b/internal/store/identifier_match.go @@ -0,0 +1,102 @@ +package store + +import ( + "fmt" + "strings" +) + +// EqualIdentifier reports whether two identifiers refer to the same +// row under the comparison rules used by AddAccountIdentity / +// RemoveAccountIdentity / MigrateLegacyIdentityConfig: email-shaped +// tokens compare case-insensitively, everything else compares +// case-sensitively. +// +// Use this when a caller has already loaded identity rows and needs +// to find the row that corresponds to a user-supplied identifier in +// memory — e.g., to read prior signals before calling +// AddAccountIdentity. Routing through this function keeps the CLI's +// in-memory matching consistent with the SQL-side LOWER() compare so +// case-mismatched re-adds do not silently bypass "already confirmed" +// UX. +func EqualIdentifier(a, b string) bool { + if looksLikeEmail(a) || looksLikeEmail(b) { + return strings.EqualFold(a, b) + } + return a == b +} + +// NormalizeIdentifierForCompare returns the comparison-canonical form +// of an identifier under the same rules as EqualIdentifier and the +// SQL-side LOWER() shape: email-shaped tokens are lowercased, +// everything else is returned unchanged. +// +// Use this when building a map keyed by identifier (e.g., the dedup +// engine's per-source identity lookup) so the same value can be used +// to insert and to look up. Calling NormalizeIdentifierForCompare on +// both the stored side and the lookup side keeps Matrix MXIDs and +// other case-sensitive synthetic identifiers intact while still +// matching email-shaped identities case-insensitively. +func NormalizeIdentifierForCompare(s string) string { + if looksLikeEmail(s) { + return strings.ToLower(s) + } + return s +} + +// identifierMatch carries the comparison rule for one identifier in +// account_identities. It exists so the email-vs-other branch lives in +// one place: every call site that compares an identifier against +// stored rows builds an identifierMatch and consumes its SQL fragment +// + bind value. +// +// Email-shaped tokens (per looksLikeEmail) compare case-insensitively +// via LOWER() in SQL; everything else compares case-sensitively. +// Case-folding always happens in SQL, never in Go, so stored rows +// retain whatever casing the user supplied at first write. +// +// Use: +// +// match := newIdentifierMatch(addr) +// row := tx.QueryRow( +// `SELECT source_signal FROM account_identities +// WHERE source_id = ? AND `+match.WhereClause("address"), +// sourceID, match.BindValue(), +// ) +type identifierMatch struct { + isEmailShaped bool + raw string +} + +// newIdentifierMatch classifies raw via looksLikeEmail and packages +// the classification with the raw input. Empty input is permitted +// (it falls into the non-email branch); the caller is responsible +// for any empty-input gating that has to happen before SQL is run. +func newIdentifierMatch(raw string) identifierMatch { + return identifierMatch{ + isEmailShaped: looksLikeEmail(raw), + raw: raw, + } +} + +// WhereClause renders an equality predicate of the form +// "LOWER() = LOWER(?)" for email-shaped identifiers and +// " = ?" for everything else. The placeholder always binds +// to BindValue. +// +// SECURITY: column is interpolated into SQL without escaping. Every +// caller in this package supplies a hard-coded column name (today, +// "address" at all three sites). Do NOT pass user input as column +// — there is no SQL-injection guard here. +func (m identifierMatch) WhereClause(column string) string { + if m.isEmailShaped { + return fmt.Sprintf("LOWER(%s) = LOWER(?)", column) + } + return fmt.Sprintf("%s = ?", column) +} + +// BindValue returns the value to bind for the placeholder in +// WhereClause. Equal to the raw input — case-folding happens in SQL +// via LOWER, not in Go, so the stored row keeps original casing. +func (m identifierMatch) BindValue() any { + return m.raw +} diff --git a/internal/store/identifier_match_test.go b/internal/store/identifier_match_test.go new file mode 100644 index 00000000..299485c9 --- /dev/null +++ b/internal/store/identifier_match_test.go @@ -0,0 +1,145 @@ +package store + +import "testing" + +// TestNormalizeIdentifierForCompare locks down the identity-map +// canonicalization rule used by the dedup engine's per-source +// identity lookup. Email-shaped tokens lowercase; everything else +// passes through. Calling it on both sides of a map insertion and +// lookup gives the same case-aware semantics as EqualIdentifier +// without paying for pairwise comparison on the hot path. +func TestNormalizeIdentifierForCompare(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {"email_lower", "foo@x.com", "foo@x.com"}, + {"email_mixed", "Foo@X.COM", "foo@x.com"}, + {"matrix_mxid_preserves_case", "@Alice:matrix.org", "@Alice:matrix.org"}, + {"handle_preserves_case", "AliceHandle", "AliceHandle"}, + {"phone_preserves", "+15551234567", "+15551234567"}, + {"empty", "", ""}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := NormalizeIdentifierForCompare(tc.in); got != tc.want { + t.Errorf("NormalizeIdentifierForCompare(%q) = %q, want %q", tc.in, got, tc.want) + } + }) + } +} + +// TestEqualIdentifier asserts that the in-memory comparison rule +// matches the SQL-side LOWER() rule encoded by identifierMatch: +// email-shaped tokens compare case-insensitively, everything else +// compares case-sensitively. The CLI uses this to look up prior +// rows in already-loaded identity slices before calling +// AddAccountIdentity, which is what surfaces the "already confirmed" +// UX message correctly when the user re-supplies an email with +// different casing. +func TestEqualIdentifier(t *testing.T) { + tests := []struct { + name string + a, b string + want bool + }{ + {"email_same_case", "foo@x.com", "foo@x.com", true}, + {"email_mixed_case", "Foo@X.COM", "foo@x.com", true}, + {"email_distinct", "alice@x.com", "bob@x.com", false}, + {"non_email_same", "AliceHandle", "AliceHandle", true}, + {"non_email_case_diff", "AliceHandle", "alicehandle", false}, + {"matrix_mxid_case_diff", "@Alice:matrix.org", "@alice:matrix.org", false}, + {"phone_same", "+15551234567", "+15551234567", true}, + {"empty_both", "", "", true}, + {"one_email_one_handle", "foo@x.com", "AliceHandle", false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := EqualIdentifier(tc.a, tc.b); got != tc.want { + t.Errorf("EqualIdentifier(%q, %q) = %v, want %v", tc.a, tc.b, got, tc.want) + } + }) + } +} + +// TestIdentifierMatch_TableDriven asserts the SQL-composition contract +// of newIdentifierMatch for representative inputs. Email-shaped tokens +// produce a LOWER()-wrapped predicate; everything else produces a +// case-sensitive predicate. BindValue is always the raw input. +// +// The classification rule is "@ not at index 0 AND right side contains +// a dot" — see looksLikeEmail. This test treats that rule as the +// contract; TestLooksLikeEmail tests the predicate directly. +func TestIdentifierMatch_TableDriven(t *testing.T) { + tests := []struct { + name string + input string + wantWhere string + }{ + {"email", "foo@x.com", "LOWER(address) = LOWER(?)"}, + {"email_mixed_case", "Foo@X.COM", "LOWER(address) = LOWER(?)"}, + {"matrix_mxid", "@alice:matrix.org", "address = ?"}, + {"bare_handle", "AliceHandle", "address = ?"}, + {"phone", "+15551234567", "address = ?"}, + {"email_no_dot", "alice@localhost", "address = ?"}, + {"empty", "", "address = ?"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + m := newIdentifierMatch(tc.input) + if got := m.WhereClause("address"); got != tc.wantWhere { + t.Errorf("WhereClause(%q) = %q, want %q", tc.input, got, tc.wantWhere) + } + if got := m.BindValue(); got != tc.input { + t.Errorf("BindValue() = %q, want %q (raw)", got, tc.input) + } + }) + } +} + +// TestIdentifierMatch_WhereClauseAcceptsCustomColumn asserts the helper +// is column-name-driven so call sites can specify their own column +// (today every site uses "address", but the contract supports more). +func TestIdentifierMatch_WhereClauseAcceptsCustomColumn(t *testing.T) { + m := newIdentifierMatch("foo@x.com") + if got := m.WhereClause("normalized"); got != "LOWER(normalized) = LOWER(?)" { + t.Errorf("WhereClause(\"normalized\") = %q", got) + } + m2 := newIdentifierMatch("AliceHandle") + if got := m2.WhereClause("col"); got != "col = ?" { + t.Errorf("WhereClause(\"col\") = %q", got) + } +} + +// TestLooksLikeEmail asserts the email-shape predicate directly. The +// regression cases (iter2→iter3 Matrix MXID misclassification) are +// the load-bearing rows here: a future refactor that loosens the +// predicate to "@ contains" must fail this test. +func TestLooksLikeEmail(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + {"plain_email", "foo@x.com", true}, + {"mixed_case", "Foo@X.COM", true}, + {"subdomain", "foo@mail.x.com", true}, + {"matrix_mxid", "@alice:matrix.org", false}, + {"matrix_mxid_with_subdomain", "@alice:server.matrix.org", false}, + {"bare_handle", "AliceHandle", false}, + {"phone_e164", "+15551234567", false}, + {"empty", "", false}, + {"email_no_dot", "alice@localhost", false}, + {"trailing_at", "alice@", false}, + {"leading_at_only", "@", false}, + {"single_char", "a", false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := looksLikeEmail(tc.input); got != tc.want { + t.Errorf("looksLikeEmail(%q) = %v, want %v", tc.input, got, tc.want) + } + }) + } +} diff --git a/internal/store/live_messages.go b/internal/store/live_messages.go new file mode 100644 index 00000000..04f5ccb3 --- /dev/null +++ b/internal/store/live_messages.go @@ -0,0 +1,53 @@ +package store + +import "fmt" + +// liveMessages* cover the four (alias, hideDeletedFromSource) +// combinations on hot read paths; pre-computing them avoids a +// fmt.Sprintf allocation per call on list/search/count paths. +const ( + liveMessagesUnaliasedDedupOnly = "deleted_at IS NULL" + liveMessagesUnaliasedFull = "deleted_at IS NULL AND deleted_from_source_at IS NULL" + liveMessagesMDedupOnly = "m.deleted_at IS NULL" + liveMessagesMFull = "m.deleted_at IS NULL AND m.deleted_from_source_at IS NULL" +) + +// LiveMessagesWhere returns the SQL predicate that selects live +// messages. Dedup-hidden rows (deleted_at IS NOT NULL) are filtered +// always — they must not appear in normal user-facing reads. +// Source-deleted rows (deleted_from_source_at IS NOT NULL) are +// filtered only when hideDeletedFromSource is true; archive views +// may intentionally show source-deleted rows, but they always hide +// dedup losers. +// +// Pass the table alias used in the surrounding query (use "" if the +// query has no alias). +// +// Predicate shape: +// +// hideDeletedFromSource=false → .deleted_at IS NULL +// hideDeletedFromSource=true → .deleted_at IS NULL AND +// .deleted_from_source_at IS NULL +// +// The two common alias values ("" and "m") with both boolean values +// are returned from package-level constants to keep this allocation- +// free on hot paths. Other aliases fall back to fmt.Sprintf. +func LiveMessagesWhere(alias string, hideDeletedFromSource bool) string { + switch { + case alias == "" && !hideDeletedFromSource: + return liveMessagesUnaliasedDedupOnly + case alias == "" && hideDeletedFromSource: + return liveMessagesUnaliasedFull + case alias == "m" && !hideDeletedFromSource: + return liveMessagesMDedupOnly + case alias == "m" && hideDeletedFromSource: + return liveMessagesMFull + } + if hideDeletedFromSource { + return fmt.Sprintf( + "%s.deleted_at IS NULL AND %s.deleted_from_source_at IS NULL", + alias, alias, + ) + } + return fmt.Sprintf("%s.deleted_at IS NULL", alias) +} diff --git a/internal/store/live_messages_test.go b/internal/store/live_messages_test.go new file mode 100644 index 00000000..0fff7486 --- /dev/null +++ b/internal/store/live_messages_test.go @@ -0,0 +1,45 @@ +package store_test + +import ( + "testing" + + "github.com/wesm/msgvault/internal/store" +) + +func TestLiveMessagesWhere_NoAlias(t *testing.T) { + got := store.LiveMessagesWhere("", true) + want := "deleted_at IS NULL AND deleted_from_source_at IS NULL" + if got != want { + t.Errorf("LiveMessagesWhere(%q) = %q, want %q", "", got, want) + } +} + +func TestLiveMessagesWhere_WithAlias(t *testing.T) { + got := store.LiveMessagesWhere("m", true) + want := "m.deleted_at IS NULL AND m.deleted_from_source_at IS NULL" + if got != want { + t.Errorf("LiveMessagesWhere(%q) = %q, want %q", "m", got, want) + } +} + +func TestLiveMessagesWhere_TableDriven(t *testing.T) { + cases := []struct { + alias string + hideDeletedFromSource bool + want string + }{ + {"", true, "deleted_at IS NULL AND deleted_from_source_at IS NULL"}, + {"", false, "deleted_at IS NULL"}, + {"m", true, "m.deleted_at IS NULL AND m.deleted_from_source_at IS NULL"}, + {"m", false, "m.deleted_at IS NULL"}, + {"msg", true, "msg.deleted_at IS NULL AND msg.deleted_from_source_at IS NULL"}, + {"msg", false, "msg.deleted_at IS NULL"}, + } + for _, tc := range cases { + got := store.LiveMessagesWhere(tc.alias, tc.hideDeletedFromSource) + if got != tc.want { + t.Errorf("LiveMessagesWhere(%q, %v) = %q, want %q", + tc.alias, tc.hideDeletedFromSource, got, tc.want) + } + } +} diff --git a/internal/store/messages.go b/internal/store/messages.go index 6ff6e60e..df5274e3 100644 --- a/internal/store/messages.go +++ b/internal/store/messages.go @@ -852,20 +852,20 @@ func (s *Store) MarkMessagesDeletedByGmailIDBatch(gmailIDs []string) error { // CountMessagesForSource returns the count of messages for a specific source (account). func (s *Store) CountMessagesForSource(sourceID int64) (int64, error) { var count int64 - err := s.db.QueryRow(` - SELECT COUNT(*) FROM messages WHERE source_id = ? AND deleted_from_source_at IS NULL - `, sourceID).Scan(&count) + err := s.db.QueryRow(fmt.Sprintf(` + SELECT COUNT(*) FROM messages WHERE source_id = ? AND %s + `, LiveMessagesWhere("", true)), sourceID).Scan(&count) return count, err } // CountMessagesWithRaw returns the count of messages that have raw MIME stored. func (s *Store) CountMessagesWithRaw(sourceID int64) (int64, error) { var count int64 - err := s.db.QueryRow(` + err := s.db.QueryRow(fmt.Sprintf(` SELECT COUNT(*) FROM messages m JOIN message_raw mr ON m.id = mr.message_id - WHERE m.source_id = ? AND m.deleted_from_source_at IS NULL - `, sourceID).Scan(&count) + WHERE m.source_id = ? AND %s + `, LiveMessagesWhere("m", true)), sourceID).Scan(&count) return count, err } @@ -873,12 +873,13 @@ func (s *Store) CountMessagesWithRaw(sourceID int64) (int64, error) { // Uses reservoir sampling with random offsets for O(limit) performance on large tables, // falling back to ORDER BY RANDOM() for small tables where the overhead isn't significant. func (s *Store) GetRandomMessageIDs(sourceID int64, limit int) ([]int64, error) { + live := LiveMessagesWhere("", true) // Get total count first var total int64 - err := s.db.QueryRow(` + err := s.db.QueryRow(fmt.Sprintf(` SELECT COUNT(*) FROM messages - WHERE source_id = ? AND deleted_from_source_at IS NULL - `, sourceID).Scan(&total) + WHERE source_id = ? AND %s + `, live), sourceID).Scan(&total) if err != nil { return nil, err } @@ -890,12 +891,12 @@ func (s *Store) GetRandomMessageIDs(sourceID int64, limit int) ([]int64, error) // For small tables or when limit >= total, use simple ORDER BY RANDOM() // The threshold of 10000 balances query overhead vs. scan cost if total < 10000 || int64(limit) >= total { - rows, err := s.db.Query(` + rows, err := s.db.Query(fmt.Sprintf(` SELECT id FROM messages - WHERE source_id = ? AND deleted_from_source_at IS NULL + WHERE source_id = ? AND %s ORDER BY RANDOM() LIMIT ? - `, sourceID, limit) + `, live), sourceID, limit) if err != nil { return nil, err } @@ -925,12 +926,12 @@ func (s *Store) GetRandomMessageIDs(sourceID int64, limit int) ([]int64, error) offset := rng.Int63n(total) var id int64 - err := s.db.QueryRow(` + err := s.db.QueryRow(fmt.Sprintf(` SELECT id FROM messages - WHERE source_id = ? AND deleted_from_source_at IS NULL + WHERE source_id = ? AND %s ORDER BY id LIMIT 1 OFFSET ? - `, sourceID, offset).Scan(&id) + `, live), sourceID, offset).Scan(&id) if err != nil { if err == sql.ErrNoRows { continue // Race condition with deletions, retry diff --git a/internal/store/migrate_legacy_identity.go b/internal/store/migrate_legacy_identity.go new file mode 100644 index 00000000..7e908f4d --- /dev/null +++ b/internal/store/migrate_legacy_identity.go @@ -0,0 +1,248 @@ +package store + +import ( + "database/sql" + "fmt" + "strings" +) + +const migrationLegacyIdentity = "legacy_identity_to_per_account" + +// MigrateLegacyIdentityConfig migrates a list of legacy global identity +// addresses into per-account confirmed records. It runs at most once: +// subsequent calls are no-ops, marked by the +// "legacy_identity_to_per_account" entry in applied_migrations. An +// empty or blank-only address list still marks the migration applied so +// a later config change does not re-run the migration unexpectedly. +// +// Returns (applied bool, deferred bool, sourceCount int, addressCount int, err error). +// +// applied: true if this call performed the migration; false if +// already applied or no addresses to migrate. +// deferred: true when legacy addresses are configured but no +// sources exist yet, so the migration is parked until +// the user adds an account. Distinguishable from the +// "already applied" / "no addresses" no-ops. +// sourceCount: number of accounts that received identity records. +// addressCount: number of distinct addresses migrated (per source). +// +// Migration semantics: every existing source receives a copy of every +// legacy address. After this call, the legacy [identity] config block +// is no longer load-bearing; the dedup engine should read from +// account_identities instead. +func (s *Store) MigrateLegacyIdentityConfig(addresses []string) (applied, deferred bool, sourceCount, addressCount int, err error) { + already, err := s.IsMigrationApplied(migrationLegacyIdentity) + if err != nil { + return false, false, 0, 0, err + } + if already { + return false, false, 0, 0, nil + } + + // Normalize addresses: trim whitespace, drop empties, deduplicate + // using the same case-aware rule as the rest of the identity + // subsystem (NormalizeIdentifierForCompare). Preserves first-seen + // casing for storage so synthetic identifiers (Matrix MXIDs, chat + // handles) keep their original case. + seen := make(map[string]struct{}, len(addresses)) + var normalized []string + for _, addr := range addresses { + a := strings.TrimSpace(addr) + if a == "" { + continue + } + key := NormalizeIdentifierForCompare(a) + if _, dup := seen[key]; dup { + continue + } + seen[key] = struct{}{} + normalized = append(normalized, a) + } + + if len(normalized) == 0 { + if err := s.MarkMigrationApplied(migrationLegacyIdentity); err != nil { + return false, false, 0, 0, err + } + return false, false, 0, 0, nil + } + + sources, err := s.ListSources("") + if err != nil { + return false, false, 0, 0, fmt.Errorf("list sources for identity migration: %w", err) + } + + // If the user has legacy [identity] addresses configured but no + // sources exist yet (typical at init-db time, or before the first + // `add-account`), defer the migration. Marking it applied now would + // permanently drop the addresses on the floor: the next account the + // user adds would never receive them. Leave the sentinel unmarked + // and let the next command run after a source exists pick it up. + // + // Report the post-normalization address count so the deferred + // notice doesn't overstate (raw input may include blanks/dupes). + if len(sources) == 0 { + return false, true, 0, len(normalized), nil + } + + // Legacy [identity] block holds email-shaped addresses only. If no + // existing source has an email-shaped identity column, defer rather + // than mark the migration applied — otherwise the addresses get + // permanently dropped on the floor when the user later adds an + // email source. Same defer contract as the no-sources branch above. + eligibleSources := 0 + for _, src := range sources { + if sourceTypeUsesEmailIdentity(src.SourceType) { + eligibleSources++ + } + } + if eligibleSources == 0 { + return false, true, 0, len(normalized), nil + } + + if err := s.withTx(func(tx *loggedTx) error { + var appliedMarker string + err := tx.QueryRow( + `SELECT name FROM applied_migrations WHERE name = ?`, + migrationLegacyIdentity, + ).Scan(&appliedMarker) + switch { + case err == nil: + return nil + case err != sql.ErrNoRows: + return fmt.Errorf("check migration %q in tx: %w", migrationLegacyIdentity, err) + } + + for _, src := range sources { + // Legacy [identity] block holds email-shaped addresses only. + // Skip non-email source types (whatsapp, imessage, sms, + // google_voice*) so phone-keyed sources don't get email + // identities written to them, which would distort dedup + // sent-copy detection. + if !sourceTypeUsesEmailIdentity(src.SourceType) { + continue + } + for _, addr := range normalized { + // Comparison rule (email-shaped → case-insensitive; + // everything else → case-sensitive) is shared with + // AddAccountIdentity via identifierMatch — see + // identifier_match.go. + match := newIdentifierMatch(addr) + var existing string + qerr := tx.QueryRow( + `SELECT source_signal FROM account_identities + WHERE source_id = ? AND `+match.WhereClause("address"), + src.ID, match.BindValue(), + ).Scan(&existing) + switch { + case qerr == sql.ErrNoRows: + _, txErr := tx.Exec( + `INSERT INTO account_identities (source_id, address, source_signal) + VALUES (?, ?, ?)`, + src.ID, addr, "config_migration", + ) + if txErr != nil { + return fmt.Errorf("insert identity (source=%d, addr=%s): %w", src.ID, addr, txErr) + } + case qerr != nil: + return fmt.Errorf("read existing identity (source=%d, addr=%s): %w", src.ID, addr, qerr) + default: + merged := mergeSignalSet(existing, "config_migration") + if merged != existing { + _, uerr := tx.Exec( + `UPDATE account_identities + SET source_signal = ? + WHERE source_id = ? AND `+match.WhereClause("address"), + merged, src.ID, match.BindValue(), + ) + if uerr != nil { + return fmt.Errorf("update identity (source=%d, addr=%s): %w", src.ID, addr, uerr) + } + } + } + } + } + + _, txErr := tx.Exec( + s.dialect.InsertOrIgnore(`INSERT OR IGNORE INTO applied_migrations (name) VALUES (?)`), + migrationLegacyIdentity, + ) + return txErr + }); err != nil { + return false, false, 0, 0, fmt.Errorf("migrate legacy identity config: %w", err) + } + + return true, false, eligibleSources, len(normalized), nil +} + +// StartupMigrationResult describes the outcome of RunStartupMigrations +// so callers can log accurately. Notice is the user-facing string to +// print to stderr (empty when nothing happened). +type StartupMigrationResult struct { + // Applied is true when the legacy identity migration actually + // inserted per-account identity rows on this call. + Applied bool + // Deferred is true when the migration was parked because no + // source exists yet; addresses remain in the legacy config and + // will migrate on the next command after a source is created. + Deferred bool + // SourceCount is the number of sources the addresses were + // distributed across (only meaningful when Applied). + SourceCount int + // AddressCount is the post-normalization count of legacy + // addresses (meaningful for both Applied and Deferred). + AddressCount int + // Notice is the user-facing string the caller should print. + // Empty when there was nothing to report. + Notice string +} + +// RunStartupMigrations runs all one-time data migrations that should execute +// on every command launch. It is idempotent: already-applied migrations are +// skipped. legacyIdentityAddresses comes from cfg.Identity.Addresses. +// +// The returned StartupMigrationResult's Notice field is non-empty when the +// migration was performed (Applied) or when legacy addresses are parked +// because no source exists yet (Deferred). Caller should print Notice to +// stderr. The structured fields let the caller log the deferred and applied +// paths distinctly. +func (s *Store) RunStartupMigrations(legacyIdentityAddresses []string) (StartupMigrationResult, error) { + applied, deferred, sources, addrs, err := s.MigrateLegacyIdentityConfig(legacyIdentityAddresses) + if err != nil { + return StartupMigrationResult{}, err + } + res := StartupMigrationResult{ + Applied: applied, + Deferred: deferred, + SourceCount: sources, + AddressCount: addrs, + } + switch { + case deferred: + res.Notice = fmt.Sprintf( + "Notice: legacy [identity] config has %d address(es) but no accounts exist yet.\n"+ + "The migration will run on the next command after you add an account\n"+ + "(e.g. 'msgvault add-account ...').", + addrs, + ) + case applied: + res.Notice = fmt.Sprintf( + "Migrated legacy [identity] config to per-account identities (%d addresses across %d accounts).\n"+ + "Run 'msgvault identity list' to review per-account identities;\n"+ + "the [identity] block in config.toml is no longer used.", + addrs, sources, + ) + } + return res, nil +} + +// sourceTypeUsesEmailIdentity reports whether a source type's identity +// column holds email-shaped addresses. Used by the legacy [identity] +// migration to skip phone/handle-keyed sources (whatsapp, imessage, +// google_voice*, sms) so email addresses don't get written to them. +func sourceTypeUsesEmailIdentity(sourceType string) bool { + switch sourceType { + case "gmail", "imap", "o365", "mbox", "hey", "apple-mail": + return true + } + return false +} diff --git a/internal/store/migrate_legacy_identity_test.go b/internal/store/migrate_legacy_identity_test.go new file mode 100644 index 00000000..361bd010 --- /dev/null +++ b/internal/store/migrate_legacy_identity_test.go @@ -0,0 +1,231 @@ +package store_test + +import ( + "testing" + + "github.com/wesm/msgvault/internal/testutil" + "github.com/wesm/msgvault/internal/testutil/storetest" +) + +func TestMigrateLegacyIdentityConfig_Basic(t *testing.T) { + f := storetest.New(t) + st := f.Store + + src2, err := st.GetOrCreateSource("gmail", "second@example.com") + testutil.MustNoErr(t, err, "GetOrCreateSource") + + addresses := []string{"alice@example.com", "alice@work.com", "shared@example.com"} + + applied, deferred, sources, addrs, err := st.MigrateLegacyIdentityConfig(addresses) + testutil.MustNoErr(t, err, "MigrateLegacyIdentityConfig") + + if !applied { + t.Error("applied should be true on first run") + } + if deferred { + t.Error("deferred should be false when sources exist") + } + if sources != 2 { + t.Errorf("sources = %d, want 2", sources) + } + if addrs != 3 { + t.Errorf("addrs = %d, want 3", addrs) + } + + // Verify rows: 2 sources × 3 addresses = 6 rows total. + for _, srcID := range []int64{f.Source.ID, src2.ID} { + ids, listErr := st.ListAccountIdentities(srcID) + testutil.MustNoErr(t, listErr, "ListAccountIdentities") + if len(ids) != 3 { + t.Errorf("source %d: got %d identities, want 3", srcID, len(ids)) + } + for _, id := range ids { + if id.SourceSignal != "config_migration" { + t.Errorf("source_signal = %q, want config_migration", id.SourceSignal) + } + } + } +} + +func TestMigrateLegacyIdentityConfig_MergesExistingSignal(t *testing.T) { + f := storetest.New(t) + st := f.Store + + testutil.MustNoErr(t, st.AddAccountIdentity(f.Source.ID, "alice@example.com", "account-identifier"), "AddAccountIdentity") + + applied, _, _, _, err := st.MigrateLegacyIdentityConfig([]string{"alice@example.com"}) + testutil.MustNoErr(t, err, "MigrateLegacyIdentityConfig") + if !applied { + t.Fatal("applied should be true on first run") + } + + ids, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities") + if len(ids) != 1 { + t.Fatalf("got %d identities, want 1", len(ids)) + } + if ids[0].SourceSignal != "account-identifier,config_migration" { + t.Errorf("source_signal = %q, want account-identifier,config_migration", ids[0].SourceSignal) + } +} + +func TestMigrateLegacyIdentityConfig_SecondCallNoOp(t *testing.T) { + f := storetest.New(t) + st := f.Store + + addresses := []string{"alice@example.com"} + + _, _, _, _, err := st.MigrateLegacyIdentityConfig(addresses) + testutil.MustNoErr(t, err, "first migration") + + applied, _, sources, addrs, err := st.MigrateLegacyIdentityConfig(addresses) + testutil.MustNoErr(t, err, "second migration") + + if applied { + t.Error("applied should be false on second call") + } + if sources != 0 || addrs != 0 { + t.Errorf("second call counts = (%d, %d), want (0, 0)", sources, addrs) + } +} + +func TestMigrateLegacyIdentityConfig_DeferredUntilSourceExists(t *testing.T) { + st := testutil.NewTestStore(t) + + applied, deferred, sources, addrs, err := st.MigrateLegacyIdentityConfig([]string{"alice@example.com"}) + testutil.MustNoErr(t, err, "first migration") + if applied { + t.Fatal("applied should be false before any sources exist") + } + if !deferred { + t.Fatal("deferred should be true when addresses exist but no sources") + } + // On the deferred path we report the post-normalization address + // count so the user-facing notice doesn't overstate (raw input may + // include blanks/dupes). Sources is still 0 because nothing was + // written. + if sources != 0 || addrs != 1 { + t.Fatalf("counts = (%d, %d), want (0, 1)", sources, addrs) + } + + _, err = st.GetOrCreateSource("gmail", "alice@example.com") + testutil.MustNoErr(t, err, "GetOrCreateSource") + + applied, deferred, sources, addrs, err = st.MigrateLegacyIdentityConfig([]string{"alice@example.com"}) + testutil.MustNoErr(t, err, "second migration") + if !applied { + t.Fatal("applied should be true after a source exists") + } + if deferred { + t.Fatal("deferred should be false once a source exists") + } + if sources != 1 || addrs != 1 { + t.Fatalf("counts = (%d, %d), want (1, 1)", sources, addrs) + } +} + +func TestMigrateLegacyIdentityConfig_EmptyAddresses(t *testing.T) { + f := storetest.New(t) + st := f.Store + + applied, _, sources, addrs, err := st.MigrateLegacyIdentityConfig(nil) + testutil.MustNoErr(t, err, "MigrateLegacyIdentityConfig empty") + + if applied { + t.Error("applied should be false for empty address list") + } + if sources != 0 || addrs != 0 { + t.Errorf("counts = (%d, %d), want (0, 0)", sources, addrs) + } + + // Migration should be marked so it won't re-run. + wasMigrated, err := st.IsMigrationApplied("legacy_identity_to_per_account") + testutil.MustNoErr(t, err, "IsMigrationApplied") + if !wasMigrated { + t.Error("migration sentinel should be set even for empty address list") + } +} + +func TestMigrateLegacyIdentityConfig_TrimsWhitespace(t *testing.T) { + f := storetest.New(t) + st := f.Store + + _, _, _, _, err := st.MigrateLegacyIdentityConfig([]string{" ME@Example.COM "}) + testutil.MustNoErr(t, err, "MigrateLegacyIdentityConfig") + + ids, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities") + if len(ids) != 1 { + t.Fatalf("got %d identities, want 1", len(ids)) + } + if ids[0].Address != "ME@Example.COM" { + t.Errorf("address = %q, want ME@Example.COM", ids[0].Address) + } +} + +func TestMigrateLegacyIdentityConfig_PreservesCase(t *testing.T) { + f := storetest.New(t) + st := f.Store + + applied, _, _, _, err := st.MigrateLegacyIdentityConfig([]string{"Alice@Example.com"}) + testutil.MustNoErr(t, err, "MigrateLegacyIdentityConfig") + if !applied { + t.Fatal("expected applied=true on first run") + } + + rows, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities") + if len(rows) != 1 { + t.Fatalf("got %d identities, want 1", len(rows)) + } + if rows[0].Address != "Alice@Example.com" { + t.Errorf("address = %q, want Alice@Example.com", rows[0].Address) + } +} + +// TestMigrateLegacyIdentityConfig_DedupesEmailCaseVariants verifies that +// the migration's input-list dedupe applies the same case-aware rule as +// the rest of the identity subsystem. Email-shaped variants like +// `Alice@Example.com` and `alice@example.com` should collapse to a single +// row per source. Synthetic identifiers (Matrix MXIDs, chat handles) +// remain case-sensitive and are NOT collapsed by dedupe. +func TestMigrateLegacyIdentityConfig_DedupesEmailCaseVariants(t *testing.T) { + f := storetest.New(t) + st := f.Store + + // Email variants: should dedupe to one row, preserving first-seen case. + // Synthetic identifier variants: should NOT dedupe — they're stored + // case-sensitively in the rest of the system. + addresses := []string{ + "Alice@Example.com", + "alice@example.com", + "ALICE@EXAMPLE.COM", + "@user:matrix.org", + "@User:matrix.org", + } + + applied, _, _, addrs, err := st.MigrateLegacyIdentityConfig(addresses) + testutil.MustNoErr(t, err, "MigrateLegacyIdentityConfig") + if !applied { + t.Fatal("expected applied=true on first run") + } + // Want: 1 email (first-seen), 2 distinct MXIDs. + if addrs != 3 { + t.Errorf("addrs = %d, want 3 (1 email collapse + 2 distinct MXIDs)", addrs) + } + + rows, err := st.ListAccountIdentities(f.Source.ID) + testutil.MustNoErr(t, err, "ListAccountIdentities") + if len(rows) != 3 { + t.Fatalf("got %d identities, want 3: %+v", len(rows), rows) + } + got := make(map[string]bool, len(rows)) + for _, r := range rows { + got[r.Address] = true + } + for _, want := range []string{"Alice@Example.com", "@user:matrix.org", "@User:matrix.org"} { + if !got[want] { + t.Errorf("missing identity %q (have %v)", want, got) + } + } +} diff --git a/internal/store/migrations.go b/internal/store/migrations.go new file mode 100644 index 00000000..008fcf37 --- /dev/null +++ b/internal/store/migrations.go @@ -0,0 +1,30 @@ +package store + +import ( + "fmt" +) + +// IsMigrationApplied reports whether the named one-time data migration +// has already run. +func (s *Store) IsMigrationApplied(name string) (bool, error) { + var count int + err := s.db.QueryRow( + `SELECT COUNT(*) FROM applied_migrations WHERE name = ?`, name, + ).Scan(&count) + if err != nil { + return false, fmt.Errorf("check migration %q: %w", name, err) + } + return count > 0, nil +} + +// MarkMigrationApplied records that a migration has run. Idempotent. +func (s *Store) MarkMigrationApplied(name string) error { + _, err := s.db.Exec( + s.dialect.InsertOrIgnore(`INSERT OR IGNORE INTO applied_migrations (name) VALUES (?)`), + name, + ) + if err != nil { + return fmt.Errorf("mark migration %q applied: %w", name, err) + } + return nil +} diff --git a/internal/store/migrations_test.go b/internal/store/migrations_test.go new file mode 100644 index 00000000..74b43300 --- /dev/null +++ b/internal/store/migrations_test.go @@ -0,0 +1,46 @@ +package store_test + +import ( + "testing" + + "github.com/wesm/msgvault/internal/testutil" + "github.com/wesm/msgvault/internal/testutil/storetest" +) + +func TestIsMigrationApplied_NotApplied(t *testing.T) { + f := storetest.New(t) + + applied, err := f.Store.IsMigrationApplied("test_migration") + testutil.MustNoErr(t, err, "IsMigrationApplied") + if applied { + t.Error("migration should not be applied yet") + } +} + +func TestMarkAndCheckMigrationApplied(t *testing.T) { + f := storetest.New(t) + + testutil.MustNoErr(t, f.Store.MarkMigrationApplied("test_migration"), "MarkMigrationApplied") + + applied, err := f.Store.IsMigrationApplied("test_migration") + testutil.MustNoErr(t, err, "IsMigrationApplied") + if !applied { + t.Error("migration should be marked as applied") + } +} + +func TestMarkMigrationApplied_Idempotent(t *testing.T) { + f := storetest.New(t) + + for range 2 { + if err := f.Store.MarkMigrationApplied("test_migration"); err != nil { + t.Fatalf("MarkMigrationApplied: %v", err) + } + } + + applied, err := f.Store.IsMigrationApplied("test_migration") + testutil.MustNoErr(t, err, "IsMigrationApplied") + if !applied { + t.Error("migration should be marked as applied after two calls") + } +} diff --git a/internal/store/schema.sql b/internal/store/schema.sql index e84a13a4..ff1c2b50 100644 --- a/internal/store/schema.sql +++ b/internal/store/schema.sql @@ -365,3 +365,56 @@ CREATE INDEX IF NOT EXISTS idx_message_labels_label ON message_labels(label_id); -- Sync CREATE INDEX IF NOT EXISTS idx_sync_runs_source ON sync_runs(source_id, started_at DESC); + +-- ============================================================================ +-- COLLECTIONS +-- ============================================================================ + +-- Collections (named groupings of sources treated as a single logical archive) +CREATE TABLE IF NOT EXISTS collections ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + description TEXT NOT NULL DEFAULT '', + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +-- Collection membership (many sources per collection) +CREATE TABLE IF NOT EXISTS collection_sources ( + collection_id INTEGER NOT NULL REFERENCES collections(id) ON DELETE CASCADE, + source_id INTEGER NOT NULL REFERENCES sources(id) ON DELETE CASCADE, + PRIMARY KEY (collection_id, source_id) +); + +CREATE INDEX IF NOT EXISTS idx_collection_sources_source_id + ON collection_sources(source_id); + +-- ============================================================================ +-- ACCOUNT IDENTITIES +-- ============================================================================ + +-- Confirmed per-account "me" identities used by sent-message detection +-- in dedup. Identity is account-scoped: an address confirmed for one +-- source does not imply it is "me" in any other source. +CREATE TABLE IF NOT EXISTS account_identities ( + source_id INTEGER NOT NULL REFERENCES sources(id) ON DELETE CASCADE, + address TEXT NOT NULL, -- case-preserved + source_signal TEXT NOT NULL DEFAULT '', -- sorted comma-separated signal set, e.g. 'manual' or 'account-identifier,manual' + confirmed_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (source_id, address) +); + +CREATE INDEX IF NOT EXISTS idx_account_identities_address + ON account_identities(address); + +-- ============================================================================ +-- APPLIED MIGRATIONS +-- ============================================================================ + +-- Marks one-time data migrations that have already run. Schema DDL is +-- idempotent via IF NOT EXISTS; this table is for *data* migrations +-- (e.g. moving legacy config into per-account records) that must run +-- exactly once. +CREATE TABLE IF NOT EXISTS applied_migrations ( + name TEXT PRIMARY KEY, + applied_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); diff --git a/internal/store/sources.go b/internal/store/sources.go index 13a6c6d6..258368c5 100644 --- a/internal/store/sources.go +++ b/internal/store/sources.go @@ -2,9 +2,38 @@ package store import ( "context" + "database/sql" + "errors" "fmt" ) +// ErrSourceNotFound is returned by GetSourceByID when no source row +// matches the given ID. Wrapped via fmt.Errorf("...: %w", ...) so +// callers can use errors.Is to distinguish absence from real DB +// errors. +var ErrSourceNotFound = errors.New("source not found") + +// GetSourceByID returns the source with the given ID, or +// ErrSourceNotFound (wrapped) if no row matches. +func (s *Store) GetSourceByID(id int64) (*Source, error) { + row := s.db.QueryRow(` + SELECT id, source_type, identifier, display_name, google_user_id, + last_sync_at, sync_cursor, sync_config, oauth_app, + created_at, updated_at + FROM sources + WHERE id = ? + `, id) + + source, err := scanSource(row) + if err == sql.ErrNoRows { + return nil, fmt.Errorf("source %d: %w", id, ErrSourceNotFound) + } + if err != nil { + return nil, fmt.Errorf("get source by id: %w", err) + } + return source, nil +} + // GetSourcesByIdentifier returns all sources matching an identifier, // regardless of source_type. Use this when the identifier may be // shared across source types (e.g., gmail + mbox import). diff --git a/internal/store/sources_test.go b/internal/store/sources_test.go index b36b92e1..047f67eb 100644 --- a/internal/store/sources_test.go +++ b/internal/store/sources_test.go @@ -320,6 +320,31 @@ func TestStore_AttachmentPathsUniqueToSource(t *testing.T) { } } +func TestStore_GetSourceByID(t *testing.T) { + f := storetest.New(t) + + got, err := f.Store.GetSourceByID(f.Source.ID) + testutil.MustNoErr(t, err, "GetSourceByID") + if got == nil { + t.Fatal("expected non-nil source") + } + if got.ID != f.Source.ID { + t.Errorf("ID = %d, want %d", got.ID, f.Source.ID) + } + if got.Identifier != f.Source.Identifier { + t.Errorf("Identifier = %q, want %q", got.Identifier, f.Source.Identifier) + } +} + +func TestStore_GetSourceByID_NotFound(t *testing.T) { + f := storetest.New(t) + + _, err := f.Store.GetSourceByID(99999) + if err == nil { + t.Fatal("expected error for non-existent ID, got nil") + } +} + func TestStore_IsAttachmentPathReferenced(t *testing.T) { f := storetest.New(t) @@ -423,3 +448,84 @@ func TestInitSchema_MigratesOAuthAppColumn(t *testing.T) { t.Errorf("OAuthApp = %v, want {acme true}", sources[0].OAuthApp) } } + +// TestInitSchema_AddsDeletedAtToLegacyMessagesTable verifies the +// upgrade-path migration: a database whose `messages` table already has +// every other column the embedded schema indexes reference, but is +// missing the dedup-hide column `deleted_at`, gets the column added by +// InitSchema. Without the ALTER, every read path that references +// `deleted_at` (LiveMessagesWhere, the dedup engine, the cache +// staleness check) fails on upgraded databases with "no such column". +func TestInitSchema_AddsDeletedAtToLegacyMessagesTable(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "legacy.db") + st, err := store.Open(dbPath) + if err != nil { + t.Fatalf("open store: %v", err) + } + t.Cleanup(func() { _ = st.Close() }) + + // Build a messages table that has every column the embedded + // schema's CREATE INDEX statements reference (sender_id, + // deleted_from_source_at, message_type, …) but DOES NOT have the + // new dedup-hide columns (`deleted_at`, `delete_batch_id`). + // Approximates a legacy DB just before this branch landed. + if _, err := st.DB().Exec(` + CREATE TABLE messages ( + id INTEGER PRIMARY KEY, + source_id INTEGER NOT NULL, + source_message_id TEXT, + conversation_id INTEGER, + subject TEXT, + snippet TEXT, + sent_at DATETIME, + received_at DATETIME, + internal_date DATETIME, + size_estimate INTEGER, + has_attachments BOOLEAN, + is_from_me BOOLEAN, + archived_at DATETIME, + rfc822_message_id TEXT, + sender_id INTEGER, + message_type TEXT NOT NULL DEFAULT 'email', + attachment_count INTEGER DEFAULT 0, + deleted_from_source_at DATETIME + ) + `); err != nil { + t.Fatalf("create legacy messages table: %v", err) + } + + if _, err := st.DB().Exec(` + INSERT INTO messages (id, source_id, source_message_id, sent_at) + VALUES (1, 1, 'msg1', datetime('now')) + `); err != nil { + t.Fatalf("insert legacy message: %v", err) + } + + // Run InitSchema — should add deleted_at and delete_batch_id via + // ALTER TABLE migrations (and silently no-op the columns that + // already exist, like deleted_from_source_at). + if err := st.InitSchema(); err != nil { + t.Fatalf("InitSchema on legacy DB: %v", err) + } + + // Confirm the canonical live-messages predicate runs without + // "no such column": this is the failure mode codex flagged. The + // query uses both deleted_at and deleted_from_source_at. + var n int + if err := st.DB().QueryRow( + "SELECT COUNT(*) FROM messages WHERE " + store.LiveMessagesWhere("", true), + ).Scan(&n); err != nil { + t.Fatalf("post-migration live count: %v", err) + } + if n != 1 { + t.Errorf("post-migration live count = %d, want 1", n) + } + + // Confirm delete_batch_id is also queryable post-migration so + // DeleteAllDeduped's distinct-batch count works on upgraded DBs. + if _, err := st.DB().Exec( + "SELECT COUNT(DISTINCT delete_batch_id) FROM messages", + ); err != nil { + t.Fatalf("post-migration delete_batch_id query: %v", err) + } +} diff --git a/internal/store/store.go b/internal/store/store.go index dd7077d0..dd5dcc9f 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -570,6 +570,7 @@ func (s *Store) InitSchema() error { {`ALTER TABLE messages ADD COLUMN message_type TEXT NOT NULL DEFAULT 'email'`, "message_type"}, {`ALTER TABLE messages ADD COLUMN attachment_count INTEGER DEFAULT 0`, "attachment_count"}, {`ALTER TABLE messages ADD COLUMN deleted_from_source_at DATETIME`, "deleted_from_source_at"}, + {`ALTER TABLE messages ADD COLUMN deleted_at DATETIME`, "deleted_at"}, {`ALTER TABLE messages ADD COLUMN delete_batch_id TEXT`, "delete_batch_id"}, {`ALTER TABLE conversations ADD COLUMN title TEXT`, "title"}, {`ALTER TABLE conversations ADD COLUMN conversation_type TEXT NOT NULL DEFAULT 'email_thread'`, "conversation_type"}, @@ -592,14 +593,20 @@ func (s *Store) InitSchema() error { if !s.dialect.IsNoSuchModuleError(err) { return fmt.Errorf("init FTS schema: %w", err) } - // Module not compiled in; availability stays false. - return nil + // Module not compiled in; availability stays false. Fall + // through so the rest of schema init still runs. } } // Probe availability through the dialect so it works uniformly for // backends that carry FTS inside their main schema. s.fts5Available = s.dialect.FTSAvailable(s.db.DB) + + // Ensure the default "All" collection exists and contains every source. + if err := s.EnsureDefaultCollection(); err != nil { + return fmt.Errorf("ensure default collection: %w", err) + } + return nil } @@ -622,22 +629,132 @@ type Stats struct { } // GetStats returns statistics about the database. +// Delegates to GetStatsForScope with no scope filter (global counts). func (s *Store) GetStats() (*Stats, error) { + return s.GetStatsForScope(nil) +} + +// GetStatsForScope returns statistics scoped to the given source IDs. +// When sourceIDs is nil or empty, returns global counts. +// All message-derived counts (threads, attachments, labels) exclude +// dedup-hidden and source-deleted messages via LiveMessagesWhere. +// DatabaseSize is always the global file size — it cannot be decomposed per source. +func (s *Store) GetStatsForScope(sourceIDs []int64) (*Stats, error) { stats := &Stats{} - queries := []struct { + var queries []struct { query string + args []any dest *int64 - }{ - {"SELECT COUNT(*) FROM messages", &stats.MessageCount}, - {"SELECT COUNT(*) FROM conversations", &stats.ThreadCount}, - {"SELECT COUNT(*) FROM attachments", &stats.AttachmentCount}, - {"SELECT COUNT(*) FROM labels", &stats.LabelCount}, - {"SELECT COUNT(*) FROM sources", &stats.SourceCount}, + } + + if len(sourceIDs) == 0 { + // Unscoped: global catalog counts, matching pre-slice-3 semantics. + // All message-linked counts apply LiveMessagesWhere so dedup-hidden + // and source-deleted rows aren't reported as live rows. + queries = []struct { + query string + args []any + dest *int64 + }{ + { + "SELECT COUNT(*) FROM messages WHERE " + LiveMessagesWhere("", true), + nil, + &stats.MessageCount, + }, + { + "SELECT COUNT(*) FROM conversations WHERE EXISTS (" + + "SELECT 1 FROM messages m WHERE m.conversation_id = conversations.id AND " + LiveMessagesWhere("m", true) + + ")", + nil, + &stats.ThreadCount, + }, + { + "SELECT COUNT(*) FROM attachments a WHERE EXISTS (" + + "SELECT 1 FROM messages m WHERE m.id = a.message_id AND " + LiveMessagesWhere("m", true) + + ")", + nil, + &stats.AttachmentCount, + }, + { + "SELECT COUNT(*) FROM labels l WHERE EXISTS (" + + "SELECT 1 FROM message_labels ml JOIN messages m ON m.id = ml.message_id WHERE ml.label_id = l.id AND " + LiveMessagesWhere("m", true) + + ")", + nil, + &stats.LabelCount, + }, + { + "SELECT COUNT(*) FROM sources", + nil, + &stats.SourceCount, + }, + } + } else { + // Build the IN (?, ?, ...) placeholder list. TrimSuffix is panic-safe + // for any len(sourceIDs); the outer guard already routes empty slices + // to the unscoped branch, but this avoids a negative slice index if + // the guard is ever refactored. + placeholders := strings.TrimSuffix(strings.Repeat("?,", len(sourceIDs)), ",") + + inClause := "source_id IN (" + placeholders + ")" + args := make([]any, len(sourceIDs)) + for i, id := range sourceIDs { + args[i] = id + } + cloneArgs := func() []any { + out := make([]any, len(args)) + copy(out, args) + return out + } + + queries = []struct { + query string + args []any + dest *int64 + }{ + { + "SELECT COUNT(*) FROM messages WHERE " + LiveMessagesWhere("", true) + " AND " + inClause, + cloneArgs(), + &stats.MessageCount, + }, + { + "SELECT COUNT(DISTINCT conversation_id) FROM messages WHERE " + LiveMessagesWhere("", true) + " AND " + inClause, + cloneArgs(), + &stats.ThreadCount, + }, + { + "SELECT COUNT(*) FROM attachments a WHERE EXISTS (" + + "SELECT 1 FROM messages m WHERE m.id = a.message_id AND " + LiveMessagesWhere("m", true) + + " AND m." + inClause + ")", + cloneArgs(), + &stats.AttachmentCount, + }, + { + "SELECT COUNT(DISTINCT ml.label_id) FROM message_labels ml " + + "JOIN messages m ON m.id = ml.message_id WHERE " + LiveMessagesWhere("m", true) + + " AND m." + inClause, + cloneArgs(), + &stats.LabelCount, + }, + } + // SourceCount reflects the scope: how many distinct accounts are + // represented. Dedupe defensively in case a caller passes a + // slice with repeats. + seen := make(map[int64]struct{}, len(sourceIDs)) + for _, id := range sourceIDs { + seen[id] = struct{}{} + } + stats.SourceCount = int64(len(seen)) } for _, q := range queries { - if err := s.db.QueryRow(q.query).Scan(q.dest); err != nil { + var row *sql.Row + if len(q.args) > 0 { + row = s.db.QueryRow(q.query, q.args...) + } else { + row = s.db.QueryRow(q.query) + } + if err := row.Scan(q.dest); err != nil { if s.dialect.IsNoSuchTableError(err) { continue } @@ -645,7 +762,7 @@ func (s *Store) GetStats() (*Stats, error) { } } - // Get database file size + // DatabaseSize is always the global file size; scoped stats cannot decompose it. if info, err := os.Stat(s.dbPath); err == nil { stats.DatabaseSize = info.Size() } diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 563ec12d..e5e985c0 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -866,6 +866,38 @@ func TestStore_GetStats_WithData(t *testing.T) { } } +func TestStore_GetStats_ExcludesDedupHidden(t *testing.T) { + f := storetest.New(t) + ids := f.CreateMessages(3) + + // Soft-delete one via dedup (deleted_at). + _, err := f.Store.DB().Exec( + f.Store.Rebind("UPDATE messages SET deleted_at = CURRENT_TIMESTAMP WHERE id = ?"), ids[0]) + testutil.MustNoErr(t, err, "set deleted_at") + + stats, err := f.Store.GetStats() + testutil.MustNoErr(t, err, "GetStats()") + if stats.MessageCount != 2 { + t.Errorf("MessageCount = %d, want 2 (dedup-hidden row excluded)", stats.MessageCount) + } +} + +func TestStore_GetStats_ExcludesSourceDeleted(t *testing.T) { + f := storetest.New(t) + ids := f.CreateMessages(3) + + // Mark one as deleted from source. + _, err := f.Store.DB().Exec( + f.Store.Rebind("UPDATE messages SET deleted_from_source_at = CURRENT_TIMESTAMP WHERE id = ?"), ids[1]) + testutil.MustNoErr(t, err, "set deleted_from_source_at") + + stats, err := f.Store.GetStats() + testutil.MustNoErr(t, err, "GetStats()") + if stats.MessageCount != 2 { + t.Errorf("MessageCount = %d, want 2 (source-deleted row excluded)", stats.MessageCount) + } +} + func TestStore_GetStats_ClosedDB(t *testing.T) { st := testutil.NewTestStore(t) @@ -1580,3 +1612,114 @@ func TestStore_PersistMessage_Upsert(t *testing.T) { t.Errorf("body_text = %q, want %q", bodyText.String, "updated body") } } + +// --- GetStatsForScope tests --- + +// makeSecondSource creates a second source and conversation in the same store as f. +func makeSecondSource(t *testing.T, f *storetest.Fixture, identifier string) (*store.Source, int64) { + t.Helper() + src, err := f.Store.GetOrCreateSource("gmail", identifier) + testutil.MustNoErr(t, err, "GetOrCreateSource "+identifier) + convID, err := f.Store.EnsureConversation(src.ID, "thread-b-1", "Thread B") + testutil.MustNoErr(t, err, "EnsureConversation "+identifier) + return src, convID +} + +// createMessagesForSource inserts count messages under srcID/convID and returns their IDs. +func createMessagesForSource(t *testing.T, st *store.Store, srcID, convID int64, prefix string, count int) []int64 { + t.Helper() + ids := make([]int64, 0, count) + for i := 0; i < count; i++ { + id, err := st.UpsertMessage(&store.Message{ + ConversationID: convID, + SourceID: srcID, + SourceMessageID: fmt.Sprintf("%s-msg-%d", prefix, i), + MessageType: "email", + SizeEstimate: 1000, + }) + testutil.MustNoErr(t, err, fmt.Sprintf("UpsertMessage %s-%d", prefix, i)) + ids = append(ids, id) + } + return ids +} + +func TestStore_GetStatsForScope_SingleSource(t *testing.T) { + f := storetest.New(t) + srcB, convB := makeSecondSource(t, f, "b@example.com") + + createMessagesForSource(t, f.Store, f.Source.ID, f.ConvID, "a", 3) + createMessagesForSource(t, f.Store, srcB.ID, convB, "b", 2) + + // Scoped to source A only. + statsA, err := f.Store.GetStatsForScope([]int64{f.Source.ID}) + testutil.MustNoErr(t, err, "GetStatsForScope A") + if statsA.MessageCount != 3 { + t.Errorf("MessageCount (A only) = %d, want 3", statsA.MessageCount) + } + if statsA.SourceCount != 1 { + t.Errorf("SourceCount (A only) = %d, want 1", statsA.SourceCount) + } + + // Scoped to both sources. + statsAB, err := f.Store.GetStatsForScope([]int64{f.Source.ID, srcB.ID}) + testutil.MustNoErr(t, err, "GetStatsForScope A+B") + if statsAB.MessageCount != 5 { + t.Errorf("MessageCount (A+B) = %d, want 5", statsAB.MessageCount) + } + if statsAB.SourceCount != 2 { + t.Errorf("SourceCount (A+B) = %d, want 2", statsAB.SourceCount) + } + + // Unscoped (nil) should count all messages across both sources. + statsAll, err := f.Store.GetStatsForScope(nil) + testutil.MustNoErr(t, err, "GetStatsForScope nil") + if statsAll.MessageCount != 5 { + t.Errorf("MessageCount (nil/global) = %d, want 5", statsAll.MessageCount) + } + if statsAll.SourceCount != 2 { + t.Errorf("SourceCount (nil/global) = %d, want 2", statsAll.SourceCount) + } +} + +func TestStore_GetStatsForScope_ExcludesDedupHidden(t *testing.T) { + f := storetest.New(t) + srcB, convB := makeSecondSource(t, f, "b-dedup@example.com") + + idsA := createMessagesForSource(t, f.Store, f.Source.ID, f.ConvID, "a-dedup", 2) + createMessagesForSource(t, f.Store, srcB.ID, convB, "b-dedup", 1) + + // Soft-delete one message in source A via dedup (deleted_at). + _, err := f.Store.DB().Exec( + f.Store.Rebind("UPDATE messages SET deleted_at = CURRENT_TIMESTAMP WHERE id = ?"), idsA[0]) + testutil.MustNoErr(t, err, "set deleted_at") + + // Scoped to A: should see only the live message. + statsA, err := f.Store.GetStatsForScope([]int64{f.Source.ID}) + testutil.MustNoErr(t, err, "GetStatsForScope A") + if statsA.MessageCount != 1 { + t.Errorf("MessageCount (A scoped) = %d, want 1 (dedup-hidden excluded)", statsA.MessageCount) + } + + // Unscoped: should also exclude the dedup-hidden message (2 live, not 3). + statsAll, err := f.Store.GetStatsForScope(nil) + testutil.MustNoErr(t, err, "GetStatsForScope nil") + if statsAll.MessageCount != 2 { + t.Errorf("MessageCount (nil/global) = %d, want 2 (dedup-hidden excluded)", statsAll.MessageCount) + } +} + +func TestStore_GetStatsForScope_ExcludesSourceDeleted(t *testing.T) { + f := storetest.New(t) + ids := createMessagesForSource(t, f.Store, f.Source.ID, f.ConvID, "a-srcdeleted", 2) + + // Mark one as deleted from source. + _, err := f.Store.DB().Exec( + f.Store.Rebind("UPDATE messages SET deleted_from_source_at = CURRENT_TIMESTAMP WHERE id = ?"), ids[0]) + testutil.MustNoErr(t, err, "set deleted_from_source_at") + + stats, err := f.Store.GetStatsForScope([]int64{f.Source.ID}) + testutil.MustNoErr(t, err, "GetStatsForScope") + if stats.MessageCount != 1 { + t.Errorf("MessageCount = %d, want 1 (source-deleted excluded)", stats.MessageCount) + } +} diff --git a/internal/store/subset.go b/internal/store/subset.go index 9cf6c854..6df36d1d 100644 --- a/internal/store/subset.go +++ b/internal/store/subset.go @@ -231,12 +231,12 @@ func verifyForeignKeys(db *sql.DB) error { func copyData(tx *sql.Tx, rowCount int) (*CopyResult, error) { result := &CopyResult{} - if _, err := tx.Exec(` + if _, err := tx.Exec(fmt.Sprintf(` CREATE TEMP TABLE selected_messages AS SELECT id FROM src.messages - WHERE deleted_from_source_at IS NULL + WHERE %s ORDER BY COALESCE(sent_at, received_at, internal_date) - DESC, id DESC LIMIT ?`, rowCount); err != nil { + DESC, id DESC LIMIT ?`, LiveMessagesWhere("", true)), rowCount); err != nil { return nil, fmt.Errorf("select messages: %w", err) } diff --git a/internal/store/sync.go b/internal/store/sync.go index bd10799d..1daff77f 100644 --- a/internal/store/sync.go +++ b/internal/store/sync.go @@ -3,6 +3,7 @@ package store import ( "database/sql" "fmt" + "log/slog" "time" ) @@ -321,6 +322,28 @@ func (s *Store) GetOrCreateSource(sourceType, identifier string) (*Source, error } newSource.ID, _ = result.LastInsertId() + // Add to the default "All" collection if it exists. + // + // This runs as a separate Exec rather than inside a transaction + // with the source insert. If this Exec fails, the source row is + // committed but the All membership is missing — and the next + // EnsureDefaultCollection call (which runs in InitSchema on every + // process launch) re-adds every source not yet linked. Self-heals + // on next CLI invocation; until then collection-scoped reads of + // All would miss this source. Acceptable for a single-user tool; + // a future refactor can fold this into a withTx. + if _, err := s.db.Exec( + `INSERT OR IGNORE INTO collection_sources (collection_id, source_id) + SELECT id, ? FROM collections WHERE name = ?`, + newSource.ID, DefaultCollectionName, + ); err != nil { + slog.Warn("failed to add source to default collection (self-heals on next InitSchema)", + "source_id", newSource.ID, + "identifier", identifier, + "error", err, + ) + } + return newSource, nil } diff --git a/internal/sync/sync_test.go b/internal/sync/sync_test.go index a20bf68d..8b6dd64c 100644 --- a/internal/sync/sync_test.go +++ b/internal/sync/sync_test.go @@ -1549,7 +1549,9 @@ func TestIncrementalSyncMixedOperations(t *testing.T) { assertDeletedFromSource(t, env.Store, "existing-1", true) assertMessageHasLabel(t, env.Store, "existing-2", "STARRED") - assertMessageCount(t, env.Store, 3) // 2 original (1 deleted but still counted) + 1 new + // GetStats now applies the live-message predicate: source-deleted rows are + // excluded. Count is 1 surviving original + 1 new = 2. + assertMessageCount(t, env.Store, 2) } // TestDeriveThreadKey verifies the MIME-based thread key derivation used for diff --git a/internal/tui/model.go b/internal/tui/model.go index f5ef14c3..414e7028 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -1456,7 +1456,7 @@ func (m Model) confirmDeletion() (tea.Model, tea.Cmd) { // Show success m.modal = modalDeleteResult - m.modalResult = fmt.Sprintf("Staged %d messages for deletion.\nBatch ID: %s\nRun 'msgvault delete-staged' to execute.", + m.modalResult = fmt.Sprintf("Staged %d messages for deletion.\nBatch ID: %s\nInspect: msgvault delete-staged --list\nExecute: MSGVAULT_ENABLE_REMOTE_DELETE=1 msgvault delete-staged", len(m.pendingManifest.GmailIDs), m.pendingManifest.ID) // Clear selection diff --git a/internal/tui/view.go b/internal/tui/view.go index ee75b8fa..a90c299f 100644 --- a/internal/tui/view.go +++ b/internal/tui/view.go @@ -1240,7 +1240,8 @@ func (m Model) renderDeleteConfirmModal() string { sb.WriteString("\n\n") _, _ = fmt.Fprintf(&sb, "Stage %d messages for deletion?\n\n", len(m.pendingManifest.GmailIDs)) sb.WriteString("This creates a deletion batch. Messages will NOT be\n") - sb.WriteString("deleted until you run 'msgvault delete-staged'.\n\n") + sb.WriteString("deleted until you run 'msgvault delete-staged'\n") + sb.WriteString("with MSGVAULT_ENABLE_REMOTE_DELETE=1 set.\n\n") if m.pendingManifest.Filters.Account == "" { sb.WriteString("! Account not set. Use --account when executing.\n\n") } diff --git a/internal/vector/embed/testsupport_test.go b/internal/vector/embed/testsupport_test.go index 7ec9bb4e..4a4e7ff4 100644 --- a/internal/vector/embed/testsupport_test.go +++ b/internal/vector/embed/testsupport_test.go @@ -92,6 +92,7 @@ func newWorkerFixture(t *testing.T, n int) *workerFixture { CREATE TABLE messages ( id INTEGER PRIMARY KEY, subject TEXT, + deleted_at DATETIME, deleted_from_source_at DATETIME ); CREATE TABLE message_bodies ( diff --git a/internal/vector/hybrid/engine_test.go b/internal/vector/hybrid/engine_test.go index 00dedd98..309105b4 100644 --- a/internal/vector/hybrid/engine_test.go +++ b/internal/vector/hybrid/engine_test.go @@ -62,6 +62,7 @@ CREATE TABLE messages ( has_attachments INTEGER DEFAULT 0, size_estimate INTEGER, sent_at DATETIME, + deleted_at DATETIME, deleted_from_source_at DATETIME ); CREATE TABLE message_bodies ( diff --git a/internal/vector/sqlitevec/backend.go b/internal/vector/sqlitevec/backend.go index a54436a2..97c8eba1 100644 --- a/internal/vector/sqlitevec/backend.go +++ b/internal/vector/sqlitevec/backend.go @@ -13,6 +13,7 @@ import ( "time" sqlite3 "github.com/mattn/go-sqlite3" + "github.com/wesm/msgvault/internal/store" "github.com/wesm/msgvault/internal/vector" ) @@ -296,8 +297,15 @@ func isUniqueConstraintErr(err error) bool { // retry if interrupted. Runs under a single vectors.db transaction so // the seed itself is atomic. func (b *Backend) seedPending(ctx context.Context, gen vector.GenerationID, now int64) error { + // Embedding-seeding: skip dedup-hidden and remote-deleted rows + // using the canonical live-message predicate + // (store.LiveMessagesWhere). Dedup Execute does not remove + // vector-store rows by design: if a message is embedded then later + // soft-deleted, the embedding stays in the vector store and + // query-time live filtering (dropDeletedFromSource, + // filteredMessageIDs) enforces the live-message contract. rows, err := b.mainDB.QueryContext(ctx, - `SELECT id FROM messages WHERE deleted_from_source_at IS NULL`) + fmt.Sprintf(`SELECT id FROM messages WHERE %s`, store.LiveMessagesWhere("", true))) if err != nil { return fmt.Errorf("select messages: %w", err) } @@ -788,9 +796,9 @@ func (b *Backend) resolveFilter(ctx context.Context, filter vector.Filter) (stri // case at O(embedded count) rather than leaving the caller short. const deletedOverfetchFactor = 2 -// dropDeletedFromSource takes ANN hits and returns the subset whose -// message rows are still live (deleted_from_source_at IS NULL) in -// main.db, preserving the input order. Used by Search on the empty- +// dropDeletedFromSource takes ANN hits and returns the subset that +// are live messages (deleted_at IS NULL AND deleted_from_source_at IS NULL) +// in main.db, preserving the input order. Used by Search on the empty- // filter fast path so that pure-vector/find_similar callers don't pay // the cost of materializing the full live-corpus id list just to // enforce the deletion check. @@ -806,9 +814,9 @@ func (b *Backend) dropDeletedFromSource(ctx context.Context, hits []vector.Hit) if err != nil { return nil, fmt.Errorf("encode hit ids: %w", err) } - q := `SELECT id FROM messages + q := fmt.Sprintf(`SELECT id FROM messages WHERE id IN (SELECT value FROM json_each(?)) - AND deleted_from_source_at IS NULL` + AND %s`, store.LiveMessagesWhere("", true)) rows, err := b.mainDB.QueryContext(ctx, q, string(blob)) if err != nil { return nil, fmt.Errorf("live-hit filter: %w", err) @@ -837,7 +845,7 @@ func (b *Backend) dropDeletedFromSource(ctx context.Context, hits []vector.Hit) // filteredMessageIDs runs the filter against the main DB and returns // matching message IDs. See spec §5.3. func (b *Backend) filteredMessageIDs(ctx context.Context, f vector.Filter) ([]int64, error) { - clauses := []string{"m.deleted_from_source_at IS NULL"} + clauses := []string{store.LiveMessagesWhere("m", true)} var args []any if len(f.SourceIDs) > 0 { diff --git a/internal/vector/sqlitevec/backend_test.go b/internal/vector/sqlitevec/backend_test.go index 92c4189d..faec8800 100644 --- a/internal/vector/sqlitevec/backend_test.go +++ b/internal/vector/sqlitevec/backend_test.go @@ -381,6 +381,56 @@ func TestBackend_CreateGeneration_SkipsDeletedMessages(t *testing.T) { } } +// TestBackend_SeedPending_SkipsDedupHidden verifies that seedPending +// omits messages soft-deleted by dedup (deleted_at IS NOT NULL). +func TestBackend_SeedPending_SkipsDedupHidden(t *testing.T) { + t.Helper() + ctx := context.Background() + + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("open main: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + if _, err := db.Exec(`CREATE TABLE messages ( + id INTEGER PRIMARY KEY, + deleted_at DATETIME, + deleted_from_source_at DATETIME + )`); err != nil { + t.Fatalf("create messages: %v", err) + } + // Insert one live and one dedup-hidden message. + if _, err := db.Exec(`INSERT INTO messages (id) VALUES (1)`); err != nil { + t.Fatalf("insert live: %v", err) + } + if _, err := db.Exec(`INSERT INTO messages (id, deleted_at) VALUES (2, CURRENT_TIMESTAMP)`); err != nil { + t.Fatalf("insert dedup-hidden: %v", err) + } + + b, err := Open(ctx, Options{ + Path: t.TempDir() + "/vectors.db", + Dimension: 768, + MainDB: db, + }) + if err != nil { + t.Fatalf("Open: %v", err) + } + t.Cleanup(func() { _ = b.Close() }) + + gid, err := b.CreateGeneration(ctx, "m", 768) + if err != nil { + t.Fatalf("CreateGeneration: %v", err) + } + var n int + if err := b.db.QueryRowContext(ctx, + `SELECT COUNT(*) FROM pending_embeddings WHERE generation_id = ?`, gid).Scan(&n); err != nil { + t.Fatalf("count pending: %v", err) + } + if n != 1 { + t.Errorf("pending count = %d, want 1 (dedup-hidden message must be excluded)", n) + } +} + // TestBackend_Upsert_WritesEmbeddingAndVector verifies Upsert's // contract: it writes the embeddings row and the dimension-specific // vec0 row, and explicitly does NOT touch pending_embeddings. The @@ -1589,3 +1639,135 @@ func TestBackend_LoadVector_NoActive(t *testing.T) { t.Fatalf("want ErrNoActiveGeneration, got %v", err) } } + +// TestBackend_Search_ExcludesDedupHidden confirms that Search excludes +// messages hidden by dedup (deleted_at IS NOT NULL), not just those +// deleted from source. Uses a minimal main DB without FTS5. +func TestBackend_Search_ExcludesDedupHidden(t *testing.T) { + ctx := context.Background() + + // Minimal main DB: two messages, one dedup-hidden. No FTS5 required. + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("open main: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + if _, err := db.Exec(`CREATE TABLE messages ( + id INTEGER PRIMARY KEY, + deleted_at DATETIME, + deleted_from_source_at DATETIME + )`); err != nil { + t.Fatalf("create messages: %v", err) + } + if _, err := db.Exec( + `INSERT INTO messages (id, deleted_at) VALUES (1, NULL), (2, '2026-01-01 00:00:00')`); err != nil { + t.Fatalf("insert messages: %v", err) + } + + b, err := Open(ctx, Options{ + Path: t.TempDir() + "/vectors.db", + Dimension: 768, + MainDB: db, + }) + if err != nil { + t.Fatalf("Open: %v", err) + } + t.Cleanup(func() { _ = b.Close() }) + + gid, err := b.CreateGeneration(ctx, "m", 768) + if err != nil { + t.Fatalf("CreateGeneration: %v", err) + } + chunks := []vector.Chunk{ + {MessageID: 1, Vector: unitVec(768, 0)}, + {MessageID: 2, Vector: unitVec(768, 0)}, + } + if err := b.Upsert(ctx, gid, chunks); err != nil { + t.Fatalf("Upsert: %v", err) + } + + hits, err := b.Search(ctx, gid, unitVec(768, 0), 10, vector.Filter{}) + if err != nil { + t.Fatalf("Search: %v", err) + } + got := make(map[int64]bool, len(hits)) + for _, h := range hits { + got[h.MessageID] = true + } + if !got[1] { + t.Errorf("msg 1 missing (live message must appear)") + } + if got[2] { + t.Errorf("msg 2 present (deleted_at IS NOT NULL, must be excluded)") + } +} + +// TestBackend_FilteredMessageIDs_ExcludesDedupHidden confirms that +// filteredMessageIDs excludes messages with deleted_at set. +// Uses a minimal main DB without FTS5. +func TestBackend_FilteredMessageIDs_ExcludesDedupHidden(t *testing.T) { + ctx := context.Background() + + // Minimal main DB with source_id for SourceIDs filter. + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("open main: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + if _, err := db.Exec(`CREATE TABLE messages ( + id INTEGER PRIMARY KEY, + source_id INTEGER, + deleted_at DATETIME, + deleted_from_source_at DATETIME + )`); err != nil { + t.Fatalf("create messages: %v", err) + } + // Three messages: 1 live, 2 dedup-hidden, 3 source-deleted. + if _, err := db.Exec(` + INSERT INTO messages (id, source_id, deleted_at, deleted_from_source_at) VALUES + (1, 1, NULL, NULL), + (2, 1, '2026-01-01 00:00:00', NULL), + (3, 1, NULL, '2026-01-01 00:00:00')`); err != nil { + t.Fatalf("insert messages: %v", err) + } + + b, err := Open(ctx, Options{ + Path: t.TempDir() + "/vectors.db", + Dimension: 768, + MainDB: db, + }) + if err != nil { + t.Fatalf("Open: %v", err) + } + t.Cleanup(func() { _ = b.Close() }) + + // Upsert vectors for all three messages directly. + gid, err := b.CreateGeneration(ctx, "m", 768) + if err != nil { + t.Fatalf("CreateGeneration: %v", err) + } + chunks := []vector.Chunk{ + {MessageID: 1, Vector: unitVec(768, 0)}, + {MessageID: 2, Vector: unitVec(768, 0)}, + {MessageID: 3, Vector: unitVec(768, 0)}, + } + if err := b.Upsert(ctx, gid, chunks); err != nil { + t.Fatalf("Upsert: %v", err) + } + + // Filtered search via a non-empty filter triggers filteredMessageIDs. + hits, err := b.Search(ctx, gid, unitVec(768, 0), 10, vector.Filter{SourceIDs: []int64{1}}) + if err != nil { + t.Fatalf("Search with filter: %v", err) + } + got := make(map[int64]bool, len(hits)) + for _, h := range hits { + got[h.MessageID] = true + } + if got[2] { + t.Errorf("msg 2 present (deleted_at, must be excluded)") + } + if got[3] { + t.Errorf("msg 3 present (deleted_from_source_at, must be excluded)") + } +} diff --git a/internal/vector/sqlitevec/backend_testhelpers_test.go b/internal/vector/sqlitevec/backend_testhelpers_test.go index 11e31b05..8c9a6458 100644 --- a/internal/vector/sqlitevec/backend_testhelpers_test.go +++ b/internal/vector/sqlitevec/backend_testhelpers_test.go @@ -26,6 +26,7 @@ func openMainDBWithOneMessage(t *testing.T) *sql.DB { t.Cleanup(func() { _ = db.Close() }) if _, err := db.Exec(`CREATE TABLE messages ( id INTEGER PRIMARY KEY, + deleted_at DATETIME, deleted_from_source_at DATETIME )`); err != nil { t.Fatalf("create messages: %v", err) @@ -48,6 +49,7 @@ func openBackendWithOneDeletedMessage(t *testing.T) *Backend { t.Cleanup(func() { _ = db.Close() }) if _, err := db.Exec(`CREATE TABLE messages ( id INTEGER PRIMARY KEY, + deleted_at DATETIME, deleted_from_source_at DATETIME )`); err != nil { t.Fatalf("create messages: %v", err) @@ -103,6 +105,7 @@ CREATE TABLE messages ( has_attachments INTEGER DEFAULT 0, size_estimate INTEGER, sent_at DATETIME, + deleted_at DATETIME, deleted_from_source_at DATETIME ); CREATE VIRTUAL TABLE messages_fts USING fts5(subject, body, content='', contentless_delete=1); diff --git a/internal/vector/sqlitevec/fused.go b/internal/vector/sqlitevec/fused.go index bf47ddb5..46a0ecdb 100644 --- a/internal/vector/sqlitevec/fused.go +++ b/internal/vector/sqlitevec/fused.go @@ -12,6 +12,7 @@ import ( "sort" "strings" + "github.com/wesm/msgvault/internal/store" "github.com/wesm/msgvault/internal/vector" ) @@ -129,7 +130,7 @@ WITH filtered AS ( SELECT m.id FROM messages m - WHERE m.deleted_from_source_at IS NULL + WHERE %s AND (:source_ids IS NULL OR m.source_id IN (SELECT value FROM json_each(:source_ids))) %s %s @@ -187,7 +188,7 @@ SELECT message_id, rrf_score, bm25_score, vector_score, FROM fused ORDER BY rrf_score DESC, message_id ASC LIMIT :limit -`, senderGroupSQL, toGroupSQL, ccGroupSQL, bccGroupSQL, labelGroupSQL, +`, store.LiveMessagesWhere("m", true), senderGroupSQL, toGroupSQL, ccGroupSQL, bccGroupSQL, labelGroupSQL, kPlus1, req.KPerSignal, vecTable, kPlus1, req.KPerSignal) var queryVecArg any @@ -515,6 +516,11 @@ func (b *Backend) batchGetSubjects(ctx context.Context, ids []int64) (map[int64] placeholders[i] = "?" args[i] = id } + // Liveness is already enforced upstream in the `filtered` CTE used + // for ranking; re-filtering here would silently drop the subject + // for any hit whose row is soft-deleted between ranking and + // hydration, leaving the caller with a ranked hit and an empty + // subject. Hydrate whatever was ranked. q := fmt.Sprintf(`SELECT id, COALESCE(subject, '') FROM messages WHERE id IN (%s)`, strings.Join(placeholders, ",")) rows, err := b.mainDB.QueryContext(ctx, q, args...) diff --git a/internal/vector/sqlitevec/fused_test.go b/internal/vector/sqlitevec/fused_test.go index 1ffb7475..2aa05f1e 100644 --- a/internal/vector/sqlitevec/fused_test.go +++ b/internal/vector/sqlitevec/fused_test.go @@ -341,6 +341,7 @@ CREATE TABLE messages ( has_attachments INTEGER DEFAULT 0, size_estimate INTEGER, sent_at DATETIME, + deleted_at DATETIME, deleted_from_source_at DATETIME ); CREATE VIRTUAL TABLE messages_fts USING fts5(subject, body, content='', contentless_delete=1); diff --git a/internal/whatsapp/importer.go b/internal/whatsapp/importer.go index 7eea2998..8221f0d1 100644 --- a/internal/whatsapp/importer.go +++ b/internal/whatsapp/importer.go @@ -67,6 +67,7 @@ func (imp *Importer) Import(ctx context.Context, waDBPath string, opts ImportOpt if opts.DisplayName != "" { _ = imp.store.UpdateSourceDisplayName(source.ID, opts.DisplayName) } + summary.SourceID = source.ID // Start a sync run for tracking. syncID, err := imp.store.StartSync(source.ID, "whatsapp_import") diff --git a/internal/whatsapp/types.go b/internal/whatsapp/types.go index 1fbc74b8..1e755390 100644 --- a/internal/whatsapp/types.go +++ b/internal/whatsapp/types.go @@ -123,6 +123,7 @@ func DefaultOptions() ImportOptions { // ImportSummary holds statistics from a completed import. type ImportSummary struct { Duration time.Duration + SourceID int64 ChatsProcessed int64 MessagesProcessed int64 MessagesAdded int64