crypto/x509: simplify candidate chain filtering

Use slices.DeleteFunc to remove chains with invalid policies and
incompatible key usage, instead of iterating over the chains and
reconstructing the slice.

Change-Id: I8ad2bc1ac2469d0d18b2c090e3d4f702b1b577cb
Reviewed-on: https://go-review.googlesource.com/c/go/+/708415
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Daniel McCarney <daniel@binaryparadox.net>
Reviewed-by: David Chase <drchase@google.com>
This commit is contained in:
Roland Shoemaker 2025-10-01 08:57:00 -07:00
parent 29046398bb
commit 0d3dab9b1d
2 changed files with 27 additions and 36 deletions

View file

@ -17,6 +17,7 @@ import (
"net/url" "net/url"
"reflect" "reflect"
"runtime" "runtime"
"slices"
"strings" "strings"
"time" "time"
"unicode/utf8" "unicode/utf8"
@ -801,7 +802,7 @@ func (c *Certificate) isValid(certType int, currentChain []*Certificate, opts *V
// Certificates other than c in the returned chains should not be modified. // Certificates other than c in the returned chains should not be modified.
// //
// WARNING: this function doesn't do any revocation checking. // WARNING: this function doesn't do any revocation checking.
func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err error) { func (c *Certificate) Verify(opts VerifyOptions) ([][]*Certificate, error) {
// Platform-specific verification needs the ASN.1 contents so // Platform-specific verification needs the ASN.1 contents so
// this makes the behavior consistent across platforms. // this makes the behavior consistent across platforms.
if len(c.Raw) == 0 { if len(c.Raw) == 0 {
@ -843,15 +844,15 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
} }
} }
err = c.isValid(leafCertificate, nil, &opts) err := c.isValid(leafCertificate, nil, &opts)
if err != nil { if err != nil {
return return nil, err
} }
if len(opts.DNSName) > 0 { if len(opts.DNSName) > 0 {
err = c.VerifyHostname(opts.DNSName) err = c.VerifyHostname(opts.DNSName)
if err != nil { if err != nil {
return return nil, err
} }
} }
@ -865,26 +866,12 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
} }
} }
chains = make([][]*Certificate, 0, len(candidateChains)) anyKeyUsage := false
var invalidPoliciesChains int
for _, candidate := range candidateChains {
if !policiesValid(candidate, opts) {
invalidPoliciesChains++
continue
}
chains = append(chains, candidate)
}
if len(chains) == 0 {
return nil, CertificateInvalidError{c, NoValidChains, "all candidate chains have invalid policies"}
}
for _, eku := range opts.KeyUsages { for _, eku := range opts.KeyUsages {
if eku == ExtKeyUsageAny { if eku == ExtKeyUsageAny {
// If any key usage is acceptable, no need to check the chain for // The presence of anyExtendedKeyUsage overrides any other key usage.
// key usages. anyKeyUsage = true
return chains, nil break
} }
} }
@ -892,34 +879,38 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
opts.KeyUsages = []ExtKeyUsage{ExtKeyUsageServerAuth} opts.KeyUsages = []ExtKeyUsage{ExtKeyUsageServerAuth}
} }
candidateChains = chains var invalidPoliciesChains int
chains = chains[:0]
var incompatibleKeyUsageChains int var incompatibleKeyUsageChains int
for _, candidate := range candidateChains { candidateChains = slices.DeleteFunc(candidateChains, func(chain []*Certificate) bool {
if !checkChainForKeyUsage(candidate, opts.KeyUsages) { if !policiesValid(chain, opts) {
invalidPoliciesChains++
return true
}
// If any key usage is acceptable, no need to check the chain for
// key usages.
if !anyKeyUsage && !checkChainForKeyUsage(chain, opts.KeyUsages) {
incompatibleKeyUsageChains++ incompatibleKeyUsageChains++
continue return true
}
chains = append(chains, candidate)
} }
return false
})
if len(chains) == 0 { if len(candidateChains) == 0 {
var details []string var details []string
if incompatibleKeyUsageChains > 0 { if incompatibleKeyUsageChains > 0 {
if invalidPoliciesChains == 0 { if invalidPoliciesChains == 0 {
return nil, CertificateInvalidError{c, IncompatibleUsage, ""} return nil, CertificateInvalidError{c, IncompatibleUsage, ""}
} }
details = append(details, fmt.Sprintf("%d chains with incompatible key usage", incompatibleKeyUsageChains)) details = append(details, fmt.Sprintf("%d candidate chains with incompatible key usage", incompatibleKeyUsageChains))
} }
if invalidPoliciesChains > 0 { if invalidPoliciesChains > 0 {
details = append(details, fmt.Sprintf("%d chains with invalid policies", invalidPoliciesChains)) details = append(details, fmt.Sprintf("%d candidate chains with invalid policies", invalidPoliciesChains))
} }
err = CertificateInvalidError{c, NoValidChains, strings.Join(details, ", ")} err = CertificateInvalidError{c, NoValidChains, strings.Join(details, ", ")}
return nil, err return nil, err
} }
return chains, nil return candidateChains, nil
} }
func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate { func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate {

View file

@ -3031,7 +3031,7 @@ func TestInvalidPolicyWithAnyKeyUsage(t *testing.T) {
testOID3 := mustNewOIDFromInts([]uint64{1, 2, 840, 113554, 4, 1, 72585, 2, 3}) testOID3 := mustNewOIDFromInts([]uint64{1, 2, 840, 113554, 4, 1, 72585, 2, 3})
root, intermediate, leaf := loadTestCert(t, "testdata/policy_root.pem"), loadTestCert(t, "testdata/policy_intermediate_require.pem"), loadTestCert(t, "testdata/policy_leaf.pem") root, intermediate, leaf := loadTestCert(t, "testdata/policy_root.pem"), loadTestCert(t, "testdata/policy_intermediate_require.pem"), loadTestCert(t, "testdata/policy_leaf.pem")
expectedErr := "x509: no valid chains built: all candidate chains have invalid policies" expectedErr := "x509: no valid chains built: 1 candidate chains with invalid policies"
roots, intermediates := NewCertPool(), NewCertPool() roots, intermediates := NewCertPool(), NewCertPool()
roots.AddCert(root) roots.AddCert(root)