diff --git a/internal/services/sep45_service.go b/internal/services/sep45_service.go new file mode 100644 index 000000000..b67babf4e --- /dev/null +++ b/internal/services/sep45_service.go @@ -0,0 +1,378 @@ +package services + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/binary" + "fmt" + "net/url" + "strings" + + "github.com/stellar/go/clients/stellartoml" + "github.com/stellar/go/keypair" + "github.com/stellar/go/strkey" + "github.com/stellar/go/txnbuild" + "github.com/stellar/go/xdr" + "github.com/stellar/stellar-rpc/protocol" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/sdpcontext" + "github.com/stellar/stellar-disbursement-platform-backend/internal/stellar" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +// The number of ledgers after which the server-signed authorization entry expires. +const signatureExpirationLedgers = 10 + +//go:generate mockery --name=SEP45Service --case=underscore --structname=MockSEP45Service --filename=sep45_service_mock.go --inpackage +type SEP45Service interface { + // CreateChallenge creates a new challenge for the given contract account and home domain. + CreateChallenge(ctx context.Context, req SEP45ChallengeRequest) (*SEP45ChallengeResponse, error) + // ValidateChallenge validates the given challenge and returns a JWT if valid. + ValidateChallenge(ctx context.Context, req SEP45ValidationRequest) (*SEP45ValidationResponse, error) +} + +type sep45Service struct { + rpcClient stellar.RPCClient + tomlClient stellartoml.ClientInterface + networkPassphrase string + contractID xdr.ContractId + signingKP *keypair.Full + signingPKBytes []byte + clientAttributionRequired bool + allowHTTPRetry bool + baseURL string +} + +type SEP45ChallengeRequest struct { + Account string `json:"account" query:"account"` + HomeDomain string `json:"home_domain" query:"home_domain"` + ClientDomain *string `json:"client_domain,omitempty" query:"client_domain"` +} + +func (r SEP45ChallengeRequest) Validate() error { + if strings.TrimSpace(r.Account) == "" { + return fmt.Errorf("account is required") + } + if !strkey.IsValidContractAddress(r.Account) { + return fmt.Errorf("account must be a valid contract address") + } + if strings.TrimSpace(r.HomeDomain) == "" { + return fmt.Errorf("home_domain is required") + } + return nil +} + +type SEP45ChallengeResponse struct { + AuthorizationEntries string `json:"authorization_entries"` + NetworkPassphrase string `json:"network_passphrase"` +} + +type SEP45ValidationRequest struct { + AuthorizationEntries string `json:"authorization_entries" form:"authorization_entries"` +} + +type SEP45ValidationResponse struct { + Token string `json:"token"` +} + +type SEP45ServiceOptions struct { + RPCClient stellar.RPCClient + TOMLClient stellartoml.ClientInterface + NetworkPassphrase string + WebAuthVerifyContractID string + ServerSigningKeypair *keypair.Full + BaseURL string + ClientAttributionRequired bool + AllowHTTPRetry bool +} + +func NewSEP45Service(opts SEP45ServiceOptions) (SEP45Service, error) { + if opts.RPCClient == nil { + return nil, fmt.Errorf("rpc client cannot be nil") + } + if strings.TrimSpace(opts.NetworkPassphrase) == "" { + return nil, fmt.Errorf("network passphrase cannot be empty") + } + if strings.TrimSpace(opts.WebAuthVerifyContractID) == "" { + return nil, fmt.Errorf("web_auth_verify contract ID cannot be empty") + } + if opts.ServerSigningKeypair == nil { + return nil, fmt.Errorf("server signing keypair cannot be nil") + } + if strings.TrimSpace(opts.BaseURL) == "" { + return nil, fmt.Errorf("base URL cannot be empty") + } + + signingKP := opts.ServerSigningKeypair + signingPKBytes, err := strkey.Decode(strkey.VersionByteAccountID, signingKP.Address()) + if err != nil { + return nil, fmt.Errorf("decoding signing public key: %w", err) + } + + rawContractID, err := strkey.Decode(strkey.VersionByteContract, opts.WebAuthVerifyContractID) + if err != nil { + return nil, fmt.Errorf("decoding contract ID: %w", err) + } + var contractID xdr.ContractId + copy(contractID[:], rawContractID) + + tomlClient := opts.TOMLClient + if tomlClient == nil { + tomlClient = stellartoml.DefaultClient + } + + return &sep45Service{ + rpcClient: opts.RPCClient, + tomlClient: tomlClient, + networkPassphrase: opts.NetworkPassphrase, + contractID: contractID, + signingKP: signingKP, + signingPKBytes: signingPKBytes, + clientAttributionRequired: opts.ClientAttributionRequired, + allowHTTPRetry: opts.AllowHTTPRetry, + baseURL: opts.BaseURL, + }, nil +} + +func (s *sep45Service) CreateChallenge(ctx context.Context, req SEP45ChallengeRequest) (*SEP45ChallengeResponse, error) { + if err := req.Validate(); err != nil { + return nil, err + } + + webAuthDomain := s.getWebAuthDomain(ctx) + if strings.TrimSpace(webAuthDomain) == "" { + return nil, fmt.Errorf("unable to determine web_auth_domain") + } + + account := strings.TrimSpace(req.Account) + homeDomain := strings.TrimSpace(req.HomeDomain) + if homeDomain == "" { + return nil, fmt.Errorf("home_domain is required") + } + + if !s.isValidHomeDomain(homeDomain) { + return nil, fmt.Errorf("invalid home_domain must match %s", s.getBaseDomain()) + } + + clientDomain := "" + if req.ClientDomain != nil { + clientDomain = strings.TrimSpace(*req.ClientDomain) + } + if s.clientAttributionRequired && clientDomain == "" { + return nil, fmt.Errorf("client_domain is required") + } + + var clientDomainAccount string + if clientDomain != "" { + key, err := s.fetchSigningKeyFromClientDomain(clientDomain) + if err != nil { + return nil, fmt.Errorf("fetching signing key for client_domain %s: %w", clientDomain, err) + } + clientDomainAccount = key + } + + // TODO(philip): We generate a random nonce right now and don't store it anywhere. + // This is also the case with the SEP-10 implementation, so we should address them together. + nonce, err := generateNonce() + if err != nil { + return nil, fmt.Errorf("generating nonce: %w", err) + } + + // Build the invocation arguments for the web_auth_verify contract function, ensuring + // that fields are in lexicographical order. + fields := []xdr.ScMapEntry{ + utils.NewSymbolStringEntry("account", account), + } + if clientDomain != "" { + fields = append(fields, + utils.NewSymbolStringEntry("client_domain", clientDomain), + utils.NewSymbolStringEntry("client_domain_account", clientDomainAccount), + ) + } + fields = append(fields, + utils.NewSymbolStringEntry("home_domain", homeDomain), + utils.NewSymbolStringEntry("nonce", nonce), + utils.NewSymbolStringEntry("web_auth_domain", webAuthDomain), + utils.NewSymbolStringEntry("web_auth_domain_account", s.signingKP.Address()), + ) + + scMap := xdr.ScMap(fields) + arg, err := xdr.NewScVal(xdr.ScValTypeScvMap, &scMap) + if err != nil { + return nil, fmt.Errorf("building invocation arguments: %w", err) + } + args := xdr.ScVec{arg} + + hostFunction := xdr.HostFunction{ + Type: xdr.HostFunctionTypeHostFunctionTypeInvokeContract, + InvokeContract: &xdr.InvokeContractArgs{ + ContractAddress: xdr.ScAddress{ + Type: xdr.ScAddressTypeScAddressTypeContract, + ContractId: &s.contractID, + }, + FunctionName: "web_auth_verify", + Args: args, + }, + } + + txParams := txnbuild.TransactionParams{ + // The challenge transaction's source account must be different than the server signing account + // so that there is an authorization entry generated for the server signing account. + SourceAccount: &txnbuild.SimpleAccount{ + AccountID: keypair.MustRandom().Address(), + Sequence: 0, + }, + BaseFee: int64(txnbuild.MinBaseFee), + Preconditions: txnbuild.Preconditions{ + TimeBounds: txnbuild.NewTimeout(300), + }, + Operations: []txnbuild.Operation{&txnbuild.InvokeHostFunction{ + SourceAccount: s.signingKP.Address(), + HostFunction: hostFunction, + }}, + } + + tx, err := txnbuild.NewTransaction(txParams) + if err != nil { + return nil, fmt.Errorf("building transaction: %w", err) + } + + base64EncodedTx, err := tx.Base64() + if err != nil { + return nil, fmt.Errorf("encoding transaction: %w", err) + } + + // Simulate the transaction to obtain the authorization entries. + // + // There should be an entry for: + // 1. The server signing account. + // 2. The client contract account (corresponding to the `account` argument). + // 3. The client domain account (if applicable). + simResult, simErr := s.rpcClient.SimulateTransaction(ctx, protocol.SimulateTransactionRequest{ + Transaction: base64EncodedTx, + }) + if simErr != nil { + return nil, fmt.Errorf("simulating transaction: %w", simErr) + } + + authEntries, err := s.signServerAuthEntry(ctx, simResult) + if err != nil { + return nil, err + } + + rawEntries, err := authEntries.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("encoding authorization entries: %w", err) + } + + return &SEP45ChallengeResponse{ + AuthorizationEntries: base64.StdEncoding.EncodeToString(rawEntries), + NetworkPassphrase: s.networkPassphrase, + }, nil +} + +func (s *sep45Service) ValidateChallenge(ctx context.Context, req SEP45ValidationRequest) (*SEP45ValidationResponse, error) { + return nil, fmt.Errorf("challenge validation is not implemented") +} + +func (s *sep45Service) signServerAuthEntry(ctx context.Context, result *stellar.SimulationResult) (xdr.SorobanAuthorizationEntries, error) { + if result == nil || len(result.Response.Results) == 0 { + return nil, fmt.Errorf("missing simulation results") + } + authXDR := result.Response.Results[0].AuthXDR + if authXDR == nil { + return nil, fmt.Errorf("missing authorization entries") + } + + ledgerNumber, err := s.rpcClient.GetLatestLedgerSequence(ctx) + if err != nil { + return nil, fmt.Errorf("fetching latest ledger: %w", err) + } + validUntil := ledgerNumber + uint32(signatureExpirationLedgers) + + signedEntries := make(xdr.SorobanAuthorizationEntries, 0, len(*authXDR)) + for _, entryB64 := range *authXDR { + var entry xdr.SorobanAuthorizationEntry + if err := xdr.SafeUnmarshalBase64(entryB64, &entry); err != nil { + return nil, fmt.Errorf("unmarshalling authorization entry: %w", err) + } + + signedEntry, err := utils.SignAuthEntry(entry, validUntil, s.signingKP, s.networkPassphrase) + if err != nil { + return nil, fmt.Errorf("signing authorization entry: %w", err) + } + signedEntries = append(signedEntries, signedEntry) + } + + return signedEntries, nil +} + +func (s *sep45Service) fetchSigningKeyFromClientDomain(clientDomain string) (string, error) { + resp, err := s.tomlClient.GetStellarToml(clientDomain) + if err != nil && s.allowHTTPRetry { + if client, ok := s.tomlClient.(*stellartoml.Client); ok { + fallback := *client + fallback.UseHTTP = true + resp, err = fallback.GetStellarToml(clientDomain) + } else { + fallback := &stellartoml.Client{UseHTTP: true} + resp, err = fallback.GetStellarToml(clientDomain) + } + } + if err != nil { + return "", fmt.Errorf("fetching stellar.toml for %s: %w", clientDomain, err) + } + if resp == nil || strings.TrimSpace(resp.SigningKey) == "" { + return "", fmt.Errorf("stellar.toml at %s missing SIGNING_KEY", clientDomain) + } + if !strkey.IsValidEd25519PublicKey(resp.SigningKey) { + return "", fmt.Errorf("stellar.toml SIGNING_KEY at %s is invalid", clientDomain) + } + return resp.SigningKey, nil +} + +func generateNonce() (string, error) { + var buf [4]byte + if _, err := rand.Read(buf[:]); err != nil { + return "", fmt.Errorf("generating nonce: %w", err) + } + return fmt.Sprintf("%d", binary.BigEndian.Uint32(buf[:])), nil +} + +// TODO(philip): Below methods are shared with sep10_service.go so they can be moved to a common utility package later. + +func (s *sep45Service) getWebAuthDomain(ctx context.Context) string { + currentTenant, err := sdpcontext.GetTenantFromContext(ctx) + if err == nil && currentTenant != nil && currentTenant.BaseURL != nil { + parsedURL, parseErr := url.Parse(*currentTenant.BaseURL) + if parseErr == nil { + return parsedURL.Host + } + } + return s.getBaseDomain() +} + +func (s *sep45Service) getBaseDomain() string { + parsed, err := url.Parse(s.baseURL) + if err != nil { + return "" + } + return parsed.Host +} + +func (s *sep45Service) isValidHomeDomain(homeDomain string) bool { + baseDomain := s.getBaseDomain() + if baseDomain == "" || homeDomain == "" { + return false + } + + baseDomainLower := strings.ToLower(baseDomain) + homeDomainLower := strings.ToLower(homeDomain) + + if homeDomainLower == baseDomainLower { + return true + } + + return strings.HasSuffix(homeDomainLower, "."+baseDomainLower) +} diff --git a/internal/services/sep45_service_test.go b/internal/services/sep45_service_test.go new file mode 100644 index 000000000..944e98386 --- /dev/null +++ b/internal/services/sep45_service_test.go @@ -0,0 +1,430 @@ +package services + +import ( + "context" + "encoding/base64" + "testing" + + "github.com/stellar/go/clients/stellartoml" + "github.com/stellar/go/keypair" + "github.com/stellar/go/network" + "github.com/stellar/go/strkey" + "github.com/stellar/go/xdr" + "github.com/stellar/stellar-rpc/protocol" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/stellar/stellar-disbursement-platform-backend/internal/stellar" + stellarMocks "github.com/stellar/stellar-disbursement-platform-backend/internal/stellar/mocks" + "github.com/stellar/stellar-disbursement-platform-backend/internal/utils" +) + +const ( + testWebAuthVerifyContractID = "CD3LA6RKF5D2FN2R2L57MWXLBRSEWWENE74YBEFZSSGNJRJGICFGQXMX" + testClientContractAddress = "CCYU2FUIMK23K34U3SWCN2O2JVI6JBGUGQUILYK7GRPCIDABVVTCS7R4" +) + +var ( + testWebAuthVerifyContract = decodeContractID(testWebAuthVerifyContractID) + testClientContractID = decodeContractID(testClientContractAddress) +) + +func Test_SEP45ChallengeRequest_Validate(t *testing.T) { + validAccount := testClientContractAddress + + testCases := []struct { + name string + req SEP45ChallengeRequest + expectError bool + errMsg string + }{ + { + name: "valid contract address", + req: SEP45ChallengeRequest{ + Account: validAccount, + HomeDomain: "home.example.com", + }, + }, + { + name: "missing home domain", + req: SEP45ChallengeRequest{ + Account: validAccount, + }, + expectError: true, + errMsg: "home_domain is required", + }, + { + name: "missing account", + req: SEP45ChallengeRequest{}, + expectError: true, + errMsg: "account is required", + }, + { + name: "invalid account type", + req: SEP45ChallengeRequest{ + Account: keypair.MustRandom().Address(), + }, + expectError: true, + errMsg: "account must be a valid contract address", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.req.Validate() + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.errMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func Test_SEP45Service_CreateChallenge(t *testing.T) { + testCases := []struct { + name string + build func(t *testing.T) (context.Context, SEP45ServiceOptions, SEP45ChallengeRequest, func(*testing.T, *SEP45ChallengeResponse)) + expectError bool + errContains string + }{ + { + name: "valid challenge request with client domain", + build: func(t *testing.T) (context.Context, SEP45ServiceOptions, SEP45ChallengeRequest, func(*testing.T, *SEP45ChallengeResponse)) { + t.Helper() + + ctx := context.Background() + serverKP := keypair.MustRandom() + clientDomainKP := keypair.MustRandom() + clientContractAddress := testClientContractAddress + clientContractID := testClientContractID + + rpcMock := stellarMocks.NewMockRPCClient(t) + tomlMock := &stellartoml.MockClient{} + + clientDomain := "wallet.example.com" + homeDomain := "home.example.com" + baseHost := "example.com" + + argEntries := xdr.ScMap{ + utils.NewSymbolStringEntry("account", clientContractAddress), + utils.NewSymbolStringEntry("client_domain", clientDomain), + utils.NewSymbolStringEntry("client_domain_account", clientDomainKP.Address()), + utils.NewSymbolStringEntry("home_domain", homeDomain), + utils.NewSymbolStringEntry("nonce", "nonce-value"), + utils.NewSymbolStringEntry("web_auth_domain", baseHost), + utils.NewSymbolStringEntry("web_auth_domain_account", serverKP.Address()), + } + + serverAccountID := xdr.MustAddress(serverKP.Address()) + clientDomainAccountID := xdr.MustAddress(clientDomainKP.Address()) + authEntries := marshalAuthorizationEntries(t, []xdr.SorobanAuthorizationEntry{ + makeAuthorizationEntry(t, testWebAuthVerifyContract, xdr.ScAddress{Type: xdr.ScAddressTypeScAddressTypeAccount, AccountId: &serverAccountID}, argEntries), + makeAuthorizationEntry(t, testWebAuthVerifyContract, xdr.ScAddress{Type: xdr.ScAddressTypeScAddressTypeContract, ContractId: &clientContractID}, argEntries), + makeAuthorizationEntry(t, testWebAuthVerifyContract, xdr.ScAddress{Type: xdr.ScAddressTypeScAddressTypeAccount, AccountId: &clientDomainAccountID}, argEntries), + }) + + var capturedTx string + rpcMock. + On("SimulateTransaction", mock.Anything, mock.MatchedBy(func(req protocol.SimulateTransactionRequest) bool { + capturedTx = req.Transaction + return true + })). + Return(&stellar.SimulationResult{ + Response: protocol.SimulateTransactionResponse{ + Results: []protocol.SimulateHostFunctionResult{{AuthXDR: &authEntries}}, + }, + }, (*stellar.SimulationError)(nil)). + Once() + rpcMock. + On("GetLatestLedgerSequence", mock.Anything). + Return(uint32(100), nil). + Once() + + tomlMock. + On("GetStellarToml", clientDomain). + Return(&stellartoml.Response{SigningKey: clientDomainKP.Address()}, nil). + Once() + + clientDomainCopy := clientDomain + assertFn := func(t *testing.T, resp *SEP45ChallengeResponse) { + require.Equal(t, network.TestNetworkPassphrase, resp.NetworkPassphrase) + + rawEntries, err := base64.StdEncoding.DecodeString(resp.AuthorizationEntries) + require.NoError(t, err) + + var signedEntries xdr.SorobanAuthorizationEntries + require.NoError(t, signedEntries.UnmarshalBinary(rawEntries)) + require.Len(t, signedEntries, 3) + require.Equal(t, xdr.Uint32(100+signatureExpirationLedgers), signedEntries[0].Credentials.Address.SignatureExpirationLedger) + + sigVec, ok := signedEntries[0].Credentials.Address.Signature.GetVec() + require.True(t, ok) + require.NotNil(t, sigVec) + require.NotZero(t, len(*sigVec)) + + argsMap := extractInvokeArgs(t, capturedTx) + assert.Equal(t, clientContractAddress, argsMap["account"]) + assert.Equal(t, clientDomain, argsMap["client_domain"]) + assert.Equal(t, clientDomainKP.Address(), argsMap["client_domain_account"]) + assert.Equal(t, homeDomain, argsMap["home_domain"]) + assert.Equal(t, baseHost, argsMap["web_auth_domain"]) + assert.Equal(t, serverKP.Address(), argsMap["web_auth_domain_account"]) + assert.NotEmpty(t, argsMap["nonce"]) + } + + return ctx, SEP45ServiceOptions{ + RPCClient: rpcMock, + TOMLClient: tomlMock, + NetworkPassphrase: network.TestNetworkPassphrase, + WebAuthVerifyContractID: testWebAuthVerifyContractID, + ServerSigningKeypair: serverKP, + BaseURL: "https://" + baseHost, + AllowHTTPRetry: true, + ClientAttributionRequired: true, + }, SEP45ChallengeRequest{ + Account: clientContractAddress, + HomeDomain: homeDomain, + ClientDomain: &clientDomainCopy, + }, assertFn + }, + }, + { + name: "invalid account", + expectError: true, + errContains: "account", + build: func(t *testing.T) (context.Context, SEP45ServiceOptions, SEP45ChallengeRequest, func(*testing.T, *SEP45ChallengeResponse)) { + t.Helper() + serverKP := keypair.MustRandom() + ctx := context.Background() + opts := SEP45ServiceOptions{ + RPCClient: stellarMocks.NewMockRPCClient(t), + TOMLClient: stellartoml.DefaultClient, + NetworkPassphrase: network.TestNetworkPassphrase, + WebAuthVerifyContractID: testWebAuthVerifyContractID, + ServerSigningKeypair: serverKP, + BaseURL: "https://home.example.com", + AllowHTTPRetry: true, + } + req := SEP45ChallengeRequest{Account: "invalid-account", HomeDomain: "home.example.com"} + return ctx, opts, req, nil + }, + }, + { + name: "invalid home domain", + expectError: true, + errContains: "home_domain", + build: func(t *testing.T) (context.Context, SEP45ServiceOptions, SEP45ChallengeRequest, func(*testing.T, *SEP45ChallengeResponse)) { + t.Helper() + serverKP := keypair.MustRandom() + clientDomain := "wallet.example.com" + clientDomainCopy := clientDomain + ctx := context.Background() + opts := SEP45ServiceOptions{ + RPCClient: stellarMocks.NewMockRPCClient(t), + TOMLClient: stellartoml.DefaultClient, + NetworkPassphrase: network.TestNetworkPassphrase, + WebAuthVerifyContractID: testWebAuthVerifyContractID, + ServerSigningKeypair: serverKP, + BaseURL: "https://allowed.example.com", + AllowHTTPRetry: true, + } + req := SEP45ChallengeRequest{ + Account: testClientContractAddress, + HomeDomain: "other.example.com", + ClientDomain: &clientDomainCopy, + } + return ctx, opts, req, nil + }, + }, + { + name: "missing home domain", + expectError: true, + errContains: "home_domain", + build: func(t *testing.T) (context.Context, SEP45ServiceOptions, SEP45ChallengeRequest, func(*testing.T, *SEP45ChallengeResponse)) { + t.Helper() + serverKP := keypair.MustRandom() + ctx := context.Background() + opts := SEP45ServiceOptions{ + RPCClient: stellarMocks.NewMockRPCClient(t), + TOMLClient: stellartoml.DefaultClient, + NetworkPassphrase: network.TestNetworkPassphrase, + WebAuthVerifyContractID: testWebAuthVerifyContractID, + ServerSigningKeypair: serverKP, + BaseURL: "https://home.example.com", + AllowHTTPRetry: true, + } + req := SEP45ChallengeRequest{Account: testClientContractAddress} + return ctx, opts, req, nil + }, + }, + { + name: "requires client domain when attribution enforced", + expectError: true, + errContains: "client_domain", + build: func(t *testing.T) (context.Context, SEP45ServiceOptions, SEP45ChallengeRequest, func(*testing.T, *SEP45ChallengeResponse)) { + t.Helper() + serverKP := keypair.MustRandom() + ctx := context.Background() + opts := SEP45ServiceOptions{ + RPCClient: stellarMocks.NewMockRPCClient(t), + TOMLClient: stellartoml.DefaultClient, + NetworkPassphrase: network.TestNetworkPassphrase, + WebAuthVerifyContractID: testWebAuthVerifyContractID, + ServerSigningKeypair: serverKP, + BaseURL: "https://home.example.com", + AllowHTTPRetry: true, + ClientAttributionRequired: true, + } + req := SEP45ChallengeRequest{Account: testClientContractAddress, HomeDomain: "home.example.com"} + return ctx, opts, req, nil + }, + }, + { + name: "errors when client domain signing key missing", + expectError: true, + errContains: "SIGNING_KEY", + build: func(t *testing.T) (context.Context, SEP45ServiceOptions, SEP45ChallengeRequest, func(*testing.T, *SEP45ChallengeResponse)) { + t.Helper() + + ctx := context.Background() + serverKP := keypair.MustRandom() + rpcMock := stellarMocks.NewMockRPCClient(t) + tomlMock := &stellartoml.MockClient{} + + clientDomain := "wallet.example.com" + tomlMock. + On("GetStellarToml", clientDomain). + Return(&stellartoml.Response{SigningKey: ""}, nil). + Once() + + clientDomainCopy := clientDomain + + opts := SEP45ServiceOptions{ + RPCClient: rpcMock, + TOMLClient: tomlMock, + NetworkPassphrase: network.TestNetworkPassphrase, + WebAuthVerifyContractID: testWebAuthVerifyContractID, + ServerSigningKeypair: serverKP, + BaseURL: "https://home.example.com", + AllowHTTPRetry: true, + ClientAttributionRequired: true, + } + req := SEP45ChallengeRequest{ + Account: testClientContractAddress, + HomeDomain: "home.example.com", + ClientDomain: &clientDomainCopy, + } + return ctx, opts, req, nil + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, opts, req, assertFn := tc.build(t) + svc, err := NewSEP45Service(opts) + require.NoError(t, err) + + resp, err := svc.CreateChallenge(ctx, req) + if tc.expectError { + require.Error(t, err) + if tc.errContains != "" { + assert.Contains(t, err.Error(), tc.errContains) + } + require.Nil(t, resp) + return + } + + require.NoError(t, err) + require.NotNil(t, resp) + if assertFn != nil { + assertFn(t, resp) + } + }) + } +} + +func extractInvokeArgs(t *testing.T, txB64 string) map[string]string { + t.Helper() + + var env xdr.TransactionEnvelope + require.NoError(t, xdr.SafeUnmarshalBase64(txB64, &env)) + require.Equal(t, xdr.EnvelopeTypeEnvelopeTypeTx, env.Type) + + // Extract the first op + ops := env.V1.Tx.Operations + require.NotEmpty(t, ops) + + // Extract the invoke contract args + invoke := ops[0].Body.MustInvokeHostFunctionOp() + args := invoke.HostFunction.MustInvokeContract().Args + require.NotEmpty(t, args) + + // Put it in a map + argMap := args[0].MustMap() + result := make(map[string]string, len(*argMap)) + for _, entry := range *argMap { + sym := entry.Key.MustSym() + if str, ok := entry.Val.GetStr(); ok { + result[string(sym)] = string(str) + } + } + return result +} + +func marshalAuthorizationEntries(t *testing.T, entries []xdr.SorobanAuthorizationEntry) []string { + t.Helper() + encoded := make([]string, 0, len(entries)) + for _, entry := range entries { + bytes, err := entry.MarshalBinary() + require.NoError(t, err) + encoded = append(encoded, base64.StdEncoding.EncodeToString(bytes)) + } + return encoded +} + +func makeAuthorizationEntry(t *testing.T, contractID xdr.ContractId, address xdr.ScAddress, argEntries xdr.ScMap) xdr.SorobanAuthorizationEntry { + t.Helper() + mapVal, err := xdr.NewScVal(xdr.ScValTypeScvMap, &argEntries) + require.NoError(t, err) + emptyVec := xdr.ScVec{} + emptySignature, err := xdr.NewScVal(xdr.ScValTypeScvVec, &emptyVec) + require.NoError(t, err) + return xdr.SorobanAuthorizationEntry{ + Credentials: xdr.SorobanCredentials{ + Type: xdr.SorobanCredentialsTypeSorobanCredentialsAddress, + Address: &xdr.SorobanAddressCredentials{ + Address: address, + Nonce: 0, + SignatureExpirationLedger: 0, + Signature: emptySignature, + }, + }, + RootInvocation: xdr.SorobanAuthorizedInvocation{ + Function: xdr.SorobanAuthorizedFunction{ + Type: xdr.SorobanAuthorizedFunctionTypeSorobanAuthorizedFunctionTypeContractFn, + ContractFn: &xdr.InvokeContractArgs{ + ContractAddress: xdr.ScAddress{ + Type: xdr.ScAddressTypeScAddressTypeContract, + ContractId: &contractID, + }, + FunctionName: "web_auth_verify", + Args: xdr.ScVec{mapVal}, + }, + }, + }, + } +} + +func decodeContractID(contract string) xdr.ContractId { + raw, err := strkey.Decode(strkey.VersionByteContract, contract) + if err != nil { + panic(err) + } + var id xdr.ContractId + copy(id[:], raw) + return id +} diff --git a/internal/stellar/mocks/rpc_client.go b/internal/stellar/mocks/rpc_client.go index 00a57615b..737e1fae2 100644 --- a/internal/stellar/mocks/rpc_client.go +++ b/internal/stellar/mocks/rpc_client.go @@ -48,6 +48,38 @@ func (_m *MockRPCClient) SimulateTransaction(ctx context.Context, request protoc return r0, r1 } +// GetLatestLedgerSequence provides a mock function with given fields: ctx +func (_m *MockRPCClient) GetLatestLedgerSequence(ctx context.Context) (uint32, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetLatestLedgerSequence") + } + + var r0 uint32 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (uint32, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) uint32); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(uint32) + } + } + + if len(ret) > 1 { + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + } + + return r0, r1 +} + // NewMockRPCClient creates a new instance of MockRPCClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockRPCClient(t interface { diff --git a/internal/stellar/rpc.go b/internal/stellar/rpc.go index 5bcb432ed..3cc001b9d 100644 --- a/internal/stellar/rpc.go +++ b/internal/stellar/rpc.go @@ -144,4 +144,5 @@ func isResourceError(err string) bool { //go:generate mockery --name=RPCClient --case=underscore --structname=MockRPCClient --filename=rpc_client.go type RPCClient interface { SimulateTransaction(ctx context.Context, request protocol.SimulateTransactionRequest) (*SimulationResult, *SimulationError) + GetLatestLedgerSequence(ctx context.Context) (uint32, error) } diff --git a/internal/stellar/rpc_client.go b/internal/stellar/rpc_client.go index 7e979ff60..87cae940f 100644 --- a/internal/stellar/rpc_client.go +++ b/internal/stellar/rpc_client.go @@ -47,6 +47,19 @@ func (w *RPCClientWrapper) SimulateTransaction(ctx context.Context, request prot }, nil } +func (w *RPCClientWrapper) GetLatestLedgerSequence(ctx context.Context) (uint32, error) { + if w.client == nil { + return 0, errors.New("RPC client not initialized") + } + + resp, err := w.client.GetLatestLedger(ctx) + if err != nil { + return 0, fmt.Errorf("getting latest ledger sequence: %w", err) + } + + return resp.Sequence, nil +} + func NewHTTPClientWithAuth(authHeaderKey, authHeaderValue string) (*http.Client, error) { if authHeaderKey == "" && authHeaderValue == "" { return http.DefaultClient, nil diff --git a/internal/utils/xdr.go b/internal/utils/xdr.go new file mode 100644 index 000000000..90e15ad5c --- /dev/null +++ b/internal/utils/xdr.go @@ -0,0 +1,143 @@ +package utils + +import ( + "crypto/sha256" + "fmt" + + "github.com/stellar/go/keypair" + "github.com/stellar/go/network" + "github.com/stellar/go/strkey" + "github.com/stellar/go/xdr" +) + +// NewSymbolStringEntry constructs an ScMapEntry for the provided key/value pair. +func NewSymbolStringEntry(key, value string) xdr.ScMapEntry { + symbol := xdr.ScSymbol(key) + str := xdr.ScString(value) + return xdr.ScMapEntry{ + Key: xdr.ScVal{ + Type: xdr.ScValTypeScvSymbol, + Sym: &symbol, + }, + Val: xdr.ScVal{ + Type: xdr.ScValTypeScvString, + Str: &str, + }, + } +} + +// BuildAuthorizationPayload produces the hash payload Soroban expects for signature verification. +func BuildAuthorizationPayload(entry xdr.SorobanAuthorizationEntry, networkPassphrase string) ([32]byte, error) { + var zero [32]byte + if entry.Credentials.Address == nil { + return zero, fmt.Errorf("authorization entry missing address credentials") + } + + preimage := xdr.HashIdPreimage{ + Type: xdr.EnvelopeTypeEnvelopeTypeSorobanAuthorization, + SorobanAuthorization: &xdr.HashIdPreimageSorobanAuthorization{ + NetworkId: xdr.Hash(network.ID(networkPassphrase)), + Nonce: entry.Credentials.Address.Nonce, + SignatureExpirationLedger: entry.Credentials.Address.SignatureExpirationLedger, + Invocation: entry.RootInvocation, + }, + } + preimageBytes, err := preimage.MarshalBinary() + if err != nil { + return zero, fmt.Errorf("marshalling authorization preimage: %w", err) + } + payload := sha256.Sum256(preimageBytes) + return payload, nil +} + +// SignAuthEntry signs the authorization entry if it belongs to the provided signing account. +func SignAuthEntry(entry xdr.SorobanAuthorizationEntry, validUntil uint32, signingKP *keypair.Full, networkPassphrase string) (xdr.SorobanAuthorizationEntry, error) { + if entry.Credentials.Type != xdr.SorobanCredentialsTypeSorobanCredentialsAddress { + return entry, nil + } + if entry.Credentials.Address == nil { + return entry, fmt.Errorf("address credentials missing") + } + + addr := entry.Credentials.Address.Address + if addr.Type != xdr.ScAddressTypeScAddressTypeAccount || addr.AccountId == nil { + return entry, nil + } + + serverAccountID := xdr.MustAddress(signingKP.Address()) + if !addr.AccountId.Equals(serverAccountID) { + return entry, nil + } + + encoded, err := entry.MarshalBinary() + if err != nil { + return entry, fmt.Errorf("marshalling authorization entry: %w", err) + } + + var clone xdr.SorobanAuthorizationEntry + if err := clone.UnmarshalBinary(encoded); err != nil { + return entry, fmt.Errorf("cloning authorization entry: %w", err) + } + + clone.Credentials.Address.SignatureExpirationLedger = xdr.Uint32(validUntil) + + payload, err := BuildAuthorizationPayload(clone, networkPassphrase) + if err != nil { + return entry, fmt.Errorf("encoding authorization preimage: %w", err) + } + + signature, err := signingKP.Sign(payload[:]) + if err != nil { + return entry, fmt.Errorf("signing authorization entry: %w", err) + } + if err := signingKP.Verify(payload[:], signature); err != nil { + return entry, fmt.Errorf("signature verification failed: %w", err) + } + + publicKeyRaw, err := strkey.Decode(strkey.VersionByteAccountID, signingKP.Address()) + if err != nil { + return entry, fmt.Errorf("decoding signing public key: %w", err) + } + + pkBytes := xdr.ScBytes(publicKeyRaw) + sigBytes := xdr.ScBytes(signature) + + publicKeySymbol := xdr.ScSymbol("public_key") + signatureSymbol := xdr.ScSymbol("signature") + entries := xdr.ScMap{ + { + Key: xdr.ScVal{ + Type: xdr.ScValTypeScvSymbol, + Sym: &publicKeySymbol, + }, + Val: xdr.ScVal{ + Type: xdr.ScValTypeScvBytes, + Bytes: &pkBytes, + }, + }, + { + Key: xdr.ScVal{ + Type: xdr.ScValTypeScvSymbol, + Sym: &signatureSymbol, + }, + Val: xdr.ScVal{ + Type: xdr.ScValTypeScvBytes, + Bytes: &sigBytes, + }, + }, + } + + mapVal, err := xdr.NewScVal(xdr.ScValTypeScvMap, &entries) + if err != nil { + return entry, fmt.Errorf("building signature map: %w", err) + } + + vector := xdr.ScVec{mapVal} + vecVal, err := xdr.NewScVal(xdr.ScValTypeScvVec, &vector) + if err != nil { + return entry, fmt.Errorf("building signature vector: %w", err) + } + + clone.Credentials.Address.Signature = vecVal + return clone, nil +} diff --git a/internal/utils/xdr_test.go b/internal/utils/xdr_test.go new file mode 100644 index 000000000..dcb64755a --- /dev/null +++ b/internal/utils/xdr_test.go @@ -0,0 +1,181 @@ +package utils + +import ( + "testing" + + "github.com/stellar/go/keypair" + "github.com/stellar/go/network" + "github.com/stellar/go/strkey" + "github.com/stellar/go/xdr" + "github.com/stretchr/testify/require" +) + +func Test_NewSymbolStringEntry(t *testing.T) { + entry := NewSymbolStringEntry("foo", "bar") + require.Equal(t, xdr.ScValTypeScvSymbol, entry.Key.Type) + require.NotNil(t, entry.Key.Sym) + require.Equal(t, xdr.ScSymbol("foo"), *entry.Key.Sym) + + require.Equal(t, xdr.ScValTypeScvString, entry.Val.Type) + require.NotNil(t, entry.Val.Str) + require.Equal(t, xdr.ScString("bar"), *entry.Val.Str) +} + +func Test_BuildAuthorizationPayload(t *testing.T) { + testCases := []struct { + name string + entry xdr.SorobanAuthorizationEntry + expectError string + }{ + { + name: "missing address credentials returns error", + entry: xdr.SorobanAuthorizationEntry{ + Credentials: xdr.SorobanCredentials{Type: xdr.SorobanCredentialsTypeSorobanCredentialsAddress}, + }, + expectError: "authorization entry missing address credentials", + }, + { + name: "returns payload for valid entry", + entry: newTestAuthEntry(t, keypair.MustRandom().Address()), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + payload, err := BuildAuthorizationPayload(tc.entry, network.TestNetworkPassphrase) + if tc.expectError != "" { + require.ErrorContains(t, err, tc.expectError) + return + } + require.NoError(t, err) + require.NotEqual(t, [32]byte{}, payload) + }) + } +} + +func Test_SignAuthEntry(t *testing.T) { + serverKP := keypair.MustRandom() + + testCases := []struct { + name string + buildEntry func(t *testing.T) xdr.SorobanAuthorizationEntry + validUntil uint32 + shouldSign bool + signatureCount int + }{ + { + name: "non-server account remains unchanged", + validUntil: 200, + buildEntry: func(t *testing.T) xdr.SorobanAuthorizationEntry { + return newTestAuthEntry(t, keypair.MustRandom().Address()) + }, + }, + { + name: "server account gets signed", + validUntil: 500, + shouldSign: true, + buildEntry: func(t *testing.T) xdr.SorobanAuthorizationEntry { + return newTestAuthEntry(t, serverKP.Address()) + }, + signatureCount: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + entry := tc.buildEntry(t) + origCopy := entry + result, err := SignAuthEntry(entry, tc.validUntil, serverKP, network.TestNetworkPassphrase) + require.NoError(t, err) + + if !tc.shouldSign { + require.Equal(t, origCopy, result) + return + } + + require.Equal(t, xdr.Uint32(tc.validUntil), result.Credentials.Address.SignatureExpirationLedger) + sigVec, ok := result.Credentials.Address.Signature.GetVec() + require.True(t, ok) + require.NotNil(t, sigVec) + require.Len(t, *sigVec, tc.signatureCount) + + sigMap, ok := (*sigVec)[0].GetMap() + require.True(t, ok) + require.NotNil(t, sigMap) + require.Len(t, *sigMap, 2) + + var ( + extractedPub []byte + extractedSig []byte + ) + for _, entry := range *sigMap { + key := entry.Key.MustSym() + switch string(key) { + case "public_key": + bytesVal, ok := entry.Val.GetBytes() + require.True(t, ok) + extractedPub = append([]byte(nil), bytesVal...) + case "signature": + bytesVal, ok := entry.Val.GetBytes() + require.True(t, ok) + extractedSig = append([]byte(nil), bytesVal...) + } + } + require.NotEmpty(t, extractedPub) + require.NotEmpty(t, extractedSig) + + serverPubRaw, err := strkey.Decode(strkey.VersionByteAccountID, serverKP.Address()) + require.NoError(t, err) + require.Equal(t, serverPubRaw, extractedPub) + + payload, err := BuildAuthorizationPayload(result, network.TestNetworkPassphrase) + require.NoError(t, err) + require.NoError(t, serverKP.Verify(payload[:], extractedSig)) + }) + } +} + +func newTestAuthEntry(t *testing.T, account string) xdr.SorobanAuthorizationEntry { + t.Helper() + + accountID := xdr.MustAddress(account) + accountAddress := xdr.ScAddress{ + Type: xdr.ScAddressTypeScAddressTypeAccount, + AccountId: &accountID, + } + var contractID xdr.ContractId + for i := range contractID { + contractID[i] = byte(i + 1) + } + contractAddress := xdr.ScAddress{ + Type: xdr.ScAddressTypeScAddressTypeContract, + ContractId: &contractID, + } + + emptyVec := xdr.ScVec{} + emptySignature, err := xdr.NewScVal(xdr.ScValTypeScvVec, &emptyVec) + require.NoError(t, err) + + return xdr.SorobanAuthorizationEntry{ + Credentials: xdr.SorobanCredentials{ + Type: xdr.SorobanCredentialsTypeSorobanCredentialsAddress, + Address: &xdr.SorobanAddressCredentials{ + Address: accountAddress, + Nonce: xdr.Int64(42), + SignatureExpirationLedger: xdr.Uint32(10), + Signature: emptySignature, + }, + }, + RootInvocation: xdr.SorobanAuthorizedInvocation{ + Function: xdr.SorobanAuthorizedFunction{ + Type: xdr.SorobanAuthorizedFunctionTypeSorobanAuthorizedFunctionTypeContractFn, + ContractFn: &xdr.InvokeContractArgs{ + ContractAddress: contractAddress, + FunctionName: xdr.ScSymbol("web_auth_verify"), + Args: xdr.ScVec{}, + }, + }, + SubInvocations: nil, + }, + } +}