diff --git a/signer/plugin.go b/signer/plugin.go index 6d8615d7..7a01051f 100644 --- a/signer/plugin.go +++ b/signer/plugin.go @@ -375,3 +375,49 @@ func (s *pluginPrimitiveSigner) Sign(payload []byte) ([]byte, []*x509.Certificat func (s *pluginPrimitiveSigner) KeySpec() (signature.KeySpec, error) { return s.keySpec, nil } + +// NewPluginPrimitiveSigner returns a signature.Signer that delegates raw +// signature generation to a plugin. It is intended for callers that need raw +// signature bytes (PKCS#7) rather than a JWS/COSE envelope. +func NewPluginPrimitiveSigner(ctx context.Context, p plugin.SignPlugin, keyID string, keySpec signature.KeySpec, pluginConfig map[string]string) (signature.Signer, error) { + if p == nil { + return nil, errors.New("nil plugin") + } + if keyID == "" { + return nil, errors.New("keyID not specified") + } + if _, err := proto.HashAlgorithmFromKeySpec(keySpec); err != nil { + return nil, fmt.Errorf("invalid keySpec: %w", err) + } + return &pluginPrimitiveSigner{ + ctx: ctx, + plugin: p, + keyID: keyID, + keySpec: keySpec, + pluginConfig: pluginConfig, + }, nil +} + +// KeySpecFromPlugin retrieves and validates the key specification from a +// plugin by calling DescribeKey. It enforces that the plugin's response +// references the same keyID that was requested. +func KeySpecFromPlugin(ctx context.Context, p plugin.SignPlugin, keyID string, pluginConfig map[string]string) (signature.KeySpec, error) { + if p == nil { + return signature.KeySpec{}, errors.New("nil plugin") + } + if keyID == "" { + return signature.KeySpec{}, errors.New("keyID not specified") + } + resp, err := p.DescribeKey(ctx, &plugin.DescribeKeyRequest{ + ContractVersion: plugin.ContractVersion, + KeyID: keyID, + PluginConfig: pluginConfig, + }) + if err != nil { + return signature.KeySpec{}, fmt.Errorf("failed to describe key %q: %w", keyID, err) + } + if resp.KeyID != keyID { + return signature.KeySpec{}, fmt.Errorf("keyID in describeKey response %q does not match request %q", resp.KeyID, keyID) + } + return proto.DecodeKeySpec(resp.KeySpec) +} diff --git a/signer/plugin_test.go b/signer/plugin_test.go index a98caa07..34ea1b8b 100644 --- a/signer/plugin_test.go +++ b/signer/plugin_test.go @@ -60,15 +60,17 @@ func init() { } type mockPlugin struct { - failEnvelope bool - wantEnvelope bool - invalidSig bool - invalidCertChain bool - invalidDescriptor bool - annotations map[string]string - key crypto.PrivateKey - certs []*x509.Certificate - keySpec signature.KeySpec + failEnvelope bool + wantEnvelope bool + invalidSig bool + invalidCertChain bool + invalidDescriptor bool + describeKeyErr error + describeKeyIDOverride string + annotations map[string]string + key crypto.PrivateKey + certs []*x509.Certificate + keySpec signature.KeySpec } func getDescriptorFunc(throwError bool) func(hashAlgo digest.Algorithm) (ocispec.Descriptor, error) { @@ -108,8 +110,12 @@ func (p *mockPlugin) GetMetadata(ctx context.Context, req *proto.GetMetadataRequ // DescribeKey returns the KeySpec of a key. func (p *mockPlugin) DescribeKey(ctx context.Context, req *proto.DescribeKeyRequest) (*proto.DescribeKeyResponse, error) { + if p.describeKeyErr != nil { + return nil, p.describeKeyErr + } ks, _ := proto.EncodeKeySpec(p.keySpec) return &proto.DescribeKeyResponse{ + KeyID: p.describeKeyIDOverride, KeySpec: ks, }, nil } @@ -492,3 +498,157 @@ func basicSignTest(t *testing.T, ps *PluginSigner, envelopeType string, data []b } basicVerification(t, data, envelopeType, mockPlugin.certs[len(mockPlugin.certs)-1], &validMetadata) } + +func TestNewPluginPrimitiveSigner(t *testing.T) { + ctx := context.Background() + mp := newMockPlugin(defaultKeyCert.key, defaultKeyCert.certs, defaultKeySpec) + + s, err := NewPluginPrimitiveSigner(ctx, mp, "testKeyID", defaultKeySpec, nil) + if err != nil { + t.Fatalf("NewPluginPrimitiveSigner() error: %v", err) + } + + // verify KeySpec + ks, err := s.KeySpec() + if err != nil { + t.Fatalf("KeySpec() error: %v", err) + } + if ks != defaultKeySpec { + t.Fatalf("KeySpec() = %v, want %v", ks, defaultKeySpec) + } + + // verify Sign + sig, certs, err := s.Sign([]byte("payload")) + if err != nil { + t.Fatalf("Sign() error: %v", err) + } + if len(sig) == 0 { + t.Fatal("Sign() returned empty signature") + } + if len(certs) == 0 { + t.Fatal("Sign() returned no certificates") + } +} + +func TestNewPluginPrimitiveSigner_NilPlugin(t *testing.T) { + _, err := NewPluginPrimitiveSigner(context.Background(), nil, "testKeyID", defaultKeySpec, nil) + if err == nil { + t.Fatal("expected error for nil plugin, got nil") + } + if !strings.Contains(err.Error(), "nil plugin") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestNewPluginPrimitiveSigner_EmptyKeyID(t *testing.T) { + mp := newMockPlugin(defaultKeyCert.key, defaultKeyCert.certs, defaultKeySpec) + _, err := NewPluginPrimitiveSigner(context.Background(), mp, "", defaultKeySpec, nil) + if err == nil { + t.Fatal("expected error for empty keyID, got nil") + } + if !strings.Contains(err.Error(), "keyID") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestNewPluginPrimitiveSigner_InvalidKeySpec(t *testing.T) { + mp := newMockPlugin(defaultKeyCert.key, defaultKeyCert.certs, defaultKeySpec) + _, err := NewPluginPrimitiveSigner(context.Background(), mp, "testKeyID", signature.KeySpec{}, nil) + if err == nil { + t.Fatal("expected error for invalid keySpec, got nil") + } + if !strings.Contains(err.Error(), "invalid keySpec") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestPluginPrimitiveSigner_SignError(t *testing.T) { + mp := newMockPlugin(defaultKeyCert.key, defaultKeyCert.certs, defaultKeySpec) + mp.invalidCertChain = true // plugin returns unparseable cert chain + s, err := NewPluginPrimitiveSigner(context.Background(), mp, "testKeyID", defaultKeySpec, nil) + if err != nil { + t.Fatalf("NewPluginPrimitiveSigner() error: %v", err) + } + if _, _, err := s.Sign([]byte("payload")); err == nil { + t.Fatal("expected Sign() to fail when plugin returns invalid cert chain, got nil") + } +} + +func TestPluginPrimitiveSigner_ECDSA(t *testing.T) { + var ecPair *keyCertPair + var ecKeySpec signature.KeySpec + for _, p := range keyCertPairCollections { + ks, err := signature.ExtractKeySpec(p.certs[0]) + if err == nil && ks.Type == signature.KeyTypeEC { + ecPair = p + ecKeySpec = ks + break + } + } + if ecPair == nil { + t.Skip("no EC keyCertPair available in test fixtures") + } + mp := newMockPlugin(ecPair.key, ecPair.certs, ecKeySpec) + s, err := NewPluginPrimitiveSigner(context.Background(), mp, "testKeyID", ecKeySpec, nil) + if err != nil { + t.Fatalf("NewPluginPrimitiveSigner() error: %v", err) + } + sig, certs, err := s.Sign([]byte("payload")) + if err != nil { + t.Fatalf("Sign() error: %v", err) + } + if len(sig) == 0 || len(certs) == 0 { + t.Fatal("Sign() returned empty result") + } +} + +func TestKeySpecFromPlugin(t *testing.T) { + ctx := context.Background() + mp := newMockPlugin(defaultKeyCert.key, defaultKeyCert.certs, defaultKeySpec) + mp.describeKeyIDOverride = "testKeyID" + + got, err := KeySpecFromPlugin(ctx, mp, "testKeyID", nil) + if err != nil { + t.Fatalf("KeySpecFromPlugin() error: %v", err) + } + if got != defaultKeySpec { + t.Fatalf("KeySpecFromPlugin() = %v, want %v", got, defaultKeySpec) + } +} + +func TestKeySpecFromPlugin_NilPlugin(t *testing.T) { + if _, err := KeySpecFromPlugin(context.Background(), nil, "testKeyID", nil); err == nil { + t.Fatal("expected error for nil plugin, got nil") + } +} + +func TestKeySpecFromPlugin_EmptyKeyID(t *testing.T) { + mp := newMockPlugin(defaultKeyCert.key, defaultKeyCert.certs, defaultKeySpec) + if _, err := KeySpecFromPlugin(context.Background(), mp, "", nil); err == nil { + t.Fatal("expected error for empty keyID, got nil") + } +} + +func TestKeySpecFromPlugin_KeyIDMismatch(t *testing.T) { + mp := newMockPlugin(defaultKeyCert.key, defaultKeyCert.certs, defaultKeySpec) + mp.describeKeyIDOverride = "differentKeyID" + _, err := KeySpecFromPlugin(context.Background(), mp, "testKeyID", nil) + if err == nil { + t.Fatal("expected keyID mismatch error, got nil") + } + if !strings.Contains(err.Error(), "does not match") { + t.Errorf("unexpected error: %v", err) + } +} + +func TestKeySpecFromPlugin_DescribeKeyError(t *testing.T) { + mp := newMockPlugin(defaultKeyCert.key, defaultKeyCert.certs, defaultKeySpec) + mp.describeKeyErr = errors.New("simulated describeKey failure") + _, err := KeySpecFromPlugin(context.Background(), mp, "testKeyID", nil) + if err == nil { + t.Fatal("expected describeKey error to propagate, got nil") + } + if !strings.Contains(err.Error(), "failed to describe key") { + t.Errorf("expected wrapped error, got: %v", err) + } +}