diff --git a/engine/internal/rdsrefresh/dblab.go b/engine/internal/rdsrefresh/dblab.go index 780c0f75..2c29bb27 100644 --- a/engine/internal/rdsrefresh/dblab.go +++ b/engine/internal/rdsrefresh/dblab.go @@ -180,6 +180,10 @@ type SourceConfigUpdate struct { Password string // RDSIAMDBInstance is the RDS DB instance identifier for IAM auth. When empty, this field is omitted from the config update. RDSIAMDBInstance string + // DumpParallelJobs sets the -j flag for pg_dump. When zero, the existing value is preserved. + DumpParallelJobs int + // RestoreParallelJobs sets the -j flag for pg_restore. When zero, the existing value is preserved. + RestoreParallelJobs int } // UpdateSourceConfig updates the source database connection in DBLab config. @@ -198,6 +202,16 @@ func (c *DBLabClient) UpdateSourceConfig(ctx context.Context, update SourceConfi proj.RDSIAMDBInstance = &update.RDSIAMDBInstance } + if update.DumpParallelJobs > 0 { + dumpJobs := int64(update.DumpParallelJobs) + proj.DumpParallelJobs = &dumpJobs + } + + if update.RestoreParallelJobs > 0 { + restoreJobs := int64(update.RestoreParallelJobs) + proj.RestoreParallelJobs = &restoreJobs + } + nested := map[string]interface{}{} // defensive error check: StoreJSON only fails if target is not an addressable struct, diff --git a/engine/internal/rdsrefresh/dblab_test.go b/engine/internal/rdsrefresh/dblab_test.go index b3e48758..878b8184 100644 --- a/engine/internal/rdsrefresh/dblab_test.go +++ b/engine/internal/rdsrefresh/dblab_test.go @@ -191,6 +191,70 @@ func TestDBLabClientUpdateSourceConfig(t *testing.T) { assert.Nil(t, receivedConfig.RDSIAMDBInstance) }) + t.Run("successful with parallelism settings", func(t *testing.T) { + var receivedConfig models.ConfigProjection + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var nested map[string]interface{} + err := json.NewDecoder(r.Body).Decode(&nested) + require.NoError(t, err) + + err = projection.LoadJSON(&receivedConfig, nested, projection.LoadOptions{ + Groups: []string{"default", "sensitive"}, + }) + require.NoError(t, err) + + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client, err := NewDBLabClient(&DBLabConfig{APIEndpoint: server.URL, Token: "test-token"}) + require.NoError(t, err) + + err = client.UpdateSourceConfig(context.Background(), SourceConfigUpdate{ + Host: "clone-host.rds.amazonaws.com", Port: 5432, DBName: "postgres", + Username: "dbuser", Password: "dbpass", + DumpParallelJobs: 4, RestoreParallelJobs: 8, + }) + require.NoError(t, err) + + require.NotNil(t, receivedConfig.DumpParallelJobs) + assert.Equal(t, int64(4), *receivedConfig.DumpParallelJobs) + require.NotNil(t, receivedConfig.RestoreParallelJobs) + assert.Equal(t, int64(8), *receivedConfig.RestoreParallelJobs) + }) + + t.Run("omits parallelism when zero", func(t *testing.T) { + var receivedConfig models.ConfigProjection + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var nested map[string]interface{} + err := json.NewDecoder(r.Body).Decode(&nested) + require.NoError(t, err) + + err = projection.LoadJSON(&receivedConfig, nested, projection.LoadOptions{ + Groups: []string{"default", "sensitive"}, + }) + require.NoError(t, err) + + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client, err := NewDBLabClient(&DBLabConfig{APIEndpoint: server.URL, Token: "test-token"}) + require.NoError(t, err) + + err = client.UpdateSourceConfig(context.Background(), SourceConfigUpdate{ + Host: "host.rds.amazonaws.com", Port: 5432, DBName: "postgres", + Username: "dbuser", Password: "dbpass", + DumpParallelJobs: 0, RestoreParallelJobs: 0, + }) + require.NoError(t, err) + + assert.Nil(t, receivedConfig.DumpParallelJobs) + assert.Nil(t, receivedConfig.RestoreParallelJobs) + }) + t.Run("error on non-2xx status", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) diff --git a/engine/internal/rdsrefresh/parallelism.go b/engine/internal/rdsrefresh/parallelism.go new file mode 100644 index 00000000..b6ffe332 --- /dev/null +++ b/engine/internal/rdsrefresh/parallelism.go @@ -0,0 +1,148 @@ +/* +2026 © PostgresAI +*/ + +package rdsrefresh + +import ( + "fmt" + "runtime" + "strconv" + "strings" + + "gitlab.com/postgres-ai/database-lab/v3/pkg/log" +) + +const ( + // rdsInstanceClassPrefix is stripped to derive the instance size. + rdsInstanceClassPrefix = "db." + + // minParallelJobs is the minimum parallelism level. + minParallelJobs = 1 +) + +// instanceSizeVCPUs maps AWS instance size suffixes to their typical vCPU count. +// this mapping is consistent across most instance families (m5, m6g, r5, r6g, c5, etc.). +// graviton and intel/amd variants of the same size have the same vCPU count. +var instanceSizeVCPUs = map[string]int{ + "micro": 1, + "small": 1, + "medium": 2, + "large": 2, + "xlarge": 4, + "2xlarge": 8, + "3xlarge": 12, + "4xlarge": 16, + "6xlarge": 24, + "8xlarge": 32, + "9xlarge": 36, + "10xlarge": 40, + "12xlarge": 48, + "16xlarge": 64, + "18xlarge": 72, + "24xlarge": 96, + "32xlarge": 128, + "48xlarge": 192, + "metal": 96, +} + +// ParallelismConfig holds the computed parallelism levels for dump and restore. +type ParallelismConfig struct { + DumpJobs int + RestoreJobs int +} + +// ResolveParallelism determines the optimal parallelism levels for pg_dump and pg_restore. +// dump parallelism is based on the vCPU count of the RDS clone instance class. +// restore parallelism is based on the vCPU count of the local machine. +// local vCPU detection uses runtime.NumCPU(), which works on Linux +// (the target platform for DBLab Engine). +func ResolveParallelism(cfg *Config) (*ParallelismConfig, error) { + dumpJobs, err := resolveRDSInstanceVCPUs(cfg.RDSClone.InstanceClass) + if err != nil { + return nil, fmt.Errorf("failed to resolve RDS instance vCPUs: %w", err) + } + + restoreJobs := resolveLocalVCPUs() + + log.Msg("auto-parallelism: dump jobs =", dumpJobs, "(RDS clone vCPUs), restore jobs =", restoreJobs, "(local vCPUs)") + + return &ParallelismConfig{ + DumpJobs: dumpJobs, + RestoreJobs: restoreJobs, + }, nil +} + +// resolveRDSInstanceVCPUs estimates the vCPU count for the given RDS instance class +// by parsing the instance size suffix (e.g. "xlarge" from "db.m5.xlarge"). +// the mapping covers standard AWS size naming used across RDS instance families. +// if the size is not recognized, it attempts to parse a numeric multiplier prefix +// (e.g. "2xlarge" → 8 vCPUs). +func resolveRDSInstanceVCPUs(instanceClass string) (int, error) { + size, err := extractInstanceSize(instanceClass) + if err != nil { + return 0, err + } + + if vcpus, ok := instanceSizeVCPUs[size]; ok { + return vcpus, nil + } + + // handle unlisted NUMxlarge sizes by parsing the multiplier + vcpus, err := parseXlargeMultiplier(size) + if err != nil { + return 0, fmt.Errorf("unknown instance size %q in class %q", size, instanceClass) + } + + return vcpus, nil +} + +// extractInstanceSize extracts the size component from an RDS instance class. +// for example, "db.m5.xlarge" → "xlarge", "db.r6g.2xlarge" → "2xlarge". +func extractInstanceSize(instanceClass string) (string, error) { + if !strings.HasPrefix(instanceClass, rdsInstanceClassPrefix) { + return "", fmt.Errorf("invalid RDS instance class %q: expected %q prefix", instanceClass, rdsInstanceClassPrefix) + } + + withoutPrefix := strings.TrimPrefix(instanceClass, rdsInstanceClassPrefix) + + // format is "family.size", e.g. "m5.xlarge" or "r6g.2xlarge" + parts := strings.SplitN(withoutPrefix, ".", 2) + + const expectedParts = 2 + if len(parts) != expectedParts || parts[1] == "" { + return "", fmt.Errorf("invalid RDS instance class %q: expected format db..", instanceClass) + } + + return parts[1], nil +} + +// parseXlargeMultiplier handles NUMxlarge patterns not in the static map. +// for example, "5xlarge" → 5 * 4 = 20 vCPUs. +func parseXlargeMultiplier(size string) (int, error) { + idx := strings.Index(size, "xlarge") + if idx <= 0 { + return 0, fmt.Errorf("not an xlarge variant: %q", size) + } + + multiplier, err := strconv.Atoi(size[:idx]) + if err != nil { + return 0, fmt.Errorf("invalid multiplier in %q: %w", size, err) + } + + const vcpusPerXlarge = 4 + + return multiplier * vcpusPerXlarge, nil +} + +// resolveLocalVCPUs returns the number of logical CPUs available on the local machine. +// uses runtime.NumCPU() which reads from /proc/cpuinfo on Linux +// (the target platform for DBLab Engine). +func resolveLocalVCPUs() int { + cpus := runtime.NumCPU() + if cpus < minParallelJobs { + return minParallelJobs + } + + return cpus +} diff --git a/engine/internal/rdsrefresh/parallelism_test.go b/engine/internal/rdsrefresh/parallelism_test.go new file mode 100644 index 00000000..ffb73b65 --- /dev/null +++ b/engine/internal/rdsrefresh/parallelism_test.go @@ -0,0 +1,140 @@ +/* +2026 © PostgresAI +*/ + +package rdsrefresh + +import ( + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExtractInstanceSize(t *testing.T) { + testCases := []struct { + instanceClass string + expectedSize string + expectErr bool + }{ + {instanceClass: "db.m5.xlarge", expectedSize: "xlarge"}, + {instanceClass: "db.t3.medium", expectedSize: "medium"}, + {instanceClass: "db.r6g.2xlarge", expectedSize: "2xlarge"}, + {instanceClass: "db.m5.metal", expectedSize: "metal"}, + {instanceClass: "db.t3.micro", expectedSize: "micro"}, + {instanceClass: "db.r6g.16xlarge", expectedSize: "16xlarge"}, + {instanceClass: "m5.xlarge", expectErr: true}, + {instanceClass: "db.m5", expectErr: true}, + {instanceClass: "db.", expectErr: true}, + {instanceClass: "", expectErr: true}, + } + + for _, tc := range testCases { + t.Run(tc.instanceClass, func(t *testing.T) { + size, err := extractInstanceSize(tc.instanceClass) + + if tc.expectErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tc.expectedSize, size) + }) + } +} + +func TestResolveRDSInstanceVCPUs(t *testing.T) { + testCases := []struct { + instanceClass string + expectedVCPUs int + expectErr bool + }{ + {instanceClass: "db.t3.micro", expectedVCPUs: 1}, + {instanceClass: "db.t3.small", expectedVCPUs: 1}, + {instanceClass: "db.t3.medium", expectedVCPUs: 2}, + {instanceClass: "db.m5.large", expectedVCPUs: 2}, + {instanceClass: "db.m5.xlarge", expectedVCPUs: 4}, + {instanceClass: "db.r6g.2xlarge", expectedVCPUs: 8}, + {instanceClass: "db.r6g.4xlarge", expectedVCPUs: 16}, + {instanceClass: "db.r6g.8xlarge", expectedVCPUs: 32}, + {instanceClass: "db.r6g.16xlarge", expectedVCPUs: 64}, + {instanceClass: "db.m5.24xlarge", expectedVCPUs: 96}, + {instanceClass: "db.m5.metal", expectedVCPUs: 96}, + {instanceClass: "db.m5.5xlarge", expectedVCPUs: 20}, + {instanceClass: "invalid", expectErr: true}, + {instanceClass: "db.m5", expectErr: true}, + {instanceClass: "db.m5.unknown", expectErr: true}, + } + + for _, tc := range testCases { + t.Run(tc.instanceClass, func(t *testing.T) { + vcpus, err := resolveRDSInstanceVCPUs(tc.instanceClass) + + if tc.expectErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tc.expectedVCPUs, vcpus) + }) + } +} + +func TestParseXlargeMultiplier(t *testing.T) { + testCases := []struct { + size string + expectedVCPUs int + expectErr bool + }{ + {size: "2xlarge", expectedVCPUs: 8}, + {size: "4xlarge", expectedVCPUs: 16}, + {size: "5xlarge", expectedVCPUs: 20}, + {size: "xlarge", expectErr: true}, + {size: "large", expectErr: true}, + {size: "abcxlarge", expectErr: true}, + } + + for _, tc := range testCases { + t.Run(tc.size, func(t *testing.T) { + vcpus, err := parseXlargeMultiplier(tc.size) + + if tc.expectErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tc.expectedVCPUs, vcpus) + }) + } +} + +func TestResolveLocalVCPUs(t *testing.T) { + vcpus := resolveLocalVCPUs() + + assert.Equal(t, runtime.NumCPU(), vcpus) + assert.GreaterOrEqual(t, vcpus, minParallelJobs) +} + +func TestResolveParallelism(t *testing.T) { + t.Run("resolves both dump and restore jobs", func(t *testing.T) { + cfg := &Config{RDSClone: RDSCloneConfig{InstanceClass: "db.m5.xlarge"}} + + result, err := ResolveParallelism(cfg) + + require.NoError(t, err) + assert.Equal(t, 4, result.DumpJobs) + assert.Equal(t, runtime.NumCPU(), result.RestoreJobs) + }) + + t.Run("returns error for invalid instance class", func(t *testing.T) { + cfg := &Config{RDSClone: RDSCloneConfig{InstanceClass: "invalid"}} + + _, err := ResolveParallelism(cfg) + + require.Error(t, err) + }) +} diff --git a/engine/internal/rdsrefresh/refresher.go b/engine/internal/rdsrefresh/refresher.go index bb65eeb2..85b8abe4 100644 --- a/engine/internal/rdsrefresh/refresher.go +++ b/engine/internal/rdsrefresh/refresher.go @@ -59,14 +59,15 @@ func NewRefresherWithStateFile(ctx context.Context, cfg *Config, stateFile *Stat // Run executes the full refresh workflow: // 1. Verifies DBLab is healthy and not already refreshing -// 2. Gets source database info -// 3. Finds the latest RDS snapshot -// 4. Creates a temporary RDS clone from the RDS snapshot -// 5. Waits for the RDS clone to be available -// 6. Updates DBLab config with the RDS clone endpoint -// 7. Triggers DBLab full refresh -// 8. Waits for refresh to complete -// 9. Deletes the temporary RDS clone +// 2. Resolves parallelism levels (RDS clone vCPUs for dump, local vCPUs for restore) +// 3. Gets source database info +// 4. Finds the latest RDS snapshot +// 5. Creates a temporary RDS clone from the RDS snapshot +// 6. Waits for the RDS clone to be available +// 7. Updates DBLab config with the RDS clone endpoint and parallelism +// 8. Triggers DBLab full refresh +// 9. Waits for refresh to complete +// 10. Deletes the temporary RDS clone func (r *Refresher) Run(ctx context.Context) *RefreshResult { result := &RefreshResult{ StartTime: time.Now(), @@ -96,7 +97,17 @@ func (r *Refresher) Run(ctx context.Context) *RefreshResult { return result } - // step 2: get source info + // step 2: resolve parallelism levels + log.Msg("resolving parallelism levels...") + + parallelism, err := ResolveParallelism(r.cfg) + if err != nil { + log.Warn("failed to auto-detect parallelism, using defaults:", err) + + parallelism = &ParallelismConfig{DumpJobs: 0, RestoreJobs: 0} + } + + // step 3: get source info log.Msg("checking source database...") sourceInfo, err := r.rds.GetSourceInfo(ctx) @@ -107,7 +118,7 @@ func (r *Refresher) Run(ctx context.Context) *RefreshResult { log.Msg("source:", sourceInfo) - // step 3: find latest RDS snapshot + // step 4: find latest RDS snapshot log.Msg("finding latest RDS snapshot...") snapshotID, err := r.rds.FindLatestSnapshot(ctx) @@ -119,7 +130,7 @@ func (r *Refresher) Run(ctx context.Context) *RefreshResult { result.SnapshotID = snapshotID log.Msg("using RDS snapshot:", snapshotID) - // step 4: create temporary RDS clone + // step 5: create temporary RDS clone log.Msg("creating RDS clone from RDS snapshot...") // write state file before clone creation for crash recovery @@ -166,7 +177,7 @@ func (r *Refresher) Run(ctx context.Context) *RefreshResult { } }() - // step 5: wait for RDS clone to be available + // step 6: wait for RDS clone to be available log.Msg("waiting for RDS clone (10-30 min)...") if err := r.rds.WaitForCloneAvailable(ctx, clone); err != nil { @@ -177,16 +188,18 @@ func (r *Refresher) Run(ctx context.Context) *RefreshResult { result.CloneEndpoint = clone.Endpoint log.Msg("RDS clone ready:", fmt.Sprintf("%s:%d", clone.Endpoint, clone.Port)) - // step 6: update DBLab config with RDS clone endpoint + // step 7: update DBLab config with RDS clone endpoint and parallelism log.Msg("updating DBLab config...") if err := r.dblab.UpdateSourceConfig(ctx, SourceConfigUpdate{ - Host: clone.Endpoint, - Port: int(clone.Port), - DBName: r.cfg.Source.DBName, - Username: r.cfg.Source.Username, - Password: r.cfg.Source.Password, - RDSIAMDBInstance: clone.Identifier, + Host: clone.Endpoint, + Port: int(clone.Port), + DBName: r.cfg.Source.DBName, + Username: r.cfg.Source.Username, + Password: r.cfg.Source.Password, + RDSIAMDBInstance: clone.Identifier, + DumpParallelJobs: parallelism.DumpJobs, + RestoreParallelJobs: parallelism.RestoreJobs, }); err != nil { result.Error = fmt.Errorf("failed to update DBLab config: %w", err) return result @@ -194,7 +207,7 @@ func (r *Refresher) Run(ctx context.Context) *RefreshResult { log.Msg("DBLab config updated successfully") - // step 7: trigger DBLab full refresh + // step 8: trigger DBLab full refresh log.Msg("triggering DBLab full refresh...") if err := r.dblab.TriggerFullRefresh(ctx); err != nil { @@ -204,7 +217,7 @@ func (r *Refresher) Run(ctx context.Context) *RefreshResult { log.Msg("full refresh triggered, waiting for completion...") - // step 8: wait for refresh to complete + // step 9: wait for refresh to complete pollInterval := r.cfg.DBLab.PollInterval.Duration() timeout := r.cfg.DBLab.Timeout.Duration() @@ -264,6 +277,14 @@ func (r *Refresher) DryRun(ctx context.Context) error { log.Msg("would use RDS snapshot:", snapshotID) log.Msg("would create RDS clone with instance class:", r.cfg.RDSClone.InstanceClass) + // check parallelism + parallelism, err := ResolveParallelism(r.cfg) + if err != nil { + log.Warn("could not auto-detect parallelism:", err) + } else { + log.Msg("auto-parallelism: dump jobs =", parallelism.DumpJobs, ", restore jobs =", parallelism.RestoreJobs) + } + log.Msg("=== DRY RUN COMPLETE - all checks passed ===") return nil