diff --git a/.golangci.yaml b/.golangci.yaml index 53c87702..4b35ec42 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -21,7 +21,7 @@ linters: main: files: - $all - - '!$test' + - "!$test" allow: - $gostd - github.com/gocarina/gocsv @@ -57,6 +57,8 @@ linters: - github.com/openfga/openfga - github.com/stretchr - go.uber.org/mock/gomock + - github.com/spf13/cobra + - github.com/spf13/viper funlen: lines: 120 statements: 80 diff --git a/.mise.toml b/.mise.toml new file mode 100644 index 00000000..966f88a7 --- /dev/null +++ b/.mise.toml @@ -0,0 +1,2 @@ +[tools] +go = "1.26.2" diff --git a/README.md b/README.md index 4ab46e96..7f5acf1d 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ A cross-platform CLI to interact with an OpenFGA server - [Building from Source](#building-from-source) - [Usage](#usage) - [Configuration](#configuration) + - [Custom Headers](#custom-headers) - [Commands](#commands) - [Stores](#stores) - [List All Stores](#list-stores) @@ -151,6 +152,7 @@ For any command that interacts with an OpenFGA server, these configuration value | Token Audience | `--api-audience` | `FGA_API_AUDIENCE` | `api-audience` | | Store ID | `--store-id` | `FGA_STORE_ID` | `store-id` | | Authorization Model ID | `--model-id` | `FGA_MODEL_ID` | `model-id` | +| Custom Headers | `--custom-headers` | `FGA_CUSTOM_HEADERS` | `custom-headers` | If you are authenticating with a shared secret, you should specify the API Token value. If you are authenticating using OAuth, you should specify the Client ID, Client Secret, API Audience and Token Issuer. For example: @@ -164,6 +166,37 @@ api-token-issuer: auth.fga.dev store-id: 01H0H015178Y2V4CX10C2KGHF4 ``` +#### Custom Headers + +You can add custom HTTP headers to all requests sent to the API using the `--custom-headers` flag. Headers are specified in `: ` format, and the flag can be repeated to add multiple headers. + +##### Flag +```shell +--custom-headers "Header-Name: header-value" +``` + +##### Example +```shell +fga store list --custom-headers "X-Custom-Header: value1" --custom-headers "X-Request-ID: abc123" +``` + +##### Configuration + +Custom headers can also be configured via the CLI environment variable or the configuration file: + +| Name | Flag | CLI | ~/.fga.yaml | +|----------------|----------------------|------------------------|---------------------| +| Custom Headers | `--custom-headers` | `FGA_CUSTOM_HEADERS` | `custom-headers` | + +Example `~/.fga.yaml`: +```yaml +api-url: https://api.fga.example +store-id: 01H0H015178Y2V4CX10C2KGHF4 +custom-headers: + - "X-Custom-Header: value1" + - "X-Request-ID: abc123" +``` + ### Commands #### Stores diff --git a/cmd/root.go b/cmd/root.go index 46fa3cd1..d00a65e5 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -65,9 +65,10 @@ func init() { rootCmd.PersistentFlags().String("api-token", "", "API Token. Will be sent in as a Bearer in the Authorization header") rootCmd.PersistentFlags().String("api-token-issuer", "", "API Token Issuer. API responsible for issuing the API Token. Used in the Client Credentials flow") //nolint:lll rootCmd.PersistentFlags().String("api-audience", "", "API Audience. Used when performing the Client Credentials flow") - rootCmd.PersistentFlags().String("client-id", "", "Client ID. Sent to the Token Issuer during the Client Credentials flow") //nolint:lll - rootCmd.PersistentFlags().String("client-secret", "", "Client Secret. Sent to the Token Issuer during the Client Credentials flow") //nolint:lll - rootCmd.PersistentFlags().StringArray("api-scopes", []string{}, "API Scopes (repeat option for multiple values). Used in the Client Credentials flow") //nolint:lll + rootCmd.PersistentFlags().String("client-id", "", "Client ID. Sent to the Token Issuer during the Client Credentials flow") //nolint:lll + rootCmd.PersistentFlags().String("client-secret", "", "Client Secret. Sent to the Token Issuer during the Client Credentials flow") //nolint:lll + rootCmd.PersistentFlags().StringArray("api-scopes", []string{}, "API Scopes (repeat option for multiple values). Used in the Client Credentials flow") //nolint:lll + rootCmd.PersistentFlags().StringArray("custom-headers", []string{}, "Custom HTTP headers in 'Header: value' format (repeat option for multiple values)") //nolint:lll rootCmd.PersistentFlags().Bool("debug", false, "Enable debug mode - can print more detailed information for debugging") _ = rootCmd.Flags().MarkHidden("debug") diff --git a/go.mod b/go.mod index a3572449..a4df06a4 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module github.com/openfga/cli go 1.25.0 -toolchain go1.26.1 +toolchain go1.26.2 require ( github.com/gocarina/gocsv v0.0.0-20240520201108-78e41c74b4b1 diff --git a/internal/cmdutils/bind-viper-to-flags.go b/internal/cmdutils/bind-viper-to-flags.go index 8f032851..013f2a97 100644 --- a/internal/cmdutils/bind-viper-to-flags.go +++ b/internal/cmdutils/bind-viper-to-flags.go @@ -17,8 +17,6 @@ limitations under the License. package cmdutils import ( - "fmt" - "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/spf13/viper" @@ -30,9 +28,9 @@ func BindViperToFlags(cmd *cobra.Command, viperInstance *viper.Viper) { configName := flag.Name if !flag.Changed && viperInstance.IsSet(configName) { - value := viperInstance.Get(configName) - err := cmd.Flags().Set(flag.Name, fmt.Sprintf("%v", value)) - cobra.CheckErr(err) + for _, strVal := range viperInstance.GetStringSlice(configName) { + cobra.CheckErr(cmd.Flags().Set(flag.Name, strVal)) + } } }) diff --git a/internal/cmdutils/bind-viper-to-flags_test.go b/internal/cmdutils/bind-viper-to-flags_test.go new file mode 100644 index 00000000..5e9ef68f --- /dev/null +++ b/internal/cmdutils/bind-viper-to-flags_test.go @@ -0,0 +1,105 @@ +/* +Copyright © 2023 OpenFGA + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cmdutils + +import ( + "testing" + + "github.com/spf13/cobra" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBindViperToFlags(t *testing.T) { + t.Parallel() + + const flagName = "header" + + testcases := []struct { + name string + value any + expected []string + }{ + { + name: "slice value produces one flag value per element", + value: []any{ + "X-Custom-Header: value1", + "X-Request-ID: abc123", + }, + expected: []string{"X-Custom-Header: value1", "X-Request-ID: abc123"}, + }, + { + name: "single element slice", + value: []any{"X-Custom-Header: value1"}, + expected: []string{"X-Custom-Header: value1"}, + }, + { + name: "empty slice leaves flag untouched", + value: []any{}, + expected: []string{}, + }, + { + name: "typed string slice produces one flag value per element", + value: []string{"X-Custom-Header: value1", "X-Request-ID: abc123"}, + expected: []string{"X-Custom-Header: value1", "X-Request-ID: abc123"}, + }, + { + name: "typed int slice produces one flag value per element", + value: []int{1, 2, 3}, + expected: []string{"1", "2", "3"}, + }, + { + name: "scalar string produces single flag value", + value: "https://api.fga.example", + expected: []string{"https://api.fga.example"}, + }, + { + name: "space separated scalar string (env var style) splits into multiple flag values", + value: "X-Custom-Header:value1 X-Request-ID:abc123", + expected: []string{"X-Custom-Header:value1", "X-Request-ID:abc123"}, + }, + { + name: "boolean value is stringified", + value: true, + expected: []string{"true"}, + }, + { + name: "integer value is stringified", + value: 42, + expected: []string{"42"}, + }, + } + + for _, test := range testcases { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + cmd := &cobra.Command{Use: "root"} + cmd.Flags().StringArray(flagName, nil, "") + + viperInstance := viper.New() + viperInstance.Set(flagName, test.value) + + BindViperToFlags(cmd, viperInstance) + + got, err := cmd.Flags().GetStringArray(flagName) + require.NoError(t, err) + assert.Equal(t, test.expected, got) + }) + } +} diff --git a/internal/cmdutils/get-client-config.go b/internal/cmdutils/get-client-config.go index f6af6032..94bb1498 100644 --- a/internal/cmdutils/get-client-config.go +++ b/internal/cmdutils/get-client-config.go @@ -44,6 +44,7 @@ func GetClientConfig(cmd *cobra.Command) fga.ClientConfig { clientCredentialsClientID, _ := cmd.Flags().GetString("client-id") clientCredentialsClientSecret, _ := cmd.Flags().GetString("client-secret") clientCredentialsScopes, _ := cmd.Flags().GetStringArray("api-scopes") + customHeaders, _ := cmd.Flags().GetStringArray("custom-headers") debug, _ := cmd.Flags().GetBool("debug") return fga.ClientConfig{ @@ -56,6 +57,7 @@ func GetClientConfig(cmd *cobra.Command) fga.ClientConfig { ClientID: clientCredentialsClientID, ClientSecret: clientCredentialsClientSecret, APIScopes: clientCredentialsScopes, + CustomHeaders: customHeaders, Debug: debug, } } diff --git a/internal/fga/fga.go b/internal/fga/fga.go index efa7a130..0e56ac9d 100644 --- a/internal/fga/fga.go +++ b/internal/fga/fga.go @@ -18,6 +18,8 @@ limitations under the License. package fga import ( + "errors" + "fmt" "strings" openfga "github.com/openfga/go-sdk" @@ -32,7 +34,11 @@ const ( MinSdkWaitInMs = 500 ) -var userAgent = "openfga-cli/" + build.Version +var ( + userAgent = "openfga-cli/" + build.Version + + ErrInvalidHeaderFormat = errors.New("expected format \"Header-Name:value\"") +) type ClientConfig struct { ApiUrl string `json:"api_url,omitempty"` //nolint:revive,stylecheck @@ -44,11 +50,17 @@ type ClientConfig struct { APIScopes []string `json:"api_scopes,omitempty"` ClientID string `json:"client_id,omitempty"` ClientSecret string `json:"client_secret,omitempty"` //nolint:gosec + CustomHeaders []string `json:"custom_headers,omitempty"` Debug bool `json:"debug,omitempty"` } func (c ClientConfig) GetFgaClient() (*client.OpenFgaClient, error) { - fgaClient, err := client.NewSdkClient(c.getClientConfig()) + clientConfig, err := c.getClientConfig() + if err != nil { + return nil, err + } + + fgaClient, err := client.NewSdkClient(clientConfig) if err != nil { return nil, err //nolint:wrapcheck } @@ -84,7 +96,12 @@ func (c ClientConfig) getCredentials() *credentials.Credentials { } } -func (c ClientConfig) getClientConfig() *client.ClientConfiguration { +func (c ClientConfig) getClientConfig() (*client.ClientConfiguration, error) { + customHeaders, err := c.getCustomHeaders() + if err != nil { + return nil, fmt.Errorf("invalid custom headers configuration: %w", err) + } + return &client.ClientConfiguration{ ApiUrl: c.ApiUrl, StoreId: c.StoreID, @@ -95,6 +112,24 @@ func (c ClientConfig) getClientConfig() *client.ClientConfiguration { MaxRetry: MaxSdkRetry, MinWaitInMs: MinSdkWaitInMs, }, - Debug: c.Debug, + Debug: c.Debug, + DefaultHeaders: customHeaders, + }, nil +} + +func (c ClientConfig) getCustomHeaders() (map[string]string, error) { + headers := make(map[string]string, len(c.CustomHeaders)) + + for _, header := range c.CustomHeaders { + name, value, _ := strings.Cut(header, ":") + + name, value = strings.TrimSpace(name), strings.TrimSpace(value) + if name == "" { + return nil, fmt.Errorf("invalid custom header %q: %w", header, ErrInvalidHeaderFormat) + } + + headers[name] = value } + + return headers, nil } diff --git a/internal/fga/fga_test.go b/internal/fga/fga_test.go new file mode 100644 index 00000000..ee90e26a --- /dev/null +++ b/internal/fga/fga_test.go @@ -0,0 +1,186 @@ +/* +Copyright © 2023 OpenFGA + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package fga + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/openfga/go-sdk/client" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetCustomHeaders(t *testing.T) { + t.Parallel() + + testcases := []struct { + name string + headers []string + expected map[string]string + err error + }{ + { + name: "no headers", + headers: []string{}, + expected: map[string]string{}, + }, + { + name: "single valid header", + headers: []string{"X-Custom: value1"}, + expected: map[string]string{ + "X-Custom": "value1", + }, + }, + { + name: "multiple valid headers", + headers: []string{"X-Custom: value1", "X-Request-ID: abc123"}, + expected: map[string]string{ + "X-Custom": "value1", + "X-Request-ID": "abc123", + }, + }, + { + name: "colon in value is preserved", + headers: []string{"X-Custom: host:port"}, + expected: map[string]string{ + "X-Custom": "host:port", + }, + }, + { + name: "whitespace is trimmed", + headers: []string{" X-Custom : value1 "}, + expected: map[string]string{ + "X-Custom": "value1", + }, + }, + { + name: "empty value is valid", + headers: []string{"X-Custom: "}, + expected: map[string]string{ + "X-Custom": "", + }, + }, + { + name: "no colon allowed", + headers: []string{"X-Custom"}, + expected: map[string]string{ + "X-Custom": "", + }, + }, + { + name: "empty string returns error", + headers: []string{""}, + err: ErrInvalidHeaderFormat, + }, + { + name: "empty header name returns error", + headers: []string{": value"}, + err: ErrInvalidHeaderFormat, + }, + { + name: "valid header before invalid stops at first error", + headers: []string{"X-Good: ok", ""}, + err: ErrInvalidHeaderFormat, + }, + } + + for _, test := range testcases { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + cfg := ClientConfig{CustomHeaders: test.headers} + result, err := cfg.getCustomHeaders() + + if test.err != nil { + require.Error(t, err) + assert.ErrorIs(t, err, test.err) + } else { + require.NoError(t, err) + assert.Equal(t, test.expected, result) + } + }) + } +} + +func TestCustomHeadersSentInRequest(t *testing.T) { + t.Parallel() + + testcases := []struct { + name string + customHeaders []string + expectedHeaders map[string]string + }{ + { + name: "single header is sent", + customHeaders: []string{"X-Custom-Header: value1"}, + expectedHeaders: map[string]string{"X-Custom-Header": "value1"}, + }, + { + name: "multiple headers are sent", + customHeaders: []string{"X-Custom-Header: value1", "X-Request-ID: abc123"}, + expectedHeaders: map[string]string{ + "X-Custom-Header": "value1", + "X-Request-ID": "abc123", + }, + }, + { + name: "no custom headers", + customHeaders: []string{}, + expectedHeaders: map[string]string{}, + }, + } + + for _, test := range testcases { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + headersCh := make(chan http.Header, 1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headersCh <- r.Header.Clone() + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"stores": []}`)) + })) + defer server.Close() + + cfg := ClientConfig{ + ApiUrl: server.URL, + StoreID: "01H0H015178Y2V4CX10C2KGHF4", + CustomHeaders: test.customHeaders, + } + + fgaClient, err := cfg.GetFgaClient() + require.NoError(t, err) + + _, err = fgaClient.ListStores(context.Background()). + Options(client.ClientListStoresOptions{}). + Execute() + require.NoError(t, err) + + capturedHeaders := <-headersCh + for name, value := range test.expectedHeaders { + assert.Equal(t, value, capturedHeaders.Get(name), + "expected header %s to have value %q", name, value) + } + }) + } +} diff --git a/internal/requests/rampup.go b/internal/requests/rampup.go index b72210c6..9f453411 100644 --- a/internal/requests/rampup.go +++ b/internal/requests/rampup.go @@ -30,7 +30,7 @@ func RampUpAPIRequests( //nolint:cyclop semaphore = make(chan struct{}, maxInFlight) waitGroup sync.WaitGroup ticker = time.NewTicker(rampupPeriodDuration) - requestIndex int32 + requestIndex atomic.Int32 ) // if the ramp up period is 0, go to max rps directly @@ -65,7 +65,7 @@ func RampUpAPIRequests( //nolint:cyclop } for i := 0; i < int(limiter.Limit()); i++ { //nolint:intrange - idx := atomic.AddInt32(&requestIndex, 1) - 1 + idx := requestIndex.Add(1) - 1 if idx >= requestsLen { waitGroup.Wait() @@ -102,7 +102,7 @@ func RampUpAPIRequests( //nolint:cyclop } for i := 0; i < int(limiter.Limit()); i++ { //nolint:intrange - idx := atomic.AddInt32(&requestIndex, 1) - 1 + idx := requestIndex.Add(1) - 1 if idx >= requestsLen { waitGroup.Wait() diff --git a/internal/requests/rampup_test.go b/internal/requests/rampup_test.go index ce74ad2e..c065c050 100644 --- a/internal/requests/rampup_test.go +++ b/internal/requests/rampup_test.go @@ -47,7 +47,7 @@ func TestRampUpAPIRequests_RampUpRate(t *testing.T) { defer cancel() var ( - callCount int32 + callCount atomic.Int32 mutex sync.Mutex ) @@ -55,7 +55,7 @@ func TestRampUpAPIRequests_RampUpRate(t *testing.T) { for i := range requestsList { requestsList[i] = func() error { mutex.Lock() - atomic.AddInt32(&callCount, 1) + callCount.Add(1) mutex.Unlock() return nil