// Copyright 2010 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package tls import ( "bytes" "context" "crypto/rsa" "crypto/x509" "encoding/base64" "encoding/binary" "encoding/pem" "errors" "fmt" "io" "math/big" "net" "os" "os/exec" "path/filepath" "reflect" "runtime" "strconv" "strings" "testing" "time" ) // Note: see comment in handshake_test.go for details of how the reference // tests work. // opensslInputEvent enumerates possible inputs that can be sent to an `openssl // s_client` process. type opensslInputEvent int const ( // opensslRenegotiate causes OpenSSL to request a renegotiation of the // connection. opensslRenegotiate opensslInputEvent = iota // opensslSendBanner causes OpenSSL to send the contents of // opensslSentinel on the connection. opensslSendSentinel // opensslKeyUpdate causes OpenSSL to send a key update message to the // client and request one back. opensslKeyUpdate ) const opensslSentinel = "SENTINEL\n" type opensslInput chan opensslInputEvent func (i opensslInput) Read(buf []byte) (n int, err error) { for event := range i { switch event { case opensslRenegotiate: return copy(buf, []byte("R\n")), nil case opensslKeyUpdate: return copy(buf, []byte("K\n")), nil case opensslSendSentinel: return copy(buf, []byte(opensslSentinel)), nil default: panic("unknown event") } } return 0, io.EOF } // opensslOutputSink is an io.Writer that receives the stdout and stderr from an // `openssl` process and sends a value to handshakeComplete or readKeyUpdate // when certain messages are seen. type opensslOutputSink struct { handshakeComplete chan struct{} readKeyUpdate chan struct{} all []byte line []byte } func newOpensslOutputSink() *opensslOutputSink { return &opensslOutputSink{make(chan struct{}), make(chan struct{}), nil, nil} } // opensslEndOfHandshake is a message that the “openssl s_server” tool will // print when a handshake completes if run with “-state”. const opensslEndOfHandshake = "SSL_accept:SSLv3/TLS write finished" // opensslReadKeyUpdate is a message that the “openssl s_server” tool will // print when a KeyUpdate message is received if run with “-state”. const opensslReadKeyUpdate = "SSL_accept:TLSv1.3 read client key update" func (o *opensslOutputSink) Write(data []byte) (n int, err error) { o.line = append(o.line, data...) o.all = append(o.all, data...) for { line, next, ok := bytes.Cut(o.line, []byte("\n")) if !ok { break } if bytes.Equal([]byte(opensslEndOfHandshake), line) { o.handshakeComplete <- struct{}{} } if bytes.Equal([]byte(opensslReadKeyUpdate), line) { o.readKeyUpdate <- struct{}{} } o.line = next } return len(data), nil } func (o *opensslOutputSink) String() string { return string(o.all) } // clientTest represents a test of the TLS client handshake against a reference // implementation. type clientTest struct { // name is a freeform string identifying the test and the file in which // the expected results will be stored. name string // args, if not empty, contains a series of arguments for the // command to run for the reference server. args []string // config, if not nil, contains a custom Config to use for this test. config *Config // cert, if not empty, contains a DER-encoded certificate for the // reference server. cert []byte // key, if not nil, contains either a *rsa.PrivateKey, ed25519.PrivateKey or // *ecdsa.PrivateKey which is the private key for the reference server. key any // extensions, if not nil, contains a list of extension data to be returned // from the ServerHello. The data should be in standard TLS format with // a 2-byte uint16 type, 2-byte data length, followed by the extension data. extensions [][]byte // validate, if not nil, is a function that will be called with the // ConnectionState of the resulting connection. It returns a non-nil // error if the ConnectionState is unacceptable. validate func(ConnectionState) error // numRenegotiations is the number of times that the connection will be // renegotiated. numRenegotiations int // renegotiationExpectedToFail, if not zero, is the number of the // renegotiation attempt that is expected to fail. renegotiationExpectedToFail int // checkRenegotiationError, if not nil, is called with any error // arising from renegotiation. It can map expected errors to nil to // ignore them. checkRenegotiationError func(renegotiationNum int, err error) error // sendKeyUpdate will cause the server to send a KeyUpdate message. sendKeyUpdate bool } var serverCommand = []string{"openssl", "s_server", "-no_ticket", "-num_tickets", "0"} // connFromCommand starts the reference server process, connects to it and // returns a recordingConn for the connection. The stdin return value is an // opensslInput for the stdin of the child process. It must be closed before // Waiting for child. func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) { cert := testRSACertificate if len(test.cert) > 0 { cert = test.cert } certPath := tempFile(string(cert)) defer os.Remove(certPath) var key any = testRSAPrivateKey if test.key != nil { key = test.key } derBytes, err := x509.MarshalPKCS8PrivateKey(key) if err != nil { panic(err) } var pemOut bytes.Buffer pem.Encode(&pemOut, &pem.Block{Type: "PRIVATE KEY", Bytes: derBytes}) keyPath := tempFile(pemOut.String()) defer os.Remove(keyPath) var command []string command = append(command, serverCommand...) command = append(command, test.args...) command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath) // serverPort contains the port that OpenSSL will listen on. OpenSSL // can't take "0" as an argument here so we have to pick a number and // hope that it's not in use on the machine. Since this only occurs // when -update is given and thus when there's a human watching the // test, this isn't too bad. const serverPort = 24323 command = append(command, "-accept", strconv.Itoa(serverPort)) if len(test.extensions) > 0 { var serverInfo bytes.Buffer for _, ext := range test.extensions { pem.Encode(&serverInfo, &pem.Block{ Type: fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)), Bytes: ext, }) } serverInfoPath := tempFile(serverInfo.String()) defer os.Remove(serverInfoPath) command = append(command, "-serverinfo", serverInfoPath) } if test.numRenegotiations > 0 || test.sendKeyUpdate { found := false for _, flag := range command[1:] { if flag == "-state" { found = true break } } if !found { panic("-state flag missing to OpenSSL, you need this if testing renegotiation or KeyUpdate") } } cmd := exec.Command(command[0], command[1:]...) stdin = opensslInput(make(chan opensslInputEvent)) cmd.Stdin = stdin out := newOpensslOutputSink() cmd.Stdout = out cmd.Stderr = out if err := cmd.Start(); err != nil { return nil, nil, nil, nil, err } // OpenSSL does print an "ACCEPT" banner, but it does so *before* // opening the listening socket, so we can't use that to wait until it // has started listening. Thus we are forced to poll until we get a // connection. var tcpConn net.Conn for i := uint(0); i < 5; i++ { tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{ IP: net.IPv4(127, 0, 0, 1), Port: serverPort, }) if err == nil { break } time.Sleep((1 << i) * 5 * time.Millisecond) } if err != nil { close(stdin) cmd.Process.Kill() err = fmt.Errorf("error connecting to the OpenSSL server: %v (%v)\n\n%s", err, cmd.Wait(), out) return nil, nil, nil, nil, err } record := &recordingConn{ Conn: tcpConn, } return record, cmd, stdin, out, nil } func (test *clientTest) dataPath() string { return filepath.Join("testdata", "Client-"+test.name) } func (test *clientTest) loadData() (flows [][]byte, err error) { in, err := os.Open(test.dataPath()) if err != nil { return nil, err } defer in.Close() return parseTestData(in) } func (test *clientTest) run(t *testing.T, write bool) { var clientConn, serverConn net.Conn var recordingConn *recordingConn var childProcess *exec.Cmd var stdin opensslInput var stdout *opensslOutputSink if write { var err error recordingConn, childProcess, stdin, stdout, err = test.connFromCommand() if err != nil { t.Fatalf("Failed to start subcommand: %s", err) } clientConn = recordingConn defer func() { if t.Failed() { t.Logf("OpenSSL output:\n\n%s", stdout.all) } }() } else { clientConn, serverConn = localPipe(t) } doneChan := make(chan bool) defer func() { clientConn.Close() <-doneChan }() go func() { defer close(doneChan) config := test.config if config == nil { config = testConfig } client := Client(clientConn, config) defer client.Close() if _, err := client.Write([]byte("hello\n")); err != nil { t.Errorf("Client.Write failed: %s", err) return } for i := 1; i <= test.numRenegotiations; i++ { // The initial handshake will generate a // handshakeComplete signal which needs to be quashed. if i == 1 && write { <-stdout.handshakeComplete } // OpenSSL will try to interleave application data and // a renegotiation if we send both concurrently. // Therefore: ask OpensSSL to start a renegotiation, run // a goroutine to call client.Read and thus process the // renegotiation request, watch for OpenSSL's stdout to // indicate that the handshake is complete and, // finally, have OpenSSL write something to cause // client.Read to complete. if write { stdin <- opensslRenegotiate } signalChan := make(chan struct{}) go func() { defer close(signalChan) buf := make([]byte, 256) n, err := client.Read(buf) if test.checkRenegotiationError != nil { newErr := test.checkRenegotiationError(i, err) if err != nil && newErr == nil { return } err = newErr } if err != nil { t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err) return } buf = buf[:n] if !bytes.Equal([]byte(opensslSentinel), buf) { t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel) } if expected := i + 1; client.handshakes != expected { t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes) } }() if write && test.renegotiationExpectedToFail != i { <-stdout.handshakeComplete stdin <- opensslSendSentinel } <-signalChan } if test.sendKeyUpdate { if write { <-stdout.handshakeComplete stdin <- opensslKeyUpdate } doneRead := make(chan struct{}) go func() { defer close(doneRead) buf := make([]byte, 256) n, err := client.Read(buf) if err != nil { t.Errorf("Client.Read failed after KeyUpdate: %s", err) return } buf = buf[:n] if !bytes.Equal([]byte(opensslSentinel), buf) { t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel) } }() if write { // There's no real reason to wait for the client KeyUpdate to // send data with the new server keys, except that s_server // drops writes if they are sent at the wrong time. <-stdout.readKeyUpdate stdin <- opensslSendSentinel } <-doneRead if _, err := client.Write([]byte("hello again\n")); err != nil { t.Errorf("Client.Write failed: %s", err) return } } if test.validate != nil { if err := test.validate(client.ConnectionState()); err != nil { t.Errorf("validate callback returned error: %s", err) } } // If the server sent us an alert after our last flight, give it a // chance to arrive. if write && test.renegotiationExpectedToFail == 0 { if err := peekError(client); err != nil { t.Errorf("final Read returned an error: %s", err) } } }() if !write { flows, err := test.loadData() if err != nil { t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err) } for i, b := range flows { if i%2 == 1 { if *fast { serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second)) } else { serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute)) } serverConn.Write(b) continue } bb := make([]byte, len(b)) if *fast { serverConn.SetReadDeadline(time.Now().Add(1 * time.Second)) } else { serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute)) } _, err := io.ReadFull(serverConn, bb) if err != nil { t.Fatalf("%s, flow %d: %s", test.name, i+1, err) } if !bytes.Equal(b, bb) { t.Fatalf("%s, flow %d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b) } } } <-doneChan if !write { serverConn.Close() } if write { path := test.dataPath() out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) if err != nil { t.Fatalf("Failed to create output file: %s", err) } defer out.Close() recordingConn.Close() close(stdin) childProcess.Process.Kill() childProcess.Wait() if len(recordingConn.flows) < 3 { t.Fatalf("Client connection didn't work") } recordingConn.WriteTo(out) t.Logf("Wrote %s\n", path) } } // peekError does a read with a short timeout to check if the next read would // cause an error, for example if there is an alert waiting on the wire. func peekError(conn net.Conn) error { conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) if n, err := conn.Read(make([]byte, 1)); n != 0 { return errors.New("unexpectedly read data") } else if err != nil { if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { return err } } return nil } func runClientTestForVersion(t *testing.T, template *clientTest, version, option string) { // Make a deep copy of the template before going parallel. test := *template if template.config != nil { test.config = template.config.Clone() } test.name = version + "-" + test.name test.args = append([]string{option}, test.args...) runTestAndUpdateIfNeeded(t, version, test.run, false) } func runClientTestTLS10(t *testing.T, template *clientTest) { runClientTestForVersion(t, template, "TLSv10", "-tls1") } func runClientTestTLS11(t *testing.T, template *clientTest) { runClientTestForVersion(t, template, "TLSv11", "-tls1_1") } func runClientTestTLS12(t *testing.T, template *clientTest) { runClientTestForVersion(t, template, "TLSv12", "-tls1_2") } func runClientTestTLS13(t *testing.T, template *clientTest) { runClientTestForVersion(t, template, "TLSv13", "-tls1_3") } func TestHandshakeClientRSARC4(t *testing.T) { test := &clientTest{ name: "RSA-RC4", args: []string{"-cipher", "RC4-SHA"}, } runClientTestTLS10(t, test) runClientTestTLS11(t, test) runClientTestTLS12(t, test) } func TestHandshakeClientRSAAES128GCM(t *testing.T) { test := &clientTest{ name: "AES128-GCM-SHA256", args: []string{"-cipher", "AES128-GCM-SHA256"}, } runClientTestTLS12(t, test) } func TestHandshakeClientRSAAES256GCM(t *testing.T) { test := &clientTest{ name: "AES256-GCM-SHA384", args: []string{"-cipher", "AES256-GCM-SHA384"}, } runClientTestTLS12(t, test) } func TestHandshakeClientECDHERSAAES(t *testing.T) { test := &clientTest{ name: "ECDHE-RSA-AES", args: []string{"-cipher", "ECDHE-RSA-AES128-SHA"}, } runClientTestTLS10(t, test) runClientTestTLS11(t, test) runClientTestTLS12(t, test) } func TestHandshakeClientECDHEECDSAAES(t *testing.T) { test := &clientTest{ name: "ECDHE-ECDSA-AES", args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA"}, cert: testECDSACertificate, key: testECDSAPrivateKey, } runClientTestTLS10(t, test) runClientTestTLS11(t, test) runClientTestTLS12(t, test) } func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) { test := &clientTest{ name: "ECDHE-ECDSA-AES-GCM", args: []string{"-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"}, cert: testECDSACertificate, key: testECDSAPrivateKey, } runClientTestTLS12(t, test) } func TestHandshakeClientAES256GCMSHA384(t *testing.T) { test := &clientTest{ name: "ECDHE-ECDSA-AES256-GCM-SHA384", args: []string{"-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"}, cert: testECDSACertificate, key: testECDSAPrivateKey, } runClientTestTLS12(t, test) } func TestHandshakeClientAES128CBCSHA256(t *testing.T) { test := &clientTest{ name: "AES128-SHA256", args: []string{"-cipher", "AES128-SHA256"}, } runClientTestTLS12(t, test) } func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) { test := &clientTest{ name: "ECDHE-RSA-AES128-SHA256", args: []string{"-cipher", "ECDHE-RSA-AES128-SHA256"}, } runClientTestTLS12(t, test) } func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) { test := &clientTest{ name: "ECDHE-ECDSA-AES128-SHA256", args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA256"}, cert: testECDSACertificate, key: testECDSAPrivateKey, } runClientTestTLS12(t, test) } func TestHandshakeClientX25519(t *testing.T) { config := testConfig.Clone() config.CurvePreferences = []CurveID{X25519} test := &clientTest{ name: "X25519-ECDHE", args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "X25519"}, config: config, } runClientTestTLS12(t, test) runClientTestTLS13(t, test) } func TestHandshakeClientP256(t *testing.T) { config := testConfig.Clone() config.CurvePreferences = []CurveID{CurveP256} test := &clientTest{ name: "P256-ECDHE", args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"}, config: config, } runClientTestTLS12(t, test) runClientTestTLS13(t, test) } func TestHandshakeClientHelloRetryRequest(t *testing.T) { config := testConfig.Clone() config.CurvePreferences = []CurveID{X25519, CurveP256} test := &clientTest{ name: "HelloRetryRequest", args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"}, config: config, } runClientTestTLS13(t, test) } func TestHandshakeClientECDHERSAChaCha20(t *testing.T) { config := testConfig.Clone() config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305} test := &clientTest{ name: "ECDHE-RSA-CHACHA20-POLY1305", args: []string{"-cipher", "ECDHE-RSA-CHACHA20-POLY1305"}, config: config, } runClientTestTLS12(t, test) } func TestHandshakeClientECDHEECDSAChaCha20(t *testing.T) { config := testConfig.Clone() config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305} test := &clientTest{ name: "ECDHE-ECDSA-CHACHA20-POLY1305", args: []string{"-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305"}, config: config, cert: testECDSACertificate, key: testECDSAPrivateKey, } runClientTestTLS12(t, test) } func TestHandshakeClientAES128SHA256(t *testing.T) { test := &clientTest{ name: "AES128-SHA256", args: []string{"-ciphersuites", "TLS_AES_128_GCM_SHA256"}, } runClientTestTLS13(t, test) } func TestHandshakeClientAES256SHA384(t *testing.T) { test := &clientTest{ name: "AES256-SHA384", args: []string{"-ciphersuites", "TLS_AES_256_GCM_SHA384"}, } runClientTestTLS13(t, test) } func TestHandshakeClientCHACHA20SHA256(t *testing.T) { test := &clientTest{ name: "CHACHA20-SHA256", args: []string{"-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"}, } runClientTestTLS13(t, test) } func TestHandshakeClientECDSATLS13(t *testing.T) { test := &clientTest{ name: "ECDSA", cert: testECDSACertificate, key: testECDSAPrivateKey, } runClientTestTLS13(t, test) } func TestHandshakeClientEd25519(t *testing.T) { test := &clientTest{ name: "Ed25519", cert: testEd25519Certificate, key: testEd25519PrivateKey, } runClientTestTLS12(t, test) runClientTestTLS13(t, test) config := testConfig.Clone() cert, _ := X509KeyPair([]byte(clientEd25519CertificatePEM), []byte(clientEd25519KeyPEM)) config.Certificates = []Certificate{cert} test = &clientTest{ name: "ClientCert-Ed25519", args: []string{"-Verify", "1"}, config: config, } runClientTestTLS12(t, test) runClientTestTLS13(t, test) } func TestHandshakeClientCertRSA(t *testing.T) { config := testConfig.Clone() cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM)) config.Certificates = []Certificate{cert} test := &clientTest{ name: "ClientCert-RSA-RSA", args: []string{"-cipher", "AES128", "-Verify", "1"}, config: config, } runClientTestTLS10(t, test) runClientTestTLS12(t, test) test = &clientTest{ name: "ClientCert-RSA-ECDSA", args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"}, config: config, cert: testECDSACertificate, key: testECDSAPrivateKey, } runClientTestTLS10(t, test) runClientTestTLS12(t, test) runClientTestTLS13(t, test) test = &clientTest{ name: "ClientCert-RSA-AES256-GCM-SHA384", args: []string{"-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-Verify", "1"}, config: config, cert: testRSACertificate, key: testRSAPrivateKey, } runClientTestTLS12(t, test) } func TestHandshakeClientCertECDSA(t *testing.T) { config := testConfig.Clone() cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM)) config.Certificates = []Certificate{cert} test := &clientTest{ name: "ClientCert-ECDSA-RSA", args: []string{"-cipher", "AES128", "-Verify", "1"}, config: config, } runClientTestTLS10(t, test) runClientTestTLS12(t, test) runClientTestTLS13(t, test) test = &clientTest{ name: "ClientCert-ECDSA-ECDSA", args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"}, config: config, cert: testECDSACertificate, key: testECDSAPrivateKey, } runClientTestTLS10(t, test) runClientTestTLS12(t, test) } // TestHandshakeClientCertRSAPSS tests rsa_pss_rsae_sha256 signatures from both // client and server certificates. It also serves from both sides a certificate // signed itself with RSA-PSS, mostly to check that crypto/x509 chain validation // works. func TestHandshakeClientCertRSAPSS(t *testing.T) { cert, err := x509.ParseCertificate(testRSAPSSCertificate) if err != nil { panic(err) } rootCAs := x509.NewCertPool() rootCAs.AddCert(cert) config := testConfig.Clone() // Use GetClientCertificate to bypass the client certificate selection logic. config.GetClientCertificate = func(*CertificateRequestInfo) (*Certificate, error) { return &Certificate{ Certificate: [][]byte{testRSAPSSCertificate}, PrivateKey: testRSAPrivateKey, }, nil } config.RootCAs = rootCAs test := &clientTest{ name: "ClientCert-RSA-RSAPSS", args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs", "rsa_pss_rsae_sha256", "-sigalgs", "rsa_pss_rsae_sha256"}, config: config, cert: testRSAPSSCertificate, key: testRSAPrivateKey, } runClientTestTLS12(t, test) runClientTestTLS13(t, test) } func TestHandshakeClientCertRSAPKCS1v15(t *testing.T) { config := testConfig.Clone() cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM)) config.Certificates = []Certificate{cert} test := &clientTest{ name: "ClientCert-RSA-RSAPKCS1v15", args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs", "rsa_pkcs1_sha256", "-sigalgs", "rsa_pkcs1_sha256"}, config: config, } runClientTestTLS12(t, test) } func TestClientKeyUpdate(t *testing.T) { test := &clientTest{ name: "KeyUpdate", args: []string{"-state"}, sendKeyUpdate: true, } runClientTestTLS13(t, test) } func TestResumption(t *testing.T) { t.Run("TLSv12", func(t *testing.T) { testResumption(t, VersionTLS12) }) t.Run("TLSv13", func(t *testing.T) { testResumption(t, VersionTLS13) }) } func testResumption(t *testing.T, version uint16) { if testing.Short() { t.Skip("skipping in -short mode") } serverConfig := &Config{ MaxVersion: version, CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA}, Certificates: testConfig.Certificates, } issuer, err := x509.ParseCertificate(testRSACertificateIssuer) if err != nil { panic(err) } rootCAs := x509.NewCertPool() rootCAs.AddCert(issuer) clientConfig := &Config{ MaxVersion: version, CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, ClientSessionCache: NewLRUClientSessionCache(32), RootCAs: rootCAs, ServerName: "example.golang", } testResumeState := func(test string, didResume bool) { _, hs, err := testHandshake(t, clientConfig, serverConfig) if err != nil { t.Fatalf("%s: handshake failed: %s", test, err) } if hs.DidResume != didResume { t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume) } if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) { t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains) } if got, want := hs.ServerName, clientConfig.ServerName; got != want { t.Errorf("%s: server name %s, want %s", test, got, want) } } getTicket := func() []byte { return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket } deleteTicket := func() { ticketKey := clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).sessionKey clientConfig.ClientSessionCache.Put(ticketKey, nil) } corruptTicket := func() { clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.masterSecret[0] ^= 0xff } randomKey := func() [32]byte { var k [32]byte if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil { t.Fatalf("Failed to read new SessionTicketKey: %s", err) } return k } testResumeState("Handshake", false) ticket := getTicket() testResumeState("Resume", true) if !bytes.Equal(ticket, getTicket()) && version != VersionTLS13 { t.Fatal("first ticket doesn't match ticket after resumption") } if bytes.Equal(ticket, getTicket()) && version == VersionTLS13 { t.Fatal("ticket didn't change after resumption") } // An old session ticket can resume, but the server will provide a ticket encrypted with a fresh key. serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) } testResumeState("ResumeWithOldTicket", true) if bytes.Equal(ticket[:ticketKeyNameLen], getTicket()[:ticketKeyNameLen]) { t.Fatal("old first ticket matches the fresh one") } // Now the session tickey key is expired, so a full handshake should occur. serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) } testResumeState("ResumeWithExpiredTicket", false) if bytes.Equal(ticket, getTicket()) { t.Fatal("expired first ticket matches the fresh one") } serverConfig.Time = func() time.Time { return time.Now() } // reset the time back key1 := randomKey() serverConfig.SetSessionTicketKeys([][32]byte{key1}) testResumeState("InvalidSessionTicketKey", false) testResumeState("ResumeAfterInvalidSessionTicketKey", true) key2 := randomKey() serverConfig.SetSessionTicketKeys([][32]byte{key2, key1}) ticket = getTicket() testResumeState("KeyChange", true) if bytes.Equal(ticket, getTicket()) { t.Fatal("new ticket wasn't included while resuming") } testResumeState("KeyChangeFinish", true) // Age the session ticket a bit, but not yet expired. serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) } testResumeState("OldSessionTicket", true) ticket = getTicket() // Expire the session ticket, which would force a full handshake. serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) } testResumeState("ExpiredSessionTicket", false) if bytes.Equal(ticket, getTicket()) { t.Fatal("new ticket wasn't provided after old ticket expired") } // Age the session ticket a bit at a time, but don't expire it. d := 0 * time.Hour for i := 0; i < 13; i++ { d += 12 * time.Hour serverConfig.Time = func() time.Time { return time.Now().Add(d) } testResumeState("OldSessionTicket", true) } // Expire it (now a little more than 7 days) and make sure a full // handshake occurs for TLS 1.2. Resumption should still occur for // TLS 1.3 since the client should be using a fresh ticket sent over // by the server. d += 12 * time.Hour serverConfig.Time = func() time.Time { return time.Now().Add(d) } if version == VersionTLS13 { testResumeState("ExpiredSessionTicket", true) } else { testResumeState("ExpiredSessionTicket", false) } if bytes.Equal(ticket, getTicket()) { t.Fatal("new ticket wasn't provided after old ticket expired") } // Reset serverConfig to ensure that calling SetSessionTicketKeys // before the serverConfig is used works. serverConfig = &Config{ MaxVersion: version, CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA}, Certificates: testConfig.Certificates, } serverConfig.SetSessionTicketKeys([][32]byte{key2}) testResumeState("FreshConfig", true) // In TLS 1.3, cross-cipher suite resumption is allowed as long as the KDF // hash matches. Also, Config.CipherSuites does not apply to TLS 1.3. if version != VersionTLS13 { clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA} testResumeState("DifferentCipherSuite", false) testResumeState("DifferentCipherSuiteRecovers", true) } deleteTicket() testResumeState("WithoutSessionTicket", false) // Session resumption should work when using client certificates deleteTicket() serverConfig.ClientCAs = rootCAs serverConfig.ClientAuth = RequireAndVerifyClientCert clientConfig.Certificates = serverConfig.Certificates testResumeState("InitialHandshake", false) testResumeState("WithClientCertificates", true) serverConfig.ClientAuth = NoClientCert // Tickets should be removed from the session cache on TLS handshake // failure, and the client should recover from a corrupted PSK testResumeState("FetchTicketToCorrupt", false) corruptTicket() _, _, err = testHandshake(t, clientConfig, serverConfig) if err == nil { t.Fatalf("handshake did not fail with a corrupted client secret") } testResumeState("AfterHandshakeFailure", false) clientConfig.ClientSessionCache = nil testResumeState("WithoutSessionCache", false) } func TestLRUClientSessionCache(t *testing.T) { // Initialize cache of capacity 4. cache := NewLRUClientSessionCache(4) cs := make([]ClientSessionState, 6) keys := []string{"0", "1", "2", "3", "4", "5", "6"} // Add 4 entries to the cache and look them up. for i := 0; i < 4; i++ { cache.Put(keys[i], &cs[i]) } for i := 0; i < 4; i++ { if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] { t.Fatalf("session cache failed lookup for added key: %s", keys[i]) } } // Add 2 more entries to the cache. First 2 should be evicted. for i := 4; i < 6; i++ { cache.Put(keys[i], &cs[i]) } for i := 0; i < 2; i++ { if s, ok := cache.Get(keys[i]); ok || s != nil { t.Fatalf("session cache should have evicted key: %s", keys[i]) } } // Touch entry 2. LRU should evict 3 next. cache.Get(keys[2]) cache.Put(keys[0], &cs[0]) if s, ok := cache.Get(keys[3]); ok || s != nil { t.Fatalf("session cache should have evicted key 3") } // Update entry 0 in place. cache.Put(keys[0], &cs[3]) if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] { t.Fatalf("session cache failed update for key 0") } // Calling Put with a nil entry deletes the key. cache.Put(keys[0], nil) if _, ok := cache.Get(keys[0]); ok { t.Fatalf("session cache failed to delete key 0") } // Delete entry 2. LRU should keep 4 and 5 cache.Put(keys[2], nil) if _, ok := cache.Get(keys[2]); ok { t.Fatalf("session cache failed to delete key 4") } for i := 4; i < 6; i++ { if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] { t.Fatalf("session cache should not have deleted key: %s", keys[i]) } } } func TestKeyLogTLS12(t *testing.T) { var serverBuf, clientBuf bytes.Buffer clientConfig := testConfig.Clone() clientConfig.KeyLogWriter = &clientBuf clientConfig.MaxVersion = VersionTLS12 serverConfig := testConfig.Clone() serverConfig.KeyLogWriter = &serverBuf serverConfig.MaxVersion = VersionTLS12 c, s := localPipe(t) done := make(chan bool) go func() { defer close(done) if err := Server(s, serverConfig).Handshake(); err != nil { t.Errorf("server: %s", err) return } s.Close() }() if err := Client(c, clientConfig).Handshake(); err != nil { t.Fatalf("client: %s", err) } c.Close() <-done checkKeylogLine := func(side, loggedLine string) { if len(loggedLine) == 0 { t.Fatalf("%s: no keylog line was produced", side) } const expectedLen = 13 /* "CLIENT_RANDOM" */ + 1 /* space */ + 32*2 /* hex client nonce */ + 1 /* space */ + 48*2 /* hex master secret */ + 1 /* new line */ if len(loggedLine) != expectedLen { t.Fatalf("%s: keylog line has incorrect length (want %d, got %d): %q", side, expectedLen, len(loggedLine), loggedLine) } if !strings.HasPrefix(loggedLine, "CLIENT_RANDOM "+strings.Repeat("0", 64)+" ") { t.Fatalf("%s: keylog line has incorrect structure or nonce: %q", side, loggedLine) } } checkKeylogLine("client", clientBuf.String()) checkKeylogLine("server", serverBuf.String()) } func TestKeyLogTLS13(t *testing.T) { var serverBuf, clientBuf bytes.Buffer clientConfig := testConfig.Clone() clientConfig.KeyLogWriter = &clientBuf serverConfig := testConfig.Clone() serverConfig.KeyLogWriter = &serverBuf c, s := localPipe(t) done := make(chan bool) go func() { defer close(done) if err := Server(s, serverConfig).Handshake(); err != nil { t.Errorf("server: %s", err) return } s.Close() }() if err := Client(c, clientConfig).Handshake(); err != nil { t.Fatalf("client: %s", err) } c.Close() <-done checkKeylogLines := func(side, loggedLines string) { loggedLines = strings.TrimSpace(loggedLines) lines := strings.Split(loggedLines, "\n") if len(lines) != 4 { t.Errorf("Expected the %s to log 4 lines, got %d", side, len(lines)) } } checkKeylogLines("client", clientBuf.String()) checkKeylogLines("server", serverBuf.String()) } func TestHandshakeClientALPNMatch(t *testing.T) { config := testConfig.Clone() config.NextProtos = []string{"proto2", "proto1"} test := &clientTest{ name: "ALPN", // Note that this needs OpenSSL 1.0.2 because that is the first // version that supports the -alpn flag. args: []string{"-alpn", "proto1,proto2"}, config: config, validate: func(state ConnectionState) error { // The server's preferences should override the client. if state.NegotiatedProtocol != "proto1" { return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol) } return nil }, } runClientTestTLS12(t, test) runClientTestTLS13(t, test) } func TestServerSelectingUnconfiguredApplicationProtocol(t *testing.T) { // This checks that the server can't select an application protocol that the // client didn't offer. c, s := localPipe(t) errChan := make(chan error, 1) go func() { client := Client(c, &Config{ ServerName: "foo", CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, NextProtos: []string{"http", "something-else"}, }) errChan <- client.Handshake() }() var header [5]byte if _, err := io.ReadFull(s, header[:]); err != nil { t.Fatal(err) } recordLen := int(header[3])<<8 | int(header[4]) record := make([]byte, recordLen) if _, err := io.ReadFull(s, record); err != nil { t.Fatal(err) } serverHello := &serverHelloMsg{ vers: VersionTLS12, random: make([]byte, 32), cipherSuite: TLS_RSA_WITH_AES_128_GCM_SHA256, alpnProtocol: "how-about-this", } serverHelloBytes := serverHello.marshal() s.Write([]byte{ byte(recordTypeHandshake), byte(VersionTLS12 >> 8), byte(VersionTLS12 & 0xff), byte(len(serverHelloBytes) >> 8), byte(len(serverHelloBytes)), }) s.Write(serverHelloBytes) s.Close() if err := <-errChan; !strings.Contains(err.Error(), "server selected unadvertised ALPN protocol") { t.Fatalf("Expected error about unconfigured cipher suite but got %q", err) } } // sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443` const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0=" func TestHandshakClientSCTs(t *testing.T) { config := testConfig.Clone() scts, err := base64.StdEncoding.DecodeString(sctsBase64) if err != nil { t.Fatal(err) } // Note that this needs OpenSSL 1.0.2 because that is the first // version that supports the -serverinfo flag. test := &clientTest{ name: "SCT", config: config, extensions: [][]byte{scts}, validate: func(state ConnectionState) error { expectedSCTs := [][]byte{ scts[8:125], scts[127:245], scts[247:], } if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) { return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs)) } for i, expected := range expectedSCTs { if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) { return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected) } } return nil }, } runClientTestTLS12(t, test) // TLS 1.3 moved SCTs to the Certificate extensions and -serverinfo only // supports ServerHello extensions. } func TestRenegotiationRejected(t *testing.T) { config := testConfig.Clone() test := &clientTest{ name: "RenegotiationRejected", args: []string{"-state"}, config: config, numRenegotiations: 1, renegotiationExpectedToFail: 1, checkRenegotiationError: func(renegotiationNum int, err error) error { if err == nil { return errors.New("expected error from renegotiation but got nil") } if !strings.Contains(err.Error(), "no renegotiation") { return fmt.Errorf("expected renegotiation to be rejected but got %q", err) } return nil }, } runClientTestTLS12(t, test) } func TestRenegotiateOnce(t *testing.T) { config := testConfig.Clone() config.Renegotiation = RenegotiateOnceAsClient test := &clientTest{ name: "RenegotiateOnce", args: []string{"-state"}, config: config, numRenegotiations: 1, } runClientTestTLS12(t, test) } func TestRenegotiateTwice(t *testing.T) { config := testConfig.Clone() config.Renegotiation = RenegotiateFreelyAsClient test := &clientTest{ name: "RenegotiateTwice", args: []string{"-state"}, config: config, numRenegotiations: 2, } runClientTestTLS12(t, test) } func TestRenegotiateTwiceRejected(t *testing.T) { config := testConfig.Clone() config.Renegotiation = RenegotiateOnceAsClient test := &clientTest{ name: "RenegotiateTwiceRejected", args: []string{"-state"}, config: config, numRenegotiations: 2, renegotiationExpectedToFail: 2, checkRenegotiationError: func(renegotiationNum int, err error) error { if renegotiationNum == 1 { return err } if err == nil { return errors.New("expected error from renegotiation but got nil") } if !strings.Contains(err.Error(), "no renegotiation") { return fmt.Errorf("expected renegotiation to be rejected but got %q", err) } return nil }, } runClientTestTLS12(t, test) } func TestHandshakeClientExportKeyingMaterial(t *testing.T) { test := &clientTest{ name: "ExportKeyingMaterial", config: testConfig.Clone(), validate: func(state ConnectionState) error { if km, err := state.ExportKeyingMaterial("test", nil, 42); err != nil { return fmt.Errorf("ExportKeyingMaterial failed: %v", err) } else if len(km) != 42 { return fmt.Errorf("Got %d bytes from ExportKeyingMaterial, wanted %d", len(km), 42) } return nil }, } runClientTestTLS10(t, test) runClientTestTLS12(t, test) runClientTestTLS13(t, test) } var hostnameInSNITests = []struct { in, out string }{ // Opaque string {"", ""}, {"localhost", "localhost"}, {"foo, bar, baz and qux", "foo, bar, baz and qux"}, // DNS hostname {"golang.org", "golang.org"}, {"golang.org.", "golang.org"}, // Literal IPv4 address {"1.2.3.4", ""}, // Literal IPv6 address {"::1", ""}, {"::1%lo0", ""}, // with zone identifier {"[::1]", ""}, // as per RFC 5952 we allow the [] style as IPv6 literal {"[::1%lo0]", ""}, } func TestHostnameInSNI(t *testing.T) { for _, tt := range hostnameInSNITests { c, s := localPipe(t) go func(host string) { Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake() }(tt.in) var header [5]byte if _, err := io.ReadFull(s, header[:]); err != nil { t.Fatal(err) } recordLen := int(header[3])<<8 | int(header[4]) record := make([]byte, recordLen) if _, err := io.ReadFull(s, record[:]); err != nil { t.Fatal(err) } c.Close() s.Close() var m clientHelloMsg if !m.unmarshal(record) { t.Errorf("unmarshaling ClientHello for %q failed", tt.in) continue } if tt.in != tt.out && m.serverName == tt.in { t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record) } if m.serverName != tt.out { t.Errorf("expected %q not found in ClientHello: %x", tt.out, record) } } } func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) { // This checks that the server can't select a cipher suite that the // client didn't offer. See #13174. c, s := localPipe(t) errChan := make(chan error, 1) go func() { client := Client(c, &Config{ ServerName: "foo", CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, }) errChan <- client.Handshake() }() var header [5]byte if _, err := io.ReadFull(s, header[:]); err != nil { t.Fatal(err) } recordLen := int(header[3])<<8 | int(header[4]) record := make([]byte, recordLen) if _, err := io.ReadFull(s, record); err != nil { t.Fatal(err) } // Create a ServerHello that selects a different cipher suite than the // sole one that the client offered. serverHello := &serverHelloMsg{ vers: VersionTLS12, random: make([]byte, 32), cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384, } serverHelloBytes := serverHello.marshal() s.Write([]byte{ byte(recordTypeHandshake), byte(VersionTLS12 >> 8), byte(VersionTLS12 & 0xff), byte(len(serverHelloBytes) >> 8), byte(len(serverHelloBytes)), }) s.Write(serverHelloBytes) s.Close() if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") { t.Fatalf("Expected error about unconfigured cipher suite but got %q", err) } } func TestVerifyConnection(t *testing.T) { t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) }) t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) }) } func testVerifyConnection(t *testing.T, version uint16) { checkFields := func(c ConnectionState, called *int, errorType string) error { if c.Version != version { return fmt.Errorf("%s: got Version %v, want %v", errorType, c.Version, version) } if c.HandshakeComplete { return fmt.Errorf("%s: got HandshakeComplete, want false", errorType) } if c.ServerName != "example.golang" { return fmt.Errorf("%s: got ServerName %s, want %s", errorType, c.ServerName, "example.golang") } if c.NegotiatedProtocol != "protocol1" { return fmt.Errorf("%s: got NegotiatedProtocol %s, want %s", errorType, c.NegotiatedProtocol, "protocol1") } if c.CipherSuite == 0 { return fmt.Errorf("%s: got CipherSuite 0, want non-zero", errorType) } wantDidResume := false if *called == 2 { // if this is the second time, then it should be a resumption wantDidResume = true } if c.DidResume != wantDidResume { return fmt.Errorf("%s: got DidResume %t, want %t", errorType, c.DidResume, wantDidResume) } return nil } tests := []struct { name string configureServer func(*Config, *int) configureClient func(*Config, *int) }{ { name: "RequireAndVerifyClientCert", configureServer: func(config *Config, called *int) { config.ClientAuth = RequireAndVerifyClientCert config.VerifyConnection = func(c ConnectionState) error { *called++ if l := len(c.PeerCertificates); l != 1 { return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l) } if len(c.VerifiedChains) == 0 { return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero") } return checkFields(c, called, "server") } }, configureClient: func(config *Config, called *int) { config.VerifyConnection = func(c ConnectionState) error { *called++ if l := len(c.PeerCertificates); l != 1 { return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l) } if len(c.VerifiedChains) == 0 { return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero") } if c.DidResume { return nil // The SCTs and OCSP Response are dropped on resumption. // See http://golang.org/issue/39075. } if len(c.OCSPResponse) == 0 { return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero") } if len(c.SignedCertificateTimestamps) == 0 { return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") } return checkFields(c, called, "client") } }, }, { name: "InsecureSkipVerify", configureServer: func(config *Config, called *int) { config.ClientAuth = RequireAnyClientCert config.InsecureSkipVerify = true config.VerifyConnection = func(c ConnectionState) error { *called++ if l := len(c.PeerCertificates); l != 1 { return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l) } if c.VerifiedChains != nil { return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains) } return checkFields(c, called, "server") } }, configureClient: func(config *Config, called *int) { config.InsecureSkipVerify = true config.VerifyConnection = func(c ConnectionState) error { *called++ if l := len(c.PeerCertificates); l != 1 { return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l) } if c.VerifiedChains != nil { return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains) } if c.DidResume { return nil // The SCTs and OCSP Response are dropped on resumption. // See http://golang.org/issue/39075. } if len(c.OCSPResponse) == 0 { return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero") } if len(c.SignedCertificateTimestamps) == 0 { return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") } return checkFields(c, called, "client") } }, }, { name: "NoClientCert", configureServer: func(config *Config, called *int) { config.ClientAuth = NoClientCert config.VerifyConnection = func(c ConnectionState) error { *called++ return checkFields(c, called, "server") } }, configureClient: func(config *Config, called *int) { config.VerifyConnection = func(c ConnectionState) error { *called++ return checkFields(c, called, "client") } }, }, { name: "RequestClientCert", configureServer: func(config *Config, called *int) { config.ClientAuth = RequestClientCert config.VerifyConnection = func(c ConnectionState) error { *called++ return checkFields(c, called, "server") } }, configureClient: func(config *Config, called *int) { config.Certificates = nil // clear the client cert config.VerifyConnection = func(c ConnectionState) error { *called++ if l := len(c.PeerCertificates); l != 1 { return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l) } if len(c.VerifiedChains) == 0 { return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero") } if c.DidResume { return nil // The SCTs and OCSP Response are dropped on resumption. // See http://golang.org/issue/39075. } if len(c.OCSPResponse) == 0 { return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero") } if len(c.SignedCertificateTimestamps) == 0 { return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") } return checkFields(c, called, "client") } }, }, } for _, test := range tests { issuer, err := x509.ParseCertificate(testRSACertificateIssuer) if err != nil { panic(err) } rootCAs := x509.NewCertPool() rootCAs.AddCert(issuer) var serverCalled, clientCalled int serverConfig := &Config{ MaxVersion: version, Certificates: []Certificate{testConfig.Certificates[0]}, ClientCAs: rootCAs, NextProtos: []string{"protocol1"}, } serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")} serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp") test.configureServer(serverConfig, &serverCalled) clientConfig := &Config{ MaxVersion: version, ClientSessionCache: NewLRUClientSessionCache(32), RootCAs: rootCAs, ServerName: "example.golang", Certificates: []Certificate{testConfig.Certificates[0]}, NextProtos: []string{"protocol1"}, } test.configureClient(clientConfig, &clientCalled) testHandshakeState := func(name string, didResume bool) { _, hs, err := testHandshake(t, clientConfig, serverConfig) if err != nil { t.Fatalf("%s: handshake failed: %s", name, err) } if hs.DidResume != didResume { t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume) } wantCalled := 1 if didResume { wantCalled = 2 // resumption would mean this is the second time it was called in this test } if clientCalled != wantCalled { t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled) } if serverCalled != wantCalled { t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled) } } testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false) testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true) } } func TestVerifyPeerCertificate(t *testing.T) { t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) }) t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) }) } func testVerifyPeerCertificate(t *testing.T, version uint16) { issuer, err := x509.ParseCertificate(testRSACertificateIssuer) if err != nil { panic(err) } rootCAs := x509.NewCertPool() rootCAs.AddCert(issuer) now := func() time.Time { return time.Unix(1476984729, 0) } sentinelErr := errors.New("TestVerifyPeerCertificate") verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { if l := len(rawCerts); l != 1 { return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l) } if len(validatedChains) == 0 { return errors.New("got len(validatedChains) = 0, wanted non-zero") } *called = true return nil } verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error { if l := len(c.PeerCertificates); l != 1 { return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l) } if len(c.VerifiedChains) == 0 { return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero") } if isClient && len(c.OCSPResponse) == 0 { return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero") } *called = true return nil } tests := []struct { configureServer func(*Config, *bool) configureClient func(*Config, *bool) validate func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) }{ { configureServer: func(config *Config, called *bool) { config.InsecureSkipVerify = false config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { return verifyPeerCertificateCallback(called, rawCerts, validatedChains) } }, configureClient: func(config *Config, called *bool) { config.InsecureSkipVerify = false config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { return verifyPeerCertificateCallback(called, rawCerts, validatedChains) } }, validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { if clientErr != nil { t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr) } if serverErr != nil { t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr) } if !clientCalled { t.Errorf("test[%d]: client did not call callback", testNo) } if !serverCalled { t.Errorf("test[%d]: server did not call callback", testNo) } }, }, { configureServer: func(config *Config, called *bool) { config.InsecureSkipVerify = false config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { return sentinelErr } }, configureClient: func(config *Config, called *bool) { config.VerifyPeerCertificate = nil }, validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { if serverErr != sentinelErr { t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr) } }, }, { configureServer: func(config *Config, called *bool) { config.InsecureSkipVerify = false }, configureClient: func(config *Config, called *bool) { config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { return sentinelErr } }, validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { if clientErr != sentinelErr { t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr) } }, }, { configureServer: func(config *Config, called *bool) { config.InsecureSkipVerify = false }, configureClient: func(config *Config, called *bool) { config.InsecureSkipVerify = true config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { if l := len(rawCerts); l != 1 { return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l) } // With InsecureSkipVerify set, this // callback should still be called but // validatedChains must be empty. if l := len(validatedChains); l != 0 { return fmt.Errorf("got len(validatedChains) = %d, wanted zero", l) } *called = true return nil } }, validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { if clientErr != nil { t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr) } if serverErr != nil { t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr) } if !clientCalled { t.Errorf("test[%d]: client did not call callback", testNo) } }, }, { configureServer: func(config *Config, called *bool) { config.InsecureSkipVerify = false config.VerifyConnection = func(c ConnectionState) error { return verifyConnectionCallback(called, false, c) } }, configureClient: func(config *Config, called *bool) { config.InsecureSkipVerify = false config.VerifyConnection = func(c ConnectionState) error { return verifyConnectionCallback(called, true, c) } }, validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { if clientErr != nil { t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr) } if serverErr != nil { t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr) } if !clientCalled { t.Errorf("test[%d]: client did not call callback", testNo) } if !serverCalled { t.Errorf("test[%d]: server did not call callback", testNo) } }, }, { configureServer: func(config *Config, called *bool) { config.InsecureSkipVerify = false config.VerifyConnection = func(c ConnectionState) error { return sentinelErr } }, configureClient: func(config *Config, called *bool) { config.InsecureSkipVerify = false config.VerifyConnection = nil }, validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { if serverErr != sentinelErr { t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr) } }, }, { configureServer: func(config *Config, called *bool) { config.InsecureSkipVerify = false config.VerifyConnection = nil }, configureClient: func(config *Config, called *bool) { config.InsecureSkipVerify = false config.VerifyConnection = func(c ConnectionState) error { return sentinelErr } }, validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { if clientErr != sentinelErr { t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr) } }, }, { configureServer: func(config *Config, called *bool) { config.InsecureSkipVerify = false config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { return verifyPeerCertificateCallback(called, rawCerts, validatedChains) } config.VerifyConnection = func(c ConnectionState) error { return sentinelErr } }, configureClient: func(config *Config, called *bool) { config.InsecureSkipVerify = false config.VerifyPeerCertificate = nil config.VerifyConnection = nil }, validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { if serverErr != sentinelErr { t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr) } if !serverCalled { t.Errorf("test[%d]: server did not call callback", testNo) } }, }, { configureServer: func(config *Config, called *bool) { config.InsecureSkipVerify = false config.VerifyPeerCertificate = nil config.VerifyConnection = nil }, configureClient: func(config *Config, called *bool) { config.InsecureSkipVerify = false config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { return verifyPeerCertificateCallback(called, rawCerts, validatedChains) } config.VerifyConnection = func(c ConnectionState) error { return sentinelErr } }, validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { if clientErr != sentinelErr { t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr) } if !clientCalled { t.Errorf("test[%d]: client did not call callback", testNo) } }, }, } for i, test := range tests { c, s := localPipe(t) done := make(chan error) var clientCalled, serverCalled bool go func() { config := testConfig.Clone() config.ServerName = "example.golang" config.ClientAuth = RequireAndVerifyClientCert config.ClientCAs = rootCAs config.Time = now config.MaxVersion = version config.Certificates = make([]Certificate, 1) config.Certificates[0].Certificate = [][]byte{testRSACertificate} config.Certificates[0].PrivateKey = testRSAPrivateKey config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")} config.Certificates[0].OCSPStaple = []byte("dummy ocsp") test.configureServer(config, &serverCalled) err = Server(s, config).Handshake() s.Close() done <- err }() config := testConfig.Clone() config.ServerName = "example.golang" config.RootCAs = rootCAs config.Time = now config.MaxVersion = version test.configureClient(config, &clientCalled) clientErr := Client(c, config).Handshake() c.Close() serverErr := <-done test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr) } } // brokenConn wraps a net.Conn and causes all Writes after a certain number to // fail with brokenConnErr. type brokenConn struct { net.Conn // breakAfter is the number of successful writes that will be allowed // before all subsequent writes fail. breakAfter int // numWrites is the number of writes that have been done. numWrites int } // brokenConnErr is the error that brokenConn returns once exhausted. var brokenConnErr = errors.New("too many writes to brokenConn") func (b *brokenConn) Write(data []byte) (int, error) { if b.numWrites >= b.breakAfter { return 0, brokenConnErr } b.numWrites++ return b.Conn.Write(data) } func TestFailedWrite(t *testing.T) { // Test that a write error during the handshake is returned. for _, breakAfter := range []int{0, 1} { c, s := localPipe(t) done := make(chan bool) go func() { Server(s, testConfig).Handshake() s.Close() done <- true }() brokenC := &brokenConn{Conn: c, breakAfter: breakAfter} err := Client(brokenC, testConfig).Handshake() if err != brokenConnErr { t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err) } brokenC.Close() <-done } } // writeCountingConn wraps a net.Conn and counts the number of Write calls. type writeCountingConn struct { net.Conn // numWrites is the number of writes that have been done. numWrites int } func (wcc *writeCountingConn) Write(data []byte) (int, error) { wcc.numWrites++ return wcc.Conn.Write(data) } func TestBuffering(t *testing.T) { t.Run("TLSv12", func(t *testing.T) { testBuffering(t, VersionTLS12) }) t.Run("TLSv13", func(t *testing.T) { testBuffering(t, VersionTLS13) }) } func testBuffering(t *testing.T, version uint16) { c, s := localPipe(t) done := make(chan bool) clientWCC := &writeCountingConn{Conn: c} serverWCC := &writeCountingConn{Conn: s} go func() { config := testConfig.Clone() config.MaxVersion = version Server(serverWCC, config).Handshake() serverWCC.Close() done <- true }() err := Client(clientWCC, testConfig).Handshake() if err != nil { t.Fatal(err) } clientWCC.Close() <-done var expectedClient, expectedServer int if version == VersionTLS13 { expectedClient = 2 expectedServer = 1 } else { expectedClient = 2 expectedServer = 2 } if n := clientWCC.numWrites; n != expectedClient { t.Errorf("expected client handshake to complete with %d writes, but saw %d", expectedClient, n) } if n := serverWCC.numWrites; n != expectedServer { t.Errorf("expected server handshake to complete with %d writes, but saw %d", expectedServer, n) } } func TestAlertFlushing(t *testing.T) { c, s := localPipe(t) done := make(chan bool) clientWCC := &writeCountingConn{Conn: c} serverWCC := &writeCountingConn{Conn: s} serverConfig := testConfig.Clone() // Cause a signature-time error brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey} brokenKey.D = big.NewInt(42) serverConfig.Certificates = []Certificate{{ Certificate: [][]byte{testRSACertificate}, PrivateKey: &brokenKey, }} go func() { Server(serverWCC, serverConfig).Handshake() serverWCC.Close() done <- true }() err := Client(clientWCC, testConfig).Handshake() if err == nil { t.Fatal("client unexpectedly returned no error") } const expectedError = "remote error: tls: internal error" if e := err.Error(); !strings.Contains(e, expectedError) { t.Fatalf("expected to find %q in error but error was %q", expectedError, e) } clientWCC.Close() <-done if n := serverWCC.numWrites; n != 1 { t.Errorf("expected server handshake to complete with one write, but saw %d", n) } } func TestHandshakeRace(t *testing.T) { if testing.Short() { t.Skip("skipping in -short mode") } t.Parallel() // This test races a Read and Write to try and complete a handshake in // order to provide some evidence that there are no races or deadlocks // in the handshake locking. for i := 0; i < 32; i++ { c, s := localPipe(t) go func() { server := Server(s, testConfig) if err := server.Handshake(); err != nil { panic(err) } var request [1]byte if n, err := server.Read(request[:]); err != nil || n != 1 { panic(err) } server.Write(request[:]) server.Close() }() startWrite := make(chan struct{}) startRead := make(chan struct{}) readDone := make(chan struct{}, 1) client := Client(c, testConfig) go func() { <-startWrite var request [1]byte client.Write(request[:]) }() go func() { <-startRead var reply [1]byte if _, err := io.ReadFull(client, reply[:]); err != nil { panic(err) } c.Close() readDone <- struct{}{} }() if i&1 == 1 { startWrite <- struct{}{} startRead <- struct{}{} } else { startRead <- struct{}{} startWrite <- struct{}{} } <-readDone } } var getClientCertificateTests = []struct { setup func(*Config, *Config) expectedClientError string verify func(*testing.T, int, *ConnectionState) }{ { func(clientConfig, serverConfig *Config) { // Returning a Certificate with no certificate data // should result in an empty message being sent to the // server. serverConfig.ClientCAs = nil clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { if len(cri.SignatureSchemes) == 0 { panic("empty SignatureSchemes") } if len(cri.AcceptableCAs) != 0 { panic("AcceptableCAs should have been empty") } return new(Certificate), nil } }, "", func(t *testing.T, testNum int, cs *ConnectionState) { if l := len(cs.PeerCertificates); l != 0 { t.Errorf("#%d: expected no certificates but got %d", testNum, l) } }, }, { func(clientConfig, serverConfig *Config) { // With TLS 1.1, the SignatureSchemes should be // synthesised from the supported certificate types. clientConfig.MaxVersion = VersionTLS11 clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { if len(cri.SignatureSchemes) == 0 { panic("empty SignatureSchemes") } return new(Certificate), nil } }, "", func(t *testing.T, testNum int, cs *ConnectionState) { if l := len(cs.PeerCertificates); l != 0 { t.Errorf("#%d: expected no certificates but got %d", testNum, l) } }, }, { func(clientConfig, serverConfig *Config) { // Returning an error should abort the handshake with // that error. clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { return nil, errors.New("GetClientCertificate") } }, "GetClientCertificate", func(t *testing.T, testNum int, cs *ConnectionState) { }, }, { func(clientConfig, serverConfig *Config) { clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { if len(cri.AcceptableCAs) == 0 { panic("empty AcceptableCAs") } cert := &Certificate{ Certificate: [][]byte{testRSACertificate}, PrivateKey: testRSAPrivateKey, } return cert, nil } }, "", func(t *testing.T, testNum int, cs *ConnectionState) { if len(cs.VerifiedChains) == 0 { t.Errorf("#%d: expected some verified chains, but found none", testNum) } }, }, } func TestGetClientCertificate(t *testing.T) { t.Run("TLSv12", func(t *testing.T) { testGetClientCertificate(t, VersionTLS12) }) t.Run("TLSv13", func(t *testing.T) { testGetClientCertificate(t, VersionTLS13) }) } func testGetClientCertificate(t *testing.T, version uint16) { issuer, err := x509.ParseCertificate(testRSACertificateIssuer) if err != nil { panic(err) } for i, test := range getClientCertificateTests { serverConfig := testConfig.Clone() serverConfig.ClientAuth = VerifyClientCertIfGiven serverConfig.RootCAs = x509.NewCertPool() serverConfig.RootCAs.AddCert(issuer) serverConfig.ClientCAs = serverConfig.RootCAs serverConfig.Time = func() time.Time { return time.Unix(1476984729, 0) } serverConfig.MaxVersion = version clientConfig := testConfig.Clone() clientConfig.MaxVersion = version test.setup(clientConfig, serverConfig) type serverResult struct { cs ConnectionState err error } c, s := localPipe(t) done := make(chan serverResult) go func() { defer s.Close() server := Server(s, serverConfig) err := server.Handshake() var cs ConnectionState if err == nil { cs = server.ConnectionState() } done <- serverResult{cs, err} }() clientErr := Client(c, clientConfig).Handshake() c.Close() result := <-done if clientErr != nil { if len(test.expectedClientError) == 0 { t.Errorf("#%d: client error: %v", i, clientErr) } else if got := clientErr.Error(); got != test.expectedClientError { t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got) } else { test.verify(t, i, &result.cs) } } else if len(test.expectedClientError) > 0 { t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError) } else if err := result.err; err != nil { t.Errorf("#%d: server error: %v", i, err) } else { test.verify(t, i, &result.cs) } } } func TestRSAPSSKeyError(t *testing.T) { // crypto/tls does not support the rsa_pss_pss_* SignatureSchemes. If support for // public keys with OID RSASSA-PSS is added to crypto/x509, they will be misused with // the rsa_pss_rsae_* SignatureSchemes. Assert that RSASSA-PSS certificates don't // parse, or that they don't carry *rsa.PublicKey keys. b, _ := pem.Decode([]byte(` -----BEGIN CERTIFICATE----- MIIDZTCCAhygAwIBAgIUCF2x0FyTgZG0CC9QTDjGWkB5vgEwPgYJKoZIhvcNAQEK MDGgDTALBglghkgBZQMEAgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQC AgDeMBIxEDAOBgNVBAMMB1JTQS1QU1MwHhcNMTgwNjI3MjI0NDM2WhcNMTgwNzI3 MjI0NDM2WjASMRAwDgYDVQQDDAdSU0EtUFNTMIIBIDALBgkqhkiG9w0BAQoDggEP ADCCAQoCggEBANxDm0f76JdI06YzsjB3AmmjIYkwUEGxePlafmIASFjDZl/elD0Z /a7xLX468b0qGxLS5al7XCcEprSdsDR6DF5L520+pCbpfLyPOjuOvGmk9KzVX4x5 b05YXYuXdsQ0Kjxcx2i3jjCday6scIhMJVgBZxTEyMj1thPQM14SHzKCd/m6HmCL QmswpH2yMAAcBRWzRpp/vdH5DeOJEB3aelq7094no731mrLUCHRiZ1htq8BDB3ou czwqgwspbqZ4dnMXl2MvfySQ5wJUxQwILbiuAKO2lVVPUbFXHE9pgtznNoPvKwQT JNcX8ee8WIZc2SEGzofjk3NpjR+2ADB2u3sCAwEAAaNTMFEwHQYDVR0OBBYEFNEz AdyJ2f+fU+vSCS6QzohnOnprMB8GA1UdIwQYMBaAFNEzAdyJ2f+fU+vSCS6Qzohn OnprMA8GA1UdEwEB/wQFMAMBAf8wPgYJKoZIhvcNAQEKMDGgDTALBglghkgBZQME AgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQCAgDeA4IBAQCjEdrR5aab sZmCwrMeKidXgfkmWvfuLDE+TCbaqDZp7BMWcMQXT9O0UoUT5kqgKj2ARm2pEW0Z H3Z1vj3bbds72qcDIJXp+l0fekyLGeCrX/CbgnMZXEP7+/+P416p34ChR1Wz4dU1 KD3gdsUuTKKeMUog3plxlxQDhRQmiL25ygH1LmjLd6dtIt0GVRGr8lj3euVeprqZ bZ3Uq5eLfsn8oPgfC57gpO6yiN+UURRTlK3bgYvLh4VWB3XXk9UaQZ7Mq1tpXjoD HYFybkWzibkZp4WRo+Fa28rirH+/wHt0vfeN7UCceURZEx4JaxIIfe4ku7uDRhJi RwBA9Xk1KBNF -----END CERTIFICATE-----`)) if b == nil { t.Fatal("Failed to decode certificate") } cert, err := x509.ParseCertificate(b.Bytes) if err != nil { return } if _, ok := cert.PublicKey.(*rsa.PublicKey); ok { t.Error("A RSASSA-PSS certificate was parsed like a PKCS#1 v1.5 one, and it will be mistakenly used with rsa_pss_rsae_* signature algorithms") } } func TestCloseClientConnectionOnIdleServer(t *testing.T) { clientConn, serverConn := localPipe(t) client := Client(clientConn, testConfig.Clone()) go func() { var b [1]byte serverConn.Read(b[:]) client.Close() }() client.SetWriteDeadline(time.Now().Add(time.Minute)) err := client.Handshake() if err != nil { if err, ok := err.(net.Error); ok && err.Timeout() { t.Errorf("Expected a closed network connection error but got '%s'", err.Error()) } } else { t.Errorf("Error expected, but no error returned") } } func testDowngradeCanary(t *testing.T, clientVersion, serverVersion uint16) error { defer func() { testingOnlyForceDowngradeCanary = false }() testingOnlyForceDowngradeCanary = true clientConfig := testConfig.Clone() clientConfig.MaxVersion = clientVersion serverConfig := testConfig.Clone() serverConfig.MaxVersion = serverVersion _, _, err := testHandshake(t, clientConfig, serverConfig) return err } func TestDowngradeCanary(t *testing.T) { if err := testDowngradeCanary(t, VersionTLS13, VersionTLS12); err == nil { t.Errorf("downgrade from TLS 1.3 to TLS 1.2 was not detected") } if testing.Short() { t.Skip("skipping the rest of the checks in short mode") } if err := testDowngradeCanary(t, VersionTLS13, VersionTLS11); err == nil { t.Errorf("downgrade from TLS 1.3 to TLS 1.1 was not detected") } if err := testDowngradeCanary(t, VersionTLS13, VersionTLS10); err == nil { t.Errorf("downgrade from TLS 1.3 to TLS 1.0 was not detected") } if err := testDowngradeCanary(t, VersionTLS12, VersionTLS11); err == nil { t.Errorf("downgrade from TLS 1.2 to TLS 1.1 was not detected") } if err := testDowngradeCanary(t, VersionTLS12, VersionTLS10); err == nil { t.Errorf("downgrade from TLS 1.2 to TLS 1.0 was not detected") } if err := testDowngradeCanary(t, VersionTLS13, VersionTLS13); err != nil { t.Errorf("server unexpectedly sent downgrade canary for TLS 1.3") } if err := testDowngradeCanary(t, VersionTLS12, VersionTLS12); err != nil { t.Errorf("client didn't ignore expected TLS 1.2 canary") } if err := testDowngradeCanary(t, VersionTLS11, VersionTLS11); err != nil { t.Errorf("client unexpectedly reacted to a canary in TLS 1.1") } if err := testDowngradeCanary(t, VersionTLS10, VersionTLS10); err != nil { t.Errorf("client unexpectedly reacted to a canary in TLS 1.0") } } func TestResumptionKeepsOCSPAndSCT(t *testing.T) { t.Run("TLSv12", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS12) }) t.Run("TLSv13", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS13) }) } func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) { issuer, err := x509.ParseCertificate(testRSACertificateIssuer) if err != nil { t.Fatalf("failed to parse test issuer") } roots := x509.NewCertPool() roots.AddCert(issuer) clientConfig := &Config{ MaxVersion: ver, ClientSessionCache: NewLRUClientSessionCache(32), ServerName: "example.golang", RootCAs: roots, } serverConfig := testConfig.Clone() serverConfig.MaxVersion = ver serverConfig.Certificates[0].OCSPStaple = []byte{1, 2, 3} serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{4, 5, 6}} _, ccs, err := testHandshake(t, clientConfig, serverConfig) if err != nil { t.Fatalf("handshake failed: %s", err) } // after a new session we expect to see OCSPResponse and // SignedCertificateTimestamps populated as usual if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) { t.Errorf("client ConnectionState contained unexpected OCSPResponse: wanted %v, got %v", serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse) } if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) { t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps: wanted %v, got %v", serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps) } // if the server doesn't send any SCTs, repopulate the old SCTs oldSCTs := serverConfig.Certificates[0].SignedCertificateTimestamps serverConfig.Certificates[0].SignedCertificateTimestamps = nil _, ccs, err = testHandshake(t, clientConfig, serverConfig) if err != nil { t.Fatalf("handshake failed: %s", err) } if !ccs.DidResume { t.Fatalf("expected session to be resumed") } // after a resumed session we also expect to see OCSPResponse // and SignedCertificateTimestamps populated if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) { t.Errorf("client ConnectionState contained unexpected OCSPResponse after resumption: wanted %v, got %v", serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse) } if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, oldSCTs) { t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v", oldSCTs, ccs.SignedCertificateTimestamps) } // Only test overriding the SCTs for TLS 1.2, since in 1.3 // the server won't send the message containing them if ver == VersionTLS13 { return } // if the server changes the SCTs it sends, they should override the saved SCTs serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{7, 8, 9}} _, ccs, err = testHandshake(t, clientConfig, serverConfig) if err != nil { t.Fatalf("handshake failed: %s", err) } if !ccs.DidResume { t.Fatalf("expected session to be resumed") } if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) { t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v", serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps) } } // TestClientHandshakeContextCancellation tests that cancelling // the context given to the client side conn.HandshakeContext // interrupts the in-progress handshake. func TestClientHandshakeContextCancellation(t *testing.T) { c, s := localPipe(t) ctx, cancel := context.WithCancel(context.Background()) unblockServer := make(chan struct{}) defer close(unblockServer) go func() { cancel() <-unblockServer _ = s.Close() }() cli := Client(c, testConfig) // Initiates client side handshake, which will block until the client hello is read // by the server, unless the cancellation works. err := cli.HandshakeContext(ctx) if err == nil { t.Fatal("Client handshake did not error when the context was canceled") } if err != context.Canceled { t.Errorf("Unexpected client handshake error: %v", err) } if runtime.GOARCH == "wasm" { t.Skip("conn.Close does not error as expected when called multiple times on WASM") } err = cli.Close() if err == nil { t.Error("Client connection was not closed when the context was canceled") } }