Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 156 additions & 41 deletions kms/capi/capi.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ const (
HashArg = "sha1"
StoreLocationArg = "store-location" // 'machine', 'user', etc
StoreNameArg = "store" // 'MY', 'CA', 'ROOT', etc
FriendlyNameArg = "friendly-name"
DescriptionArg = "description"
IntermediateStoreLocationArg = "intermediate-store-location"
IntermediateStoreNameArg = "intermediate-store"
KeyIDArg = "key-id"
Expand Down Expand Up @@ -91,6 +93,8 @@ type uriAttributes struct {
subjectCN string
serialNumber *big.Int
issuerName string
friendlyName string
description string
keySpec string
skipFindCertificateKey bool
pin string
Expand Down Expand Up @@ -132,6 +136,8 @@ func parseURI(rawuri string) (*uriAttributes, error) {
subjectCN: u.Get(SubjectCNArg),
serialNumber: serialNumber,
issuerName: u.Get(IssuerNameArg),
friendlyName: u.Get(FriendlyNameArg),
description: u.Get(DescriptionArg),
keySpec: u.Get(KeySpec),
skipFindCertificateKey: u.GetBool(SkipFindCertificateKey),
pin: u.Pin(),
Expand Down Expand Up @@ -392,11 +398,15 @@ func (k *CAPIKMS) getCertContext(u *uriAttributes) (*windows.CertContext, error)
0,
0,
certStoreLocation,
uintptr(unsafe.Pointer(wide(u.storeName))))
uintptr(unsafe.Pointer(wide(u.storeName))),
)
if err != nil {
return nil, fmt.Errorf("CertOpenStore for the %q store %q returned: %w", u.storeLocation, u.storeName, err)
}

// if issuer + any of the other fields in the list below is provided, then attempt a second certificate lookup when
// lookup by KeyID fails (not found).
canLookupByIssuer := u.issuerName != "" && (u.serialNumber != nil || u.subjectCN != "" || u.friendlyName != "" || u.description != "")
var handle *windows.CertContext

switch {
Expand All @@ -421,44 +431,9 @@ func (k *CAPIKMS) getCertContext(u *uriAttributes) (*windows.CertContext, error)
}
case len(u.keyID) > 0:
if handle, err = findCertificateBySubjectKeyID(st, u.keyID); err != nil {
return nil, err
}
case u.issuerName != "" && (u.serialNumber != nil || u.subjectCN != ""):
var prevCert *windows.CertContext
for {
handle, err = findCertificateInStore(st,
encodingX509ASN|encodingPKCS7,
0,
findIssuerStr,
uintptr(unsafe.Pointer(wide(u.issuerName))), prevCert)
if err != nil {
return nil, fmt.Errorf("findCertificateInStore failed: %w", err)
}

if handle == nil {
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%q not found", IssuerNameArg, u.issuerName)}
}

x509Cert, err := certContextToX509(handle)
if err != nil {
return nil, fmt.Errorf("could not unmarshal certificate to DER: %w", err)
if !errors.Is(err, apiv1.NotFoundError{}) || !canLookupByIssuer {
return nil, err
}

switch {
case u.serialNumber != nil:
// TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number
if x509Cert.SerialNumber.Cmp(u.serialNumber) == 0 {
return handle, nil
}
case len(u.subjectCN) > 0:
if x509Cert.Subject.CommonName == u.subjectCN {
return handle, nil
}
}

prevCert = handle
}
case u.containerName != "":
key, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{
Expand All @@ -474,13 +449,75 @@ func (k *CAPIKMS) getCertContext(u *uriAttributes) (*windows.CertContext, error)
return nil, fmt.Errorf("error generating SubjectKeyID: %w", err)
}
if handle, err = findCertificateBySubjectKeyID(st, keyID); err != nil {
return nil, err
if !errors.Is(err, apiv1.NotFoundError{}) || !canLookupByIssuer {
return nil, err
}
}
default:
}

if handle != nil {
return handle, err
}

if !canLookupByIssuer {
return nil, fmt.Errorf("%q, %q, or %q and one of %q or %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg, SubjectCNArg)
}

return handle, err
// lookup certificate by issuer + another field (serial, CN, friendlyName, description)
var prevCert *windows.CertContext
for {
handle, err = findCertificateInStore(st,
encodingX509ASN|encodingPKCS7,
0,
findIssuerStr,
uintptr(unsafe.Pointer(wide(u.issuerName))), prevCert)
if err != nil {
return nil, fmt.Errorf("findCertificateInStore failed: %w", err)
}

if handle == nil {
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%q not found", IssuerNameArg, u.issuerName)}
}

x509Cert, err := certContextToX509(handle)
if err != nil {
return nil, fmt.Errorf("could not unmarshal certificate to DER: %w", err)
}

switch {
case u.serialNumber != nil:
// TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number
if x509Cert.SerialNumber.Cmp(u.serialNumber) == 0 {
return handle, nil
}
case len(u.subjectCN) > 0:
if x509Cert.Subject.CommonName == u.subjectCN {
return handle, nil
}
case len(u.friendlyName) > 0:
val, err := cryptFindCertificateFriendlyName(handle)
if err != nil {
return nil, fmt.Errorf("cryptFindCertificateFriendlyName failed: %w", err)
}

if val == u.friendlyName {
return handle, nil
}
case len(u.description) > 0:
val, err := cryptFindCertificateDescription(handle)
if err != nil {
return nil, fmt.Errorf("cryptFindCertificateDescription failed: %w", err)
}

if val == u.description {
return handle, nil
}
}

prevCert = handle
}
}

// CreateSigner returns a crypto.Signer that will sign using the key passed in via the URI.
Expand Down Expand Up @@ -784,6 +821,74 @@ func (k *CAPIKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) (
return chain, nil
}

// FindCertificatesByIssuer returns all certificates in the Windows certificate
// store that were issued by the given issuer. The URI must contain the "issuer"
// field; "store-location" and "store" are optional (defaulting to "user" and "My").
// When subjectRaw is non-empty, only certificates whose raw DER-encoded Subject
// matches are included.
func (k *CAPIKMS) FindCertificatesByIssuer(req *apiv1.LoadCertificateRequest, subjectRaw []byte) ([]*x509.Certificate, error) {
u, err := parseURI(req.Name)
if err != nil {
return nil, err
}
if u.issuerName == "" {
return nil, fmt.Errorf("%q is required", IssuerNameArg)
}

var certStoreLocation uint32
switch u.storeLocation {
case UserStoreLocation:
certStoreLocation = certStoreCurrentUser
case MachineStoreLocation:
certStoreLocation = certStoreLocalMachine
default:
return nil, fmt.Errorf("invalid cert store location %q", u.storeLocation)
}

st, err := windows.CertOpenStore(
certStoreProvSystem,
0,
0,
certStoreLocation,
uintptr(unsafe.Pointer(wide(u.storeName))),
)
if err != nil {
return nil, fmt.Errorf("CertOpenStore for the %q store %q returned: %w", u.storeLocation, u.storeName, err)
}

var (
certs []*x509.Certificate
prevCert *windows.CertContext
)
for {
certHandle, err := findCertificateInStore(st,
encodingX509ASN|encodingPKCS7,
0,
findIssuerStr,
uintptr(unsafe.Pointer(wide(u.issuerName))), prevCert)
if err != nil {
return nil, fmt.Errorf("findCertificateInStore failed: %w", err)
}
if certHandle == nil {
// prevCert was freed by the last findCertificateInStore call per Windows API contract.
break
}

x509Cert, err := certContextToX509(certHandle)
if err != nil {
windows.CertFreeCertificateContext(certHandle)
return nil, fmt.Errorf("could not unmarshal certificate: %w", err)
}

if len(subjectRaw) == 0 || bytes.Equal(x509Cert.RawSubject, subjectRaw) {
certs = append(certs, x509Cert)
}
prevCert = certHandle // freed on next findCertificateInStore call
}

return certs, nil
}

func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
u, err := parseURI(req.Name)
if err != nil {
Expand Down Expand Up @@ -818,6 +923,14 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
cryptFindCertificateKeyProvInfo(certContext)
}

if u.friendlyName != "" {
cryptSetCertificateFriendlyName(certContext, u.friendlyName)
}

if u.description != "" {
cryptSetCertificateDescription(certContext, u.description)
}

st, err := windows.CertOpenStore(
certStoreProvSystem,
0,
Expand Down Expand Up @@ -853,6 +966,8 @@ func (k *CAPIKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest)
HashArg: []string{fp},
StoreLocationArg: []string{u.storeLocation},
StoreNameArg: []string{u.storeName},
FriendlyNameArg: []string{u.friendlyName},
DescriptionArg: []string{u.description},
SkipFindCertificateKey: []string{strconv.FormatBool(u.skipFindCertificateKey)},
}).String(),
Certificate: leaf,
Expand Down
Loading
Loading