Source file src/crypto/tls/handshake_client_test.go

     1  // Copyright 2010 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"crypto/rsa"
    11  	"crypto/x509"
    12  	"encoding/base64"
    13  	"encoding/binary"
    14  	"encoding/pem"
    15  	"errors"
    16  	"fmt"
    17  	"io"
    18  	"math/big"
    19  	"net"
    20  	"os"
    21  	"os/exec"
    22  	"path/filepath"
    23  	"reflect"
    24  	"runtime"
    25  	"strconv"
    26  	"strings"
    27  	"testing"
    28  	"time"
    29  )
    30  
    31  // Note: see comment in handshake_test.go for details of how the reference
    32  // tests work.
    33  
    34  // opensslInputEvent enumerates possible inputs that can be sent to an `openssl
    35  // s_client` process.
    36  type opensslInputEvent int
    37  
    38  const (
    39  	// opensslRenegotiate causes OpenSSL to request a renegotiation of the
    40  	// connection.
    41  	opensslRenegotiate opensslInputEvent = iota
    42  
    43  	// opensslSendBanner causes OpenSSL to send the contents of
    44  	// opensslSentinel on the connection.
    45  	opensslSendSentinel
    46  
    47  	// opensslKeyUpdate causes OpenSSL to send a key update message to the
    48  	// client and request one back.
    49  	opensslKeyUpdate
    50  )
    51  
    52  const opensslSentinel = "SENTINEL\n"
    53  
    54  type opensslInput chan opensslInputEvent
    55  
    56  func (i opensslInput) Read(buf []byte) (n int, err error) {
    57  	for event := range i {
    58  		switch event {
    59  		case opensslRenegotiate:
    60  			return copy(buf, []byte("R\n")), nil
    61  		case opensslKeyUpdate:
    62  			return copy(buf, []byte("K\n")), nil
    63  		case opensslSendSentinel:
    64  			return copy(buf, []byte(opensslSentinel)), nil
    65  		default:
    66  			panic("unknown event")
    67  		}
    68  	}
    69  
    70  	return 0, io.EOF
    71  }
    72  
    73  // opensslOutputSink is an io.Writer that receives the stdout and stderr from an
    74  // `openssl` process and sends a value to handshakeComplete or readKeyUpdate
    75  // when certain messages are seen.
    76  type opensslOutputSink struct {
    77  	handshakeComplete chan struct{}
    78  	readKeyUpdate     chan struct{}
    79  	all               []byte
    80  	line              []byte
    81  }
    82  
    83  func newOpensslOutputSink() *opensslOutputSink {
    84  	return &opensslOutputSink{make(chan struct{}), make(chan struct{}), nil, nil}
    85  }
    86  
    87  // opensslEndOfHandshake is a message that the “openssl s_server” tool will
    88  // print when a handshake completes if run with “-state”.
    89  const opensslEndOfHandshake = "SSL_accept:SSLv3/TLS write finished"
    90  
    91  // opensslReadKeyUpdate is a message that the “openssl s_server” tool will
    92  // print when a KeyUpdate message is received if run with “-state”.
    93  const opensslReadKeyUpdate = "SSL_accept:TLSv1.3 read client key update"
    94  
    95  func (o *opensslOutputSink) Write(data []byte) (n int, err error) {
    96  	o.line = append(o.line, data...)
    97  	o.all = append(o.all, data...)
    98  
    99  	for {
   100  		line, next, ok := bytes.Cut(o.line, []byte("\n"))
   101  		if !ok {
   102  			break
   103  		}
   104  
   105  		if bytes.Equal([]byte(opensslEndOfHandshake), line) {
   106  			o.handshakeComplete <- struct{}{}
   107  		}
   108  		if bytes.Equal([]byte(opensslReadKeyUpdate), line) {
   109  			o.readKeyUpdate <- struct{}{}
   110  		}
   111  		o.line = next
   112  	}
   113  
   114  	return len(data), nil
   115  }
   116  
   117  func (o *opensslOutputSink) String() string {
   118  	return string(o.all)
   119  }
   120  
   121  // clientTest represents a test of the TLS client handshake against a reference
   122  // implementation.
   123  type clientTest struct {
   124  	// name is a freeform string identifying the test and the file in which
   125  	// the expected results will be stored.
   126  	name string
   127  	// args, if not empty, contains a series of arguments for the
   128  	// command to run for the reference server.
   129  	args []string
   130  	// config, if not nil, contains a custom Config to use for this test.
   131  	config *Config
   132  	// cert, if not empty, contains a DER-encoded certificate for the
   133  	// reference server.
   134  	cert []byte
   135  	// key, if not nil, contains either a *rsa.PrivateKey, ed25519.PrivateKey or
   136  	// *ecdsa.PrivateKey which is the private key for the reference server.
   137  	key any
   138  	// extensions, if not nil, contains a list of extension data to be returned
   139  	// from the ServerHello. The data should be in standard TLS format with
   140  	// a 2-byte uint16 type, 2-byte data length, followed by the extension data.
   141  	extensions [][]byte
   142  	// validate, if not nil, is a function that will be called with the
   143  	// ConnectionState of the resulting connection. It returns a non-nil
   144  	// error if the ConnectionState is unacceptable.
   145  	validate func(ConnectionState) error
   146  	// numRenegotiations is the number of times that the connection will be
   147  	// renegotiated.
   148  	numRenegotiations int
   149  	// renegotiationExpectedToFail, if not zero, is the number of the
   150  	// renegotiation attempt that is expected to fail.
   151  	renegotiationExpectedToFail int
   152  	// checkRenegotiationError, if not nil, is called with any error
   153  	// arising from renegotiation. It can map expected errors to nil to
   154  	// ignore them.
   155  	checkRenegotiationError func(renegotiationNum int, err error) error
   156  	// sendKeyUpdate will cause the server to send a KeyUpdate message.
   157  	sendKeyUpdate bool
   158  }
   159  
   160  var serverCommand = []string{"openssl", "s_server", "-no_ticket", "-num_tickets", "0"}
   161  
   162  // connFromCommand starts the reference server process, connects to it and
   163  // returns a recordingConn for the connection. The stdin return value is an
   164  // opensslInput for the stdin of the child process. It must be closed before
   165  // Waiting for child.
   166  func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) {
   167  	cert := testRSACertificate
   168  	if len(test.cert) > 0 {
   169  		cert = test.cert
   170  	}
   171  	certPath := tempFile(string(cert))
   172  	defer os.Remove(certPath)
   173  
   174  	var key any = testRSAPrivateKey
   175  	if test.key != nil {
   176  		key = test.key
   177  	}
   178  	derBytes, err := x509.MarshalPKCS8PrivateKey(key)
   179  	if err != nil {
   180  		panic(err)
   181  	}
   182  
   183  	var pemOut bytes.Buffer
   184  	pem.Encode(&pemOut, &pem.Block{Type: "PRIVATE KEY", Bytes: derBytes})
   185  
   186  	keyPath := tempFile(pemOut.String())
   187  	defer os.Remove(keyPath)
   188  
   189  	var command []string
   190  	command = append(command, serverCommand...)
   191  	command = append(command, test.args...)
   192  	command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath)
   193  	// serverPort contains the port that OpenSSL will listen on. OpenSSL
   194  	// can't take "0" as an argument here so we have to pick a number and
   195  	// hope that it's not in use on the machine. Since this only occurs
   196  	// when -update is given and thus when there's a human watching the
   197  	// test, this isn't too bad.
   198  	const serverPort = 24323
   199  	command = append(command, "-accept", strconv.Itoa(serverPort))
   200  
   201  	if len(test.extensions) > 0 {
   202  		var serverInfo bytes.Buffer
   203  		for _, ext := range test.extensions {
   204  			pem.Encode(&serverInfo, &pem.Block{
   205  				Type:  fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)),
   206  				Bytes: ext,
   207  			})
   208  		}
   209  		serverInfoPath := tempFile(serverInfo.String())
   210  		defer os.Remove(serverInfoPath)
   211  		command = append(command, "-serverinfo", serverInfoPath)
   212  	}
   213  
   214  	if test.numRenegotiations > 0 || test.sendKeyUpdate {
   215  		found := false
   216  		for _, flag := range command[1:] {
   217  			if flag == "-state" {
   218  				found = true
   219  				break
   220  			}
   221  		}
   222  
   223  		if !found {
   224  			panic("-state flag missing to OpenSSL, you need this if testing renegotiation or KeyUpdate")
   225  		}
   226  	}
   227  
   228  	cmd := exec.Command(command[0], command[1:]...)
   229  	stdin = opensslInput(make(chan opensslInputEvent))
   230  	cmd.Stdin = stdin
   231  	out := newOpensslOutputSink()
   232  	cmd.Stdout = out
   233  	cmd.Stderr = out
   234  	if err := cmd.Start(); err != nil {
   235  		return nil, nil, nil, nil, err
   236  	}
   237  
   238  	// OpenSSL does print an "ACCEPT" banner, but it does so *before*
   239  	// opening the listening socket, so we can't use that to wait until it
   240  	// has started listening. Thus we are forced to poll until we get a
   241  	// connection.
   242  	var tcpConn net.Conn
   243  	for i := uint(0); i < 5; i++ {
   244  		tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
   245  			IP:   net.IPv4(127, 0, 0, 1),
   246  			Port: serverPort,
   247  		})
   248  		if err == nil {
   249  			break
   250  		}
   251  		time.Sleep((1 << i) * 5 * time.Millisecond)
   252  	}
   253  	if err != nil {
   254  		close(stdin)
   255  		cmd.Process.Kill()
   256  		err = fmt.Errorf("error connecting to the OpenSSL server: %v (%v)\n\n%s", err, cmd.Wait(), out)
   257  		return nil, nil, nil, nil, err
   258  	}
   259  
   260  	record := &recordingConn{
   261  		Conn: tcpConn,
   262  	}
   263  
   264  	return record, cmd, stdin, out, nil
   265  }
   266  
   267  func (test *clientTest) dataPath() string {
   268  	return filepath.Join("testdata", "Client-"+test.name)
   269  }
   270  
   271  func (test *clientTest) loadData() (flows [][]byte, err error) {
   272  	in, err := os.Open(test.dataPath())
   273  	if err != nil {
   274  		return nil, err
   275  	}
   276  	defer in.Close()
   277  	return parseTestData(in)
   278  }
   279  
   280  func (test *clientTest) run(t *testing.T, write bool) {
   281  	var clientConn, serverConn net.Conn
   282  	var recordingConn *recordingConn
   283  	var childProcess *exec.Cmd
   284  	var stdin opensslInput
   285  	var stdout *opensslOutputSink
   286  
   287  	if write {
   288  		var err error
   289  		recordingConn, childProcess, stdin, stdout, err = test.connFromCommand()
   290  		if err != nil {
   291  			t.Fatalf("Failed to start subcommand: %s", err)
   292  		}
   293  		clientConn = recordingConn
   294  		defer func() {
   295  			if t.Failed() {
   296  				t.Logf("OpenSSL output:\n\n%s", stdout.all)
   297  			}
   298  		}()
   299  	} else {
   300  		clientConn, serverConn = localPipe(t)
   301  	}
   302  
   303  	doneChan := make(chan bool)
   304  	defer func() {
   305  		clientConn.Close()
   306  		<-doneChan
   307  	}()
   308  	go func() {
   309  		defer close(doneChan)
   310  
   311  		config := test.config
   312  		if config == nil {
   313  			config = testConfig
   314  		}
   315  		client := Client(clientConn, config)
   316  		defer client.Close()
   317  
   318  		if _, err := client.Write([]byte("hello\n")); err != nil {
   319  			t.Errorf("Client.Write failed: %s", err)
   320  			return
   321  		}
   322  
   323  		for i := 1; i <= test.numRenegotiations; i++ {
   324  			// The initial handshake will generate a
   325  			// handshakeComplete signal which needs to be quashed.
   326  			if i == 1 && write {
   327  				<-stdout.handshakeComplete
   328  			}
   329  
   330  			// OpenSSL will try to interleave application data and
   331  			// a renegotiation if we send both concurrently.
   332  			// Therefore: ask OpensSSL to start a renegotiation, run
   333  			// a goroutine to call client.Read and thus process the
   334  			// renegotiation request, watch for OpenSSL's stdout to
   335  			// indicate that the handshake is complete and,
   336  			// finally, have OpenSSL write something to cause
   337  			// client.Read to complete.
   338  			if write {
   339  				stdin <- opensslRenegotiate
   340  			}
   341  
   342  			signalChan := make(chan struct{})
   343  
   344  			go func() {
   345  				defer close(signalChan)
   346  
   347  				buf := make([]byte, 256)
   348  				n, err := client.Read(buf)
   349  
   350  				if test.checkRenegotiationError != nil {
   351  					newErr := test.checkRenegotiationError(i, err)
   352  					if err != nil && newErr == nil {
   353  						return
   354  					}
   355  					err = newErr
   356  				}
   357  
   358  				if err != nil {
   359  					t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
   360  					return
   361  				}
   362  
   363  				buf = buf[:n]
   364  				if !bytes.Equal([]byte(opensslSentinel), buf) {
   365  					t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
   366  				}
   367  
   368  				if expected := i + 1; client.handshakes != expected {
   369  					t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
   370  				}
   371  			}()
   372  
   373  			if write && test.renegotiationExpectedToFail != i {
   374  				<-stdout.handshakeComplete
   375  				stdin <- opensslSendSentinel
   376  			}
   377  			<-signalChan
   378  		}
   379  
   380  		if test.sendKeyUpdate {
   381  			if write {
   382  				<-stdout.handshakeComplete
   383  				stdin <- opensslKeyUpdate
   384  			}
   385  
   386  			doneRead := make(chan struct{})
   387  
   388  			go func() {
   389  				defer close(doneRead)
   390  
   391  				buf := make([]byte, 256)
   392  				n, err := client.Read(buf)
   393  
   394  				if err != nil {
   395  					t.Errorf("Client.Read failed after KeyUpdate: %s", err)
   396  					return
   397  				}
   398  
   399  				buf = buf[:n]
   400  				if !bytes.Equal([]byte(opensslSentinel), buf) {
   401  					t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
   402  				}
   403  			}()
   404  
   405  			if write {
   406  				// There's no real reason to wait for the client KeyUpdate to
   407  				// send data with the new server keys, except that s_server
   408  				// drops writes if they are sent at the wrong time.
   409  				<-stdout.readKeyUpdate
   410  				stdin <- opensslSendSentinel
   411  			}
   412  			<-doneRead
   413  
   414  			if _, err := client.Write([]byte("hello again\n")); err != nil {
   415  				t.Errorf("Client.Write failed: %s", err)
   416  				return
   417  			}
   418  		}
   419  
   420  		if test.validate != nil {
   421  			if err := test.validate(client.ConnectionState()); err != nil {
   422  				t.Errorf("validate callback returned error: %s", err)
   423  			}
   424  		}
   425  
   426  		// If the server sent us an alert after our last flight, give it a
   427  		// chance to arrive.
   428  		if write && test.renegotiationExpectedToFail == 0 {
   429  			if err := peekError(client); err != nil {
   430  				t.Errorf("final Read returned an error: %s", err)
   431  			}
   432  		}
   433  	}()
   434  
   435  	if !write {
   436  		flows, err := test.loadData()
   437  		if err != nil {
   438  			t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
   439  		}
   440  		for i, b := range flows {
   441  			if i%2 == 1 {
   442  				if *fast {
   443  					serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second))
   444  				} else {
   445  					serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute))
   446  				}
   447  				serverConn.Write(b)
   448  				continue
   449  			}
   450  			bb := make([]byte, len(b))
   451  			if *fast {
   452  				serverConn.SetReadDeadline(time.Now().Add(1 * time.Second))
   453  			} else {
   454  				serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute))
   455  			}
   456  			_, err := io.ReadFull(serverConn, bb)
   457  			if err != nil {
   458  				t.Fatalf("%s, flow %d: %s", test.name, i+1, err)
   459  			}
   460  			if !bytes.Equal(b, bb) {
   461  				t.Fatalf("%s, flow %d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
   462  			}
   463  		}
   464  	}
   465  
   466  	<-doneChan
   467  	if !write {
   468  		serverConn.Close()
   469  	}
   470  
   471  	if write {
   472  		path := test.dataPath()
   473  		out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
   474  		if err != nil {
   475  			t.Fatalf("Failed to create output file: %s", err)
   476  		}
   477  		defer out.Close()
   478  		recordingConn.Close()
   479  		close(stdin)
   480  		childProcess.Process.Kill()
   481  		childProcess.Wait()
   482  		if len(recordingConn.flows) < 3 {
   483  			t.Fatalf("Client connection didn't work")
   484  		}
   485  		recordingConn.WriteTo(out)
   486  		t.Logf("Wrote %s\n", path)
   487  	}
   488  }
   489  
   490  // peekError does a read with a short timeout to check if the next read would
   491  // cause an error, for example if there is an alert waiting on the wire.
   492  func peekError(conn net.Conn) error {
   493  	conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
   494  	if n, err := conn.Read(make([]byte, 1)); n != 0 {
   495  		return errors.New("unexpectedly read data")
   496  	} else if err != nil {
   497  		if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
   498  			return err
   499  		}
   500  	}
   501  	return nil
   502  }
   503  
   504  func runClientTestForVersion(t *testing.T, template *clientTest, version, option string) {
   505  	// Make a deep copy of the template before going parallel.
   506  	test := *template
   507  	if template.config != nil {
   508  		test.config = template.config.Clone()
   509  	}
   510  	test.name = version + "-" + test.name
   511  	test.args = append([]string{option}, test.args...)
   512  
   513  	runTestAndUpdateIfNeeded(t, version, test.run, false)
   514  }
   515  
   516  func runClientTestTLS10(t *testing.T, template *clientTest) {
   517  	runClientTestForVersion(t, template, "TLSv10", "-tls1")
   518  }
   519  
   520  func runClientTestTLS11(t *testing.T, template *clientTest) {
   521  	runClientTestForVersion(t, template, "TLSv11", "-tls1_1")
   522  }
   523  
   524  func runClientTestTLS12(t *testing.T, template *clientTest) {
   525  	runClientTestForVersion(t, template, "TLSv12", "-tls1_2")
   526  }
   527  
   528  func runClientTestTLS13(t *testing.T, template *clientTest) {
   529  	runClientTestForVersion(t, template, "TLSv13", "-tls1_3")
   530  }
   531  
   532  func TestHandshakeClientRSARC4(t *testing.T) {
   533  	test := &clientTest{
   534  		name: "RSA-RC4",
   535  		args: []string{"-cipher", "RC4-SHA"},
   536  	}
   537  	runClientTestTLS10(t, test)
   538  	runClientTestTLS11(t, test)
   539  	runClientTestTLS12(t, test)
   540  }
   541  
   542  func TestHandshakeClientRSAAES128GCM(t *testing.T) {
   543  	test := &clientTest{
   544  		name: "AES128-GCM-SHA256",
   545  		args: []string{"-cipher", "AES128-GCM-SHA256"},
   546  	}
   547  	runClientTestTLS12(t, test)
   548  }
   549  
   550  func TestHandshakeClientRSAAES256GCM(t *testing.T) {
   551  	test := &clientTest{
   552  		name: "AES256-GCM-SHA384",
   553  		args: []string{"-cipher", "AES256-GCM-SHA384"},
   554  	}
   555  	runClientTestTLS12(t, test)
   556  }
   557  
   558  func TestHandshakeClientECDHERSAAES(t *testing.T) {
   559  	test := &clientTest{
   560  		name: "ECDHE-RSA-AES",
   561  		args: []string{"-cipher", "ECDHE-RSA-AES128-SHA"},
   562  	}
   563  	runClientTestTLS10(t, test)
   564  	runClientTestTLS11(t, test)
   565  	runClientTestTLS12(t, test)
   566  }
   567  
   568  func TestHandshakeClientECDHEECDSAAES(t *testing.T) {
   569  	test := &clientTest{
   570  		name: "ECDHE-ECDSA-AES",
   571  		args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA"},
   572  		cert: testECDSACertificate,
   573  		key:  testECDSAPrivateKey,
   574  	}
   575  	runClientTestTLS10(t, test)
   576  	runClientTestTLS11(t, test)
   577  	runClientTestTLS12(t, test)
   578  }
   579  
   580  func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) {
   581  	test := &clientTest{
   582  		name: "ECDHE-ECDSA-AES-GCM",
   583  		args: []string{"-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"},
   584  		cert: testECDSACertificate,
   585  		key:  testECDSAPrivateKey,
   586  	}
   587  	runClientTestTLS12(t, test)
   588  }
   589  
   590  func TestHandshakeClientAES256GCMSHA384(t *testing.T) {
   591  	test := &clientTest{
   592  		name: "ECDHE-ECDSA-AES256-GCM-SHA384",
   593  		args: []string{"-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"},
   594  		cert: testECDSACertificate,
   595  		key:  testECDSAPrivateKey,
   596  	}
   597  	runClientTestTLS12(t, test)
   598  }
   599  
   600  func TestHandshakeClientAES128CBCSHA256(t *testing.T) {
   601  	test := &clientTest{
   602  		name: "AES128-SHA256",
   603  		args: []string{"-cipher", "AES128-SHA256"},
   604  	}
   605  	runClientTestTLS12(t, test)
   606  }
   607  
   608  func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) {
   609  	test := &clientTest{
   610  		name: "ECDHE-RSA-AES128-SHA256",
   611  		args: []string{"-cipher", "ECDHE-RSA-AES128-SHA256"},
   612  	}
   613  	runClientTestTLS12(t, test)
   614  }
   615  
   616  func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) {
   617  	test := &clientTest{
   618  		name: "ECDHE-ECDSA-AES128-SHA256",
   619  		args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA256"},
   620  		cert: testECDSACertificate,
   621  		key:  testECDSAPrivateKey,
   622  	}
   623  	runClientTestTLS12(t, test)
   624  }
   625  
   626  func TestHandshakeClientX25519(t *testing.T) {
   627  	config := testConfig.Clone()
   628  	config.CurvePreferences = []CurveID{X25519}
   629  
   630  	test := &clientTest{
   631  		name:   "X25519-ECDHE",
   632  		args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "X25519"},
   633  		config: config,
   634  	}
   635  
   636  	runClientTestTLS12(t, test)
   637  	runClientTestTLS13(t, test)
   638  }
   639  
   640  func TestHandshakeClientP256(t *testing.T) {
   641  	config := testConfig.Clone()
   642  	config.CurvePreferences = []CurveID{CurveP256}
   643  
   644  	test := &clientTest{
   645  		name:   "P256-ECDHE",
   646  		args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
   647  		config: config,
   648  	}
   649  
   650  	runClientTestTLS12(t, test)
   651  	runClientTestTLS13(t, test)
   652  }
   653  
   654  func TestHandshakeClientHelloRetryRequest(t *testing.T) {
   655  	config := testConfig.Clone()
   656  	config.CurvePreferences = []CurveID{X25519, CurveP256}
   657  
   658  	test := &clientTest{
   659  		name:   "HelloRetryRequest",
   660  		args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
   661  		config: config,
   662  	}
   663  
   664  	runClientTestTLS13(t, test)
   665  }
   666  
   667  func TestHandshakeClientECDHERSAChaCha20(t *testing.T) {
   668  	config := testConfig.Clone()
   669  	config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}
   670  
   671  	test := &clientTest{
   672  		name:   "ECDHE-RSA-CHACHA20-POLY1305",
   673  		args:   []string{"-cipher", "ECDHE-RSA-CHACHA20-POLY1305"},
   674  		config: config,
   675  	}
   676  
   677  	runClientTestTLS12(t, test)
   678  }
   679  
   680  func TestHandshakeClientECDHEECDSAChaCha20(t *testing.T) {
   681  	config := testConfig.Clone()
   682  	config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305}
   683  
   684  	test := &clientTest{
   685  		name:   "ECDHE-ECDSA-CHACHA20-POLY1305",
   686  		args:   []string{"-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305"},
   687  		config: config,
   688  		cert:   testECDSACertificate,
   689  		key:    testECDSAPrivateKey,
   690  	}
   691  
   692  	runClientTestTLS12(t, test)
   693  }
   694  
   695  func TestHandshakeClientAES128SHA256(t *testing.T) {
   696  	test := &clientTest{
   697  		name: "AES128-SHA256",
   698  		args: []string{"-ciphersuites", "TLS_AES_128_GCM_SHA256"},
   699  	}
   700  	runClientTestTLS13(t, test)
   701  }
   702  func TestHandshakeClientAES256SHA384(t *testing.T) {
   703  	test := &clientTest{
   704  		name: "AES256-SHA384",
   705  		args: []string{"-ciphersuites", "TLS_AES_256_GCM_SHA384"},
   706  	}
   707  	runClientTestTLS13(t, test)
   708  }
   709  func TestHandshakeClientCHACHA20SHA256(t *testing.T) {
   710  	test := &clientTest{
   711  		name: "CHACHA20-SHA256",
   712  		args: []string{"-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"},
   713  	}
   714  	runClientTestTLS13(t, test)
   715  }
   716  
   717  func TestHandshakeClientECDSATLS13(t *testing.T) {
   718  	test := &clientTest{
   719  		name: "ECDSA",
   720  		cert: testECDSACertificate,
   721  		key:  testECDSAPrivateKey,
   722  	}
   723  	runClientTestTLS13(t, test)
   724  }
   725  
   726  func TestHandshakeClientEd25519(t *testing.T) {
   727  	test := &clientTest{
   728  		name: "Ed25519",
   729  		cert: testEd25519Certificate,
   730  		key:  testEd25519PrivateKey,
   731  	}
   732  	runClientTestTLS12(t, test)
   733  	runClientTestTLS13(t, test)
   734  
   735  	config := testConfig.Clone()
   736  	cert, _ := X509KeyPair([]byte(clientEd25519CertificatePEM), []byte(clientEd25519KeyPEM))
   737  	config.Certificates = []Certificate{cert}
   738  
   739  	test = &clientTest{
   740  		name:   "ClientCert-Ed25519",
   741  		args:   []string{"-Verify", "1"},
   742  		config: config,
   743  	}
   744  
   745  	runClientTestTLS12(t, test)
   746  	runClientTestTLS13(t, test)
   747  }
   748  
   749  func TestHandshakeClientCertRSA(t *testing.T) {
   750  	config := testConfig.Clone()
   751  	cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
   752  	config.Certificates = []Certificate{cert}
   753  
   754  	test := &clientTest{
   755  		name:   "ClientCert-RSA-RSA",
   756  		args:   []string{"-cipher", "AES128", "-Verify", "1"},
   757  		config: config,
   758  	}
   759  
   760  	runClientTestTLS10(t, test)
   761  	runClientTestTLS12(t, test)
   762  
   763  	test = &clientTest{
   764  		name:   "ClientCert-RSA-ECDSA",
   765  		args:   []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
   766  		config: config,
   767  		cert:   testECDSACertificate,
   768  		key:    testECDSAPrivateKey,
   769  	}
   770  
   771  	runClientTestTLS10(t, test)
   772  	runClientTestTLS12(t, test)
   773  	runClientTestTLS13(t, test)
   774  
   775  	test = &clientTest{
   776  		name:   "ClientCert-RSA-AES256-GCM-SHA384",
   777  		args:   []string{"-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-Verify", "1"},
   778  		config: config,
   779  		cert:   testRSACertificate,
   780  		key:    testRSAPrivateKey,
   781  	}
   782  
   783  	runClientTestTLS12(t, test)
   784  }
   785  
   786  func TestHandshakeClientCertECDSA(t *testing.T) {
   787  	config := testConfig.Clone()
   788  	cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
   789  	config.Certificates = []Certificate{cert}
   790  
   791  	test := &clientTest{
   792  		name:   "ClientCert-ECDSA-RSA",
   793  		args:   []string{"-cipher", "AES128", "-Verify", "1"},
   794  		config: config,
   795  	}
   796  
   797  	runClientTestTLS10(t, test)
   798  	runClientTestTLS12(t, test)
   799  	runClientTestTLS13(t, test)
   800  
   801  	test = &clientTest{
   802  		name:   "ClientCert-ECDSA-ECDSA",
   803  		args:   []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
   804  		config: config,
   805  		cert:   testECDSACertificate,
   806  		key:    testECDSAPrivateKey,
   807  	}
   808  
   809  	runClientTestTLS10(t, test)
   810  	runClientTestTLS12(t, test)
   811  }
   812  
   813  // TestHandshakeClientCertRSAPSS tests rsa_pss_rsae_sha256 signatures from both
   814  // client and server certificates. It also serves from both sides a certificate
   815  // signed itself with RSA-PSS, mostly to check that crypto/x509 chain validation
   816  // works.
   817  func TestHandshakeClientCertRSAPSS(t *testing.T) {
   818  	cert, err := x509.ParseCertificate(testRSAPSSCertificate)
   819  	if err != nil {
   820  		panic(err)
   821  	}
   822  	rootCAs := x509.NewCertPool()
   823  	rootCAs.AddCert(cert)
   824  
   825  	config := testConfig.Clone()
   826  	// Use GetClientCertificate to bypass the client certificate selection logic.
   827  	config.GetClientCertificate = func(*CertificateRequestInfo) (*Certificate, error) {
   828  		return &Certificate{
   829  			Certificate: [][]byte{testRSAPSSCertificate},
   830  			PrivateKey:  testRSAPrivateKey,
   831  		}, nil
   832  	}
   833  	config.RootCAs = rootCAs
   834  
   835  	test := &clientTest{
   836  		name: "ClientCert-RSA-RSAPSS",
   837  		args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
   838  			"rsa_pss_rsae_sha256", "-sigalgs", "rsa_pss_rsae_sha256"},
   839  		config: config,
   840  		cert:   testRSAPSSCertificate,
   841  		key:    testRSAPrivateKey,
   842  	}
   843  	runClientTestTLS12(t, test)
   844  	runClientTestTLS13(t, test)
   845  }
   846  
   847  func TestHandshakeClientCertRSAPKCS1v15(t *testing.T) {
   848  	config := testConfig.Clone()
   849  	cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
   850  	config.Certificates = []Certificate{cert}
   851  
   852  	test := &clientTest{
   853  		name: "ClientCert-RSA-RSAPKCS1v15",
   854  		args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
   855  			"rsa_pkcs1_sha256", "-sigalgs", "rsa_pkcs1_sha256"},
   856  		config: config,
   857  	}
   858  
   859  	runClientTestTLS12(t, test)
   860  }
   861  
   862  func TestClientKeyUpdate(t *testing.T) {
   863  	test := &clientTest{
   864  		name:          "KeyUpdate",
   865  		args:          []string{"-state"},
   866  		sendKeyUpdate: true,
   867  	}
   868  	runClientTestTLS13(t, test)
   869  }
   870  
   871  func TestResumption(t *testing.T) {
   872  	t.Run("TLSv12", func(t *testing.T) { testResumption(t, VersionTLS12) })
   873  	t.Run("TLSv13", func(t *testing.T) { testResumption(t, VersionTLS13) })
   874  }
   875  
   876  func testResumption(t *testing.T, version uint16) {
   877  	if testing.Short() {
   878  		t.Skip("skipping in -short mode")
   879  	}
   880  	serverConfig := &Config{
   881  		MaxVersion:   version,
   882  		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
   883  		Certificates: testConfig.Certificates,
   884  	}
   885  
   886  	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
   887  	if err != nil {
   888  		panic(err)
   889  	}
   890  
   891  	rootCAs := x509.NewCertPool()
   892  	rootCAs.AddCert(issuer)
   893  
   894  	clientConfig := &Config{
   895  		MaxVersion:         version,
   896  		CipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
   897  		ClientSessionCache: NewLRUClientSessionCache(32),
   898  		RootCAs:            rootCAs,
   899  		ServerName:         "example.golang",
   900  	}
   901  
   902  	testResumeState := func(test string, didResume bool) {
   903  		_, hs, err := testHandshake(t, clientConfig, serverConfig)
   904  		if err != nil {
   905  			t.Fatalf("%s: handshake failed: %s", test, err)
   906  		}
   907  		if hs.DidResume != didResume {
   908  			t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
   909  		}
   910  		if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
   911  			t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
   912  		}
   913  		if got, want := hs.ServerName, clientConfig.ServerName; got != want {
   914  			t.Errorf("%s: server name %s, want %s", test, got, want)
   915  		}
   916  	}
   917  
   918  	getTicket := func() []byte {
   919  		return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket
   920  	}
   921  	deleteTicket := func() {
   922  		ticketKey := clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).sessionKey
   923  		clientConfig.ClientSessionCache.Put(ticketKey, nil)
   924  	}
   925  	corruptTicket := func() {
   926  		clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.masterSecret[0] ^= 0xff
   927  	}
   928  	randomKey := func() [32]byte {
   929  		var k [32]byte
   930  		if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil {
   931  			t.Fatalf("Failed to read new SessionTicketKey: %s", err)
   932  		}
   933  		return k
   934  	}
   935  
   936  	testResumeState("Handshake", false)
   937  	ticket := getTicket()
   938  	testResumeState("Resume", true)
   939  	if !bytes.Equal(ticket, getTicket()) && version != VersionTLS13 {
   940  		t.Fatal("first ticket doesn't match ticket after resumption")
   941  	}
   942  	if bytes.Equal(ticket, getTicket()) && version == VersionTLS13 {
   943  		t.Fatal("ticket didn't change after resumption")
   944  	}
   945  
   946  	// An old session ticket can resume, but the server will provide a ticket encrypted with a fresh key.
   947  	serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
   948  	testResumeState("ResumeWithOldTicket", true)
   949  	if bytes.Equal(ticket[:ticketKeyNameLen], getTicket()[:ticketKeyNameLen]) {
   950  		t.Fatal("old first ticket matches the fresh one")
   951  	}
   952  
   953  	// Now the session tickey key is expired, so a full handshake should occur.
   954  	serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
   955  	testResumeState("ResumeWithExpiredTicket", false)
   956  	if bytes.Equal(ticket, getTicket()) {
   957  		t.Fatal("expired first ticket matches the fresh one")
   958  	}
   959  
   960  	serverConfig.Time = func() time.Time { return time.Now() } // reset the time back
   961  	key1 := randomKey()
   962  	serverConfig.SetSessionTicketKeys([][32]byte{key1})
   963  
   964  	testResumeState("InvalidSessionTicketKey", false)
   965  	testResumeState("ResumeAfterInvalidSessionTicketKey", true)
   966  
   967  	key2 := randomKey()
   968  	serverConfig.SetSessionTicketKeys([][32]byte{key2, key1})
   969  	ticket = getTicket()
   970  	testResumeState("KeyChange", true)
   971  	if bytes.Equal(ticket, getTicket()) {
   972  		t.Fatal("new ticket wasn't included while resuming")
   973  	}
   974  	testResumeState("KeyChangeFinish", true)
   975  
   976  	// Age the session ticket a bit, but not yet expired.
   977  	serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
   978  	testResumeState("OldSessionTicket", true)
   979  	ticket = getTicket()
   980  	// Expire the session ticket, which would force a full handshake.
   981  	serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
   982  	testResumeState("ExpiredSessionTicket", false)
   983  	if bytes.Equal(ticket, getTicket()) {
   984  		t.Fatal("new ticket wasn't provided after old ticket expired")
   985  	}
   986  
   987  	// Age the session ticket a bit at a time, but don't expire it.
   988  	d := 0 * time.Hour
   989  	for i := 0; i < 13; i++ {
   990  		d += 12 * time.Hour
   991  		serverConfig.Time = func() time.Time { return time.Now().Add(d) }
   992  		testResumeState("OldSessionTicket", true)
   993  	}
   994  	// Expire it (now a little more than 7 days) and make sure a full
   995  	// handshake occurs for TLS 1.2. Resumption should still occur for
   996  	// TLS 1.3 since the client should be using a fresh ticket sent over
   997  	// by the server.
   998  	d += 12 * time.Hour
   999  	serverConfig.Time = func() time.Time { return time.Now().Add(d) }
  1000  	if version == VersionTLS13 {
  1001  		testResumeState("ExpiredSessionTicket", true)
  1002  	} else {
  1003  		testResumeState("ExpiredSessionTicket", false)
  1004  	}
  1005  	if bytes.Equal(ticket, getTicket()) {
  1006  		t.Fatal("new ticket wasn't provided after old ticket expired")
  1007  	}
  1008  
  1009  	// Reset serverConfig to ensure that calling SetSessionTicketKeys
  1010  	// before the serverConfig is used works.
  1011  	serverConfig = &Config{
  1012  		MaxVersion:   version,
  1013  		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
  1014  		Certificates: testConfig.Certificates,
  1015  	}
  1016  	serverConfig.SetSessionTicketKeys([][32]byte{key2})
  1017  
  1018  	testResumeState("FreshConfig", true)
  1019  
  1020  	// In TLS 1.3, cross-cipher suite resumption is allowed as long as the KDF
  1021  	// hash matches. Also, Config.CipherSuites does not apply to TLS 1.3.
  1022  	if version != VersionTLS13 {
  1023  		clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
  1024  		testResumeState("DifferentCipherSuite", false)
  1025  		testResumeState("DifferentCipherSuiteRecovers", true)
  1026  	}
  1027  
  1028  	deleteTicket()
  1029  	testResumeState("WithoutSessionTicket", false)
  1030  
  1031  	// Session resumption should work when using client certificates
  1032  	deleteTicket()
  1033  	serverConfig.ClientCAs = rootCAs
  1034  	serverConfig.ClientAuth = RequireAndVerifyClientCert
  1035  	clientConfig.Certificates = serverConfig.Certificates
  1036  	testResumeState("InitialHandshake", false)
  1037  	testResumeState("WithClientCertificates", true)
  1038  	serverConfig.ClientAuth = NoClientCert
  1039  
  1040  	// Tickets should be removed from the session cache on TLS handshake
  1041  	// failure, and the client should recover from a corrupted PSK
  1042  	testResumeState("FetchTicketToCorrupt", false)
  1043  	corruptTicket()
  1044  	_, _, err = testHandshake(t, clientConfig, serverConfig)
  1045  	if err == nil {
  1046  		t.Fatalf("handshake did not fail with a corrupted client secret")
  1047  	}
  1048  	testResumeState("AfterHandshakeFailure", false)
  1049  
  1050  	clientConfig.ClientSessionCache = nil
  1051  	testResumeState("WithoutSessionCache", false)
  1052  }
  1053  
  1054  func TestLRUClientSessionCache(t *testing.T) {
  1055  	// Initialize cache of capacity 4.
  1056  	cache := NewLRUClientSessionCache(4)
  1057  	cs := make([]ClientSessionState, 6)
  1058  	keys := []string{"0", "1", "2", "3", "4", "5", "6"}
  1059  
  1060  	// Add 4 entries to the cache and look them up.
  1061  	for i := 0; i < 4; i++ {
  1062  		cache.Put(keys[i], &cs[i])
  1063  	}
  1064  	for i := 0; i < 4; i++ {
  1065  		if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
  1066  			t.Fatalf("session cache failed lookup for added key: %s", keys[i])
  1067  		}
  1068  	}
  1069  
  1070  	// Add 2 more entries to the cache. First 2 should be evicted.
  1071  	for i := 4; i < 6; i++ {
  1072  		cache.Put(keys[i], &cs[i])
  1073  	}
  1074  	for i := 0; i < 2; i++ {
  1075  		if s, ok := cache.Get(keys[i]); ok || s != nil {
  1076  			t.Fatalf("session cache should have evicted key: %s", keys[i])
  1077  		}
  1078  	}
  1079  
  1080  	// Touch entry 2. LRU should evict 3 next.
  1081  	cache.Get(keys[2])
  1082  	cache.Put(keys[0], &cs[0])
  1083  	if s, ok := cache.Get(keys[3]); ok || s != nil {
  1084  		t.Fatalf("session cache should have evicted key 3")
  1085  	}
  1086  
  1087  	// Update entry 0 in place.
  1088  	cache.Put(keys[0], &cs[3])
  1089  	if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
  1090  		t.Fatalf("session cache failed update for key 0")
  1091  	}
  1092  
  1093  	// Calling Put with a nil entry deletes the key.
  1094  	cache.Put(keys[0], nil)
  1095  	if _, ok := cache.Get(keys[0]); ok {
  1096  		t.Fatalf("session cache failed to delete key 0")
  1097  	}
  1098  
  1099  	// Delete entry 2. LRU should keep 4 and 5
  1100  	cache.Put(keys[2], nil)
  1101  	if _, ok := cache.Get(keys[2]); ok {
  1102  		t.Fatalf("session cache failed to delete key 4")
  1103  	}
  1104  	for i := 4; i < 6; i++ {
  1105  		if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
  1106  			t.Fatalf("session cache should not have deleted key: %s", keys[i])
  1107  		}
  1108  	}
  1109  }
  1110  
  1111  func TestKeyLogTLS12(t *testing.T) {
  1112  	var serverBuf, clientBuf bytes.Buffer
  1113  
  1114  	clientConfig := testConfig.Clone()
  1115  	clientConfig.KeyLogWriter = &clientBuf
  1116  	clientConfig.MaxVersion = VersionTLS12
  1117  
  1118  	serverConfig := testConfig.Clone()
  1119  	serverConfig.KeyLogWriter = &serverBuf
  1120  	serverConfig.MaxVersion = VersionTLS12
  1121  
  1122  	c, s := localPipe(t)
  1123  	done := make(chan bool)
  1124  
  1125  	go func() {
  1126  		defer close(done)
  1127  
  1128  		if err := Server(s, serverConfig).Handshake(); err != nil {
  1129  			t.Errorf("server: %s", err)
  1130  			return
  1131  		}
  1132  		s.Close()
  1133  	}()
  1134  
  1135  	if err := Client(c, clientConfig).Handshake(); err != nil {
  1136  		t.Fatalf("client: %s", err)
  1137  	}
  1138  
  1139  	c.Close()
  1140  	<-done
  1141  
  1142  	checkKeylogLine := func(side, loggedLine string) {
  1143  		if len(loggedLine) == 0 {
  1144  			t.Fatalf("%s: no keylog line was produced", side)
  1145  		}
  1146  		const expectedLen = 13 /* "CLIENT_RANDOM" */ +
  1147  			1 /* space */ +
  1148  			32*2 /* hex client nonce */ +
  1149  			1 /* space */ +
  1150  			48*2 /* hex master secret */ +
  1151  			1 /* new line */
  1152  		if len(loggedLine) != expectedLen {
  1153  			t.Fatalf("%s: keylog line has incorrect length (want %d, got %d): %q", side, expectedLen, len(loggedLine), loggedLine)
  1154  		}
  1155  		if !strings.HasPrefix(loggedLine, "CLIENT_RANDOM "+strings.Repeat("0", 64)+" ") {
  1156  			t.Fatalf("%s: keylog line has incorrect structure or nonce: %q", side, loggedLine)
  1157  		}
  1158  	}
  1159  
  1160  	checkKeylogLine("client", clientBuf.String())
  1161  	checkKeylogLine("server", serverBuf.String())
  1162  }
  1163  
  1164  func TestKeyLogTLS13(t *testing.T) {
  1165  	var serverBuf, clientBuf bytes.Buffer
  1166  
  1167  	clientConfig := testConfig.Clone()
  1168  	clientConfig.KeyLogWriter = &clientBuf
  1169  
  1170  	serverConfig := testConfig.Clone()
  1171  	serverConfig.KeyLogWriter = &serverBuf
  1172  
  1173  	c, s := localPipe(t)
  1174  	done := make(chan bool)
  1175  
  1176  	go func() {
  1177  		defer close(done)
  1178  
  1179  		if err := Server(s, serverConfig).Handshake(); err != nil {
  1180  			t.Errorf("server: %s", err)
  1181  			return
  1182  		}
  1183  		s.Close()
  1184  	}()
  1185  
  1186  	if err := Client(c, clientConfig).Handshake(); err != nil {
  1187  		t.Fatalf("client: %s", err)
  1188  	}
  1189  
  1190  	c.Close()
  1191  	<-done
  1192  
  1193  	checkKeylogLines := func(side, loggedLines string) {
  1194  		loggedLines = strings.TrimSpace(loggedLines)
  1195  		lines := strings.Split(loggedLines, "\n")
  1196  		if len(lines) != 4 {
  1197  			t.Errorf("Expected the %s to log 4 lines, got %d", side, len(lines))
  1198  		}
  1199  	}
  1200  
  1201  	checkKeylogLines("client", clientBuf.String())
  1202  	checkKeylogLines("server", serverBuf.String())
  1203  }
  1204  
  1205  func TestHandshakeClientALPNMatch(t *testing.T) {
  1206  	config := testConfig.Clone()
  1207  	config.NextProtos = []string{"proto2", "proto1"}
  1208  
  1209  	test := &clientTest{
  1210  		name: "ALPN",
  1211  		// Note that this needs OpenSSL 1.0.2 because that is the first
  1212  		// version that supports the -alpn flag.
  1213  		args:   []string{"-alpn", "proto1,proto2"},
  1214  		config: config,
  1215  		validate: func(state ConnectionState) error {
  1216  			// The server's preferences should override the client.
  1217  			if state.NegotiatedProtocol != "proto1" {
  1218  				return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol)
  1219  			}
  1220  			return nil
  1221  		},
  1222  	}
  1223  	runClientTestTLS12(t, test)
  1224  	runClientTestTLS13(t, test)
  1225  }
  1226  
  1227  func TestServerSelectingUnconfiguredApplicationProtocol(t *testing.T) {
  1228  	// This checks that the server can't select an application protocol that the
  1229  	// client didn't offer.
  1230  
  1231  	c, s := localPipe(t)
  1232  	errChan := make(chan error, 1)
  1233  
  1234  	go func() {
  1235  		client := Client(c, &Config{
  1236  			ServerName:   "foo",
  1237  			CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
  1238  			NextProtos:   []string{"http", "something-else"},
  1239  		})
  1240  		errChan <- client.Handshake()
  1241  	}()
  1242  
  1243  	var header [5]byte
  1244  	if _, err := io.ReadFull(s, header[:]); err != nil {
  1245  		t.Fatal(err)
  1246  	}
  1247  	recordLen := int(header[3])<<8 | int(header[4])
  1248  
  1249  	record := make([]byte, recordLen)
  1250  	if _, err := io.ReadFull(s, record); err != nil {
  1251  		t.Fatal(err)
  1252  	}
  1253  
  1254  	serverHello := &serverHelloMsg{
  1255  		vers:         VersionTLS12,
  1256  		random:       make([]byte, 32),
  1257  		cipherSuite:  TLS_RSA_WITH_AES_128_GCM_SHA256,
  1258  		alpnProtocol: "how-about-this",
  1259  	}
  1260  	serverHelloBytes := serverHello.marshal()
  1261  
  1262  	s.Write([]byte{
  1263  		byte(recordTypeHandshake),
  1264  		byte(VersionTLS12 >> 8),
  1265  		byte(VersionTLS12 & 0xff),
  1266  		byte(len(serverHelloBytes) >> 8),
  1267  		byte(len(serverHelloBytes)),
  1268  	})
  1269  	s.Write(serverHelloBytes)
  1270  	s.Close()
  1271  
  1272  	if err := <-errChan; !strings.Contains(err.Error(), "server selected unadvertised ALPN protocol") {
  1273  		t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
  1274  	}
  1275  }
  1276  
  1277  // sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443`
  1278  const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0="
  1279  
  1280  func TestHandshakClientSCTs(t *testing.T) {
  1281  	config := testConfig.Clone()
  1282  
  1283  	scts, err := base64.StdEncoding.DecodeString(sctsBase64)
  1284  	if err != nil {
  1285  		t.Fatal(err)
  1286  	}
  1287  
  1288  	// Note that this needs OpenSSL 1.0.2 because that is the first
  1289  	// version that supports the -serverinfo flag.
  1290  	test := &clientTest{
  1291  		name:       "SCT",
  1292  		config:     config,
  1293  		extensions: [][]byte{scts},
  1294  		validate: func(state ConnectionState) error {
  1295  			expectedSCTs := [][]byte{
  1296  				scts[8:125],
  1297  				scts[127:245],
  1298  				scts[247:],
  1299  			}
  1300  			if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) {
  1301  				return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs))
  1302  			}
  1303  			for i, expected := range expectedSCTs {
  1304  				if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) {
  1305  					return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected)
  1306  				}
  1307  			}
  1308  			return nil
  1309  		},
  1310  	}
  1311  	runClientTestTLS12(t, test)
  1312  
  1313  	// TLS 1.3 moved SCTs to the Certificate extensions and -serverinfo only
  1314  	// supports ServerHello extensions.
  1315  }
  1316  
  1317  func TestRenegotiationRejected(t *testing.T) {
  1318  	config := testConfig.Clone()
  1319  	test := &clientTest{
  1320  		name:                        "RenegotiationRejected",
  1321  		args:                        []string{"-state"},
  1322  		config:                      config,
  1323  		numRenegotiations:           1,
  1324  		renegotiationExpectedToFail: 1,
  1325  		checkRenegotiationError: func(renegotiationNum int, err error) error {
  1326  			if err == nil {
  1327  				return errors.New("expected error from renegotiation but got nil")
  1328  			}
  1329  			if !strings.Contains(err.Error(), "no renegotiation") {
  1330  				return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
  1331  			}
  1332  			return nil
  1333  		},
  1334  	}
  1335  	runClientTestTLS12(t, test)
  1336  }
  1337  
  1338  func TestRenegotiateOnce(t *testing.T) {
  1339  	config := testConfig.Clone()
  1340  	config.Renegotiation = RenegotiateOnceAsClient
  1341  
  1342  	test := &clientTest{
  1343  		name:              "RenegotiateOnce",
  1344  		args:              []string{"-state"},
  1345  		config:            config,
  1346  		numRenegotiations: 1,
  1347  	}
  1348  
  1349  	runClientTestTLS12(t, test)
  1350  }
  1351  
  1352  func TestRenegotiateTwice(t *testing.T) {
  1353  	config := testConfig.Clone()
  1354  	config.Renegotiation = RenegotiateFreelyAsClient
  1355  
  1356  	test := &clientTest{
  1357  		name:              "RenegotiateTwice",
  1358  		args:              []string{"-state"},
  1359  		config:            config,
  1360  		numRenegotiations: 2,
  1361  	}
  1362  
  1363  	runClientTestTLS12(t, test)
  1364  }
  1365  
  1366  func TestRenegotiateTwiceRejected(t *testing.T) {
  1367  	config := testConfig.Clone()
  1368  	config.Renegotiation = RenegotiateOnceAsClient
  1369  
  1370  	test := &clientTest{
  1371  		name:                        "RenegotiateTwiceRejected",
  1372  		args:                        []string{"-state"},
  1373  		config:                      config,
  1374  		numRenegotiations:           2,
  1375  		renegotiationExpectedToFail: 2,
  1376  		checkRenegotiationError: func(renegotiationNum int, err error) error {
  1377  			if renegotiationNum == 1 {
  1378  				return err
  1379  			}
  1380  
  1381  			if err == nil {
  1382  				return errors.New("expected error from renegotiation but got nil")
  1383  			}
  1384  			if !strings.Contains(err.Error(), "no renegotiation") {
  1385  				return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
  1386  			}
  1387  			return nil
  1388  		},
  1389  	}
  1390  
  1391  	runClientTestTLS12(t, test)
  1392  }
  1393  
  1394  func TestHandshakeClientExportKeyingMaterial(t *testing.T) {
  1395  	test := &clientTest{
  1396  		name:   "ExportKeyingMaterial",
  1397  		config: testConfig.Clone(),
  1398  		validate: func(state ConnectionState) error {
  1399  			if km, err := state.ExportKeyingMaterial("test", nil, 42); err != nil {
  1400  				return fmt.Errorf("ExportKeyingMaterial failed: %v", err)
  1401  			} else if len(km) != 42 {
  1402  				return fmt.Errorf("Got %d bytes from ExportKeyingMaterial, wanted %d", len(km), 42)
  1403  			}
  1404  			return nil
  1405  		},
  1406  	}
  1407  	runClientTestTLS10(t, test)
  1408  	runClientTestTLS12(t, test)
  1409  	runClientTestTLS13(t, test)
  1410  }
  1411  
  1412  var hostnameInSNITests = []struct {
  1413  	in, out string
  1414  }{
  1415  	// Opaque string
  1416  	{"", ""},
  1417  	{"localhost", "localhost"},
  1418  	{"foo, bar, baz and qux", "foo, bar, baz and qux"},
  1419  
  1420  	// DNS hostname
  1421  	{"golang.org", "golang.org"},
  1422  	{"golang.org.", "golang.org"},
  1423  
  1424  	// Literal IPv4 address
  1425  	{"1.2.3.4", ""},
  1426  
  1427  	// Literal IPv6 address
  1428  	{"::1", ""},
  1429  	{"::1%lo0", ""}, // with zone identifier
  1430  	{"[::1]", ""},   // as per RFC 5952 we allow the [] style as IPv6 literal
  1431  	{"[::1%lo0]", ""},
  1432  }
  1433  
  1434  func TestHostnameInSNI(t *testing.T) {
  1435  	for _, tt := range hostnameInSNITests {
  1436  		c, s := localPipe(t)
  1437  
  1438  		go func(host string) {
  1439  			Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake()
  1440  		}(tt.in)
  1441  
  1442  		var header [5]byte
  1443  		if _, err := io.ReadFull(s, header[:]); err != nil {
  1444  			t.Fatal(err)
  1445  		}
  1446  		recordLen := int(header[3])<<8 | int(header[4])
  1447  
  1448  		record := make([]byte, recordLen)
  1449  		if _, err := io.ReadFull(s, record[:]); err != nil {
  1450  			t.Fatal(err)
  1451  		}
  1452  
  1453  		c.Close()
  1454  		s.Close()
  1455  
  1456  		var m clientHelloMsg
  1457  		if !m.unmarshal(record) {
  1458  			t.Errorf("unmarshaling ClientHello for %q failed", tt.in)
  1459  			continue
  1460  		}
  1461  		if tt.in != tt.out && m.serverName == tt.in {
  1462  			t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record)
  1463  		}
  1464  		if m.serverName != tt.out {
  1465  			t.Errorf("expected %q not found in ClientHello: %x", tt.out, record)
  1466  		}
  1467  	}
  1468  }
  1469  
  1470  func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
  1471  	// This checks that the server can't select a cipher suite that the
  1472  	// client didn't offer. See #13174.
  1473  
  1474  	c, s := localPipe(t)
  1475  	errChan := make(chan error, 1)
  1476  
  1477  	go func() {
  1478  		client := Client(c, &Config{
  1479  			ServerName:   "foo",
  1480  			CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
  1481  		})
  1482  		errChan <- client.Handshake()
  1483  	}()
  1484  
  1485  	var header [5]byte
  1486  	if _, err := io.ReadFull(s, header[:]); err != nil {
  1487  		t.Fatal(err)
  1488  	}
  1489  	recordLen := int(header[3])<<8 | int(header[4])
  1490  
  1491  	record := make([]byte, recordLen)
  1492  	if _, err := io.ReadFull(s, record); err != nil {
  1493  		t.Fatal(err)
  1494  	}
  1495  
  1496  	// Create a ServerHello that selects a different cipher suite than the
  1497  	// sole one that the client offered.
  1498  	serverHello := &serverHelloMsg{
  1499  		vers:        VersionTLS12,
  1500  		random:      make([]byte, 32),
  1501  		cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384,
  1502  	}
  1503  	serverHelloBytes := serverHello.marshal()
  1504  
  1505  	s.Write([]byte{
  1506  		byte(recordTypeHandshake),
  1507  		byte(VersionTLS12 >> 8),
  1508  		byte(VersionTLS12 & 0xff),
  1509  		byte(len(serverHelloBytes) >> 8),
  1510  		byte(len(serverHelloBytes)),
  1511  	})
  1512  	s.Write(serverHelloBytes)
  1513  	s.Close()
  1514  
  1515  	if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") {
  1516  		t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
  1517  	}
  1518  }
  1519  
  1520  func TestVerifyConnection(t *testing.T) {
  1521  	t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) })
  1522  	t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) })
  1523  }
  1524  
  1525  func testVerifyConnection(t *testing.T, version uint16) {
  1526  	checkFields := func(c ConnectionState, called *int, errorType string) error {
  1527  		if c.Version != version {
  1528  			return fmt.Errorf("%s: got Version %v, want %v", errorType, c.Version, version)
  1529  		}
  1530  		if c.HandshakeComplete {
  1531  			return fmt.Errorf("%s: got HandshakeComplete, want false", errorType)
  1532  		}
  1533  		if c.ServerName != "example.golang" {
  1534  			return fmt.Errorf("%s: got ServerName %s, want %s", errorType, c.ServerName, "example.golang")
  1535  		}
  1536  		if c.NegotiatedProtocol != "protocol1" {
  1537  			return fmt.Errorf("%s: got NegotiatedProtocol %s, want %s", errorType, c.NegotiatedProtocol, "protocol1")
  1538  		}
  1539  		if c.CipherSuite == 0 {
  1540  			return fmt.Errorf("%s: got CipherSuite 0, want non-zero", errorType)
  1541  		}
  1542  		wantDidResume := false
  1543  		if *called == 2 { // if this is the second time, then it should be a resumption
  1544  			wantDidResume = true
  1545  		}
  1546  		if c.DidResume != wantDidResume {
  1547  			return fmt.Errorf("%s: got DidResume %t, want %t", errorType, c.DidResume, wantDidResume)
  1548  		}
  1549  		return nil
  1550  	}
  1551  
  1552  	tests := []struct {
  1553  		name            string
  1554  		configureServer func(*Config, *int)
  1555  		configureClient func(*Config, *int)
  1556  	}{
  1557  		{
  1558  			name: "RequireAndVerifyClientCert",
  1559  			configureServer: func(config *Config, called *int) {
  1560  				config.ClientAuth = RequireAndVerifyClientCert
  1561  				config.VerifyConnection = func(c ConnectionState) error {
  1562  					*called++
  1563  					if l := len(c.PeerCertificates); l != 1 {
  1564  						return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
  1565  					}
  1566  					if len(c.VerifiedChains) == 0 {
  1567  						return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero")
  1568  					}
  1569  					return checkFields(c, called, "server")
  1570  				}
  1571  			},
  1572  			configureClient: func(config *Config, called *int) {
  1573  				config.VerifyConnection = func(c ConnectionState) error {
  1574  					*called++
  1575  					if l := len(c.PeerCertificates); l != 1 {
  1576  						return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
  1577  					}
  1578  					if len(c.VerifiedChains) == 0 {
  1579  						return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
  1580  					}
  1581  					if c.DidResume {
  1582  						return nil
  1583  						// The SCTs and OCSP Response are dropped on resumption.
  1584  						// See http://golang.org/issue/39075.
  1585  					}
  1586  					if len(c.OCSPResponse) == 0 {
  1587  						return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
  1588  					}
  1589  					if len(c.SignedCertificateTimestamps) == 0 {
  1590  						return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
  1591  					}
  1592  					return checkFields(c, called, "client")
  1593  				}
  1594  			},
  1595  		},
  1596  		{
  1597  			name: "InsecureSkipVerify",
  1598  			configureServer: func(config *Config, called *int) {
  1599  				config.ClientAuth = RequireAnyClientCert
  1600  				config.InsecureSkipVerify = true
  1601  				config.VerifyConnection = func(c ConnectionState) error {
  1602  					*called++
  1603  					if l := len(c.PeerCertificates); l != 1 {
  1604  						return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
  1605  					}
  1606  					if c.VerifiedChains != nil {
  1607  						return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
  1608  					}
  1609  					return checkFields(c, called, "server")
  1610  				}
  1611  			},
  1612  			configureClient: func(config *Config, called *int) {
  1613  				config.InsecureSkipVerify = true
  1614  				config.VerifyConnection = func(c ConnectionState) error {
  1615  					*called++
  1616  					if l := len(c.PeerCertificates); l != 1 {
  1617  						return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
  1618  					}
  1619  					if c.VerifiedChains != nil {
  1620  						return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
  1621  					}
  1622  					if c.DidResume {
  1623  						return nil
  1624  						// The SCTs and OCSP Response are dropped on resumption.
  1625  						// See http://golang.org/issue/39075.
  1626  					}
  1627  					if len(c.OCSPResponse) == 0 {
  1628  						return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
  1629  					}
  1630  					if len(c.SignedCertificateTimestamps) == 0 {
  1631  						return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
  1632  					}
  1633  					return checkFields(c, called, "client")
  1634  				}
  1635  			},
  1636  		},
  1637  		{
  1638  			name: "NoClientCert",
  1639  			configureServer: func(config *Config, called *int) {
  1640  				config.ClientAuth = NoClientCert
  1641  				config.VerifyConnection = func(c ConnectionState) error {
  1642  					*called++
  1643  					return checkFields(c, called, "server")
  1644  				}
  1645  			},
  1646  			configureClient: func(config *Config, called *int) {
  1647  				config.VerifyConnection = func(c ConnectionState) error {
  1648  					*called++
  1649  					return checkFields(c, called, "client")
  1650  				}
  1651  			},
  1652  		},
  1653  		{
  1654  			name: "RequestClientCert",
  1655  			configureServer: func(config *Config, called *int) {
  1656  				config.ClientAuth = RequestClientCert
  1657  				config.VerifyConnection = func(c ConnectionState) error {
  1658  					*called++
  1659  					return checkFields(c, called, "server")
  1660  				}
  1661  			},
  1662  			configureClient: func(config *Config, called *int) {
  1663  				config.Certificates = nil // clear the client cert
  1664  				config.VerifyConnection = func(c ConnectionState) error {
  1665  					*called++
  1666  					if l := len(c.PeerCertificates); l != 1 {
  1667  						return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
  1668  					}
  1669  					if len(c.VerifiedChains) == 0 {
  1670  						return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
  1671  					}
  1672  					if c.DidResume {
  1673  						return nil
  1674  						// The SCTs and OCSP Response are dropped on resumption.
  1675  						// See http://golang.org/issue/39075.
  1676  					}
  1677  					if len(c.OCSPResponse) == 0 {
  1678  						return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
  1679  					}
  1680  					if len(c.SignedCertificateTimestamps) == 0 {
  1681  						return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
  1682  					}
  1683  					return checkFields(c, called, "client")
  1684  				}
  1685  			},
  1686  		},
  1687  	}
  1688  	for _, test := range tests {
  1689  		issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
  1690  		if err != nil {
  1691  			panic(err)
  1692  		}
  1693  		rootCAs := x509.NewCertPool()
  1694  		rootCAs.AddCert(issuer)
  1695  
  1696  		var serverCalled, clientCalled int
  1697  
  1698  		serverConfig := &Config{
  1699  			MaxVersion:   version,
  1700  			Certificates: []Certificate{testConfig.Certificates[0]},
  1701  			ClientCAs:    rootCAs,
  1702  			NextProtos:   []string{"protocol1"},
  1703  		}
  1704  		serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
  1705  		serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp")
  1706  		test.configureServer(serverConfig, &serverCalled)
  1707  
  1708  		clientConfig := &Config{
  1709  			MaxVersion:         version,
  1710  			ClientSessionCache: NewLRUClientSessionCache(32),
  1711  			RootCAs:            rootCAs,
  1712  			ServerName:         "example.golang",
  1713  			Certificates:       []Certificate{testConfig.Certificates[0]},
  1714  			NextProtos:         []string{"protocol1"},
  1715  		}
  1716  		test.configureClient(clientConfig, &clientCalled)
  1717  
  1718  		testHandshakeState := func(name string, didResume bool) {
  1719  			_, hs, err := testHandshake(t, clientConfig, serverConfig)
  1720  			if err != nil {
  1721  				t.Fatalf("%s: handshake failed: %s", name, err)
  1722  			}
  1723  			if hs.DidResume != didResume {
  1724  				t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume)
  1725  			}
  1726  			wantCalled := 1
  1727  			if didResume {
  1728  				wantCalled = 2 // resumption would mean this is the second time it was called in this test
  1729  			}
  1730  			if clientCalled != wantCalled {
  1731  				t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled)
  1732  			}
  1733  			if serverCalled != wantCalled {
  1734  				t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled)
  1735  			}
  1736  		}
  1737  		testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false)
  1738  		testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true)
  1739  	}
  1740  }
  1741  
  1742  func TestVerifyPeerCertificate(t *testing.T) {
  1743  	t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) })
  1744  	t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) })
  1745  }
  1746  
  1747  func testVerifyPeerCertificate(t *testing.T, version uint16) {
  1748  	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
  1749  	if err != nil {
  1750  		panic(err)
  1751  	}
  1752  
  1753  	rootCAs := x509.NewCertPool()
  1754  	rootCAs.AddCert(issuer)
  1755  
  1756  	now := func() time.Time { return time.Unix(1476984729, 0) }
  1757  
  1758  	sentinelErr := errors.New("TestVerifyPeerCertificate")
  1759  
  1760  	verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1761  		if l := len(rawCerts); l != 1 {
  1762  			return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
  1763  		}
  1764  		if len(validatedChains) == 0 {
  1765  			return errors.New("got len(validatedChains) = 0, wanted non-zero")
  1766  		}
  1767  		*called = true
  1768  		return nil
  1769  	}
  1770  	verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error {
  1771  		if l := len(c.PeerCertificates); l != 1 {
  1772  			return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l)
  1773  		}
  1774  		if len(c.VerifiedChains) == 0 {
  1775  			return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero")
  1776  		}
  1777  		if isClient && len(c.OCSPResponse) == 0 {
  1778  			return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero")
  1779  		}
  1780  		*called = true
  1781  		return nil
  1782  	}
  1783  
  1784  	tests := []struct {
  1785  		configureServer func(*Config, *bool)
  1786  		configureClient func(*Config, *bool)
  1787  		validate        func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error)
  1788  	}{
  1789  		{
  1790  			configureServer: func(config *Config, called *bool) {
  1791  				config.InsecureSkipVerify = false
  1792  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1793  					return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
  1794  				}
  1795  			},
  1796  			configureClient: func(config *Config, called *bool) {
  1797  				config.InsecureSkipVerify = false
  1798  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1799  					return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
  1800  				}
  1801  			},
  1802  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1803  				if clientErr != nil {
  1804  					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
  1805  				}
  1806  				if serverErr != nil {
  1807  					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
  1808  				}
  1809  				if !clientCalled {
  1810  					t.Errorf("test[%d]: client did not call callback", testNo)
  1811  				}
  1812  				if !serverCalled {
  1813  					t.Errorf("test[%d]: server did not call callback", testNo)
  1814  				}
  1815  			},
  1816  		},
  1817  		{
  1818  			configureServer: func(config *Config, called *bool) {
  1819  				config.InsecureSkipVerify = false
  1820  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1821  					return sentinelErr
  1822  				}
  1823  			},
  1824  			configureClient: func(config *Config, called *bool) {
  1825  				config.VerifyPeerCertificate = nil
  1826  			},
  1827  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1828  				if serverErr != sentinelErr {
  1829  					t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
  1830  				}
  1831  			},
  1832  		},
  1833  		{
  1834  			configureServer: func(config *Config, called *bool) {
  1835  				config.InsecureSkipVerify = false
  1836  			},
  1837  			configureClient: func(config *Config, called *bool) {
  1838  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1839  					return sentinelErr
  1840  				}
  1841  			},
  1842  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1843  				if clientErr != sentinelErr {
  1844  					t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
  1845  				}
  1846  			},
  1847  		},
  1848  		{
  1849  			configureServer: func(config *Config, called *bool) {
  1850  				config.InsecureSkipVerify = false
  1851  			},
  1852  			configureClient: func(config *Config, called *bool) {
  1853  				config.InsecureSkipVerify = true
  1854  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1855  					if l := len(rawCerts); l != 1 {
  1856  						return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
  1857  					}
  1858  					// With InsecureSkipVerify set, this
  1859  					// callback should still be called but
  1860  					// validatedChains must be empty.
  1861  					if l := len(validatedChains); l != 0 {
  1862  						return fmt.Errorf("got len(validatedChains) = %d, wanted zero", l)
  1863  					}
  1864  					*called = true
  1865  					return nil
  1866  				}
  1867  			},
  1868  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1869  				if clientErr != nil {
  1870  					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
  1871  				}
  1872  				if serverErr != nil {
  1873  					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
  1874  				}
  1875  				if !clientCalled {
  1876  					t.Errorf("test[%d]: client did not call callback", testNo)
  1877  				}
  1878  			},
  1879  		},
  1880  		{
  1881  			configureServer: func(config *Config, called *bool) {
  1882  				config.InsecureSkipVerify = false
  1883  				config.VerifyConnection = func(c ConnectionState) error {
  1884  					return verifyConnectionCallback(called, false, c)
  1885  				}
  1886  			},
  1887  			configureClient: func(config *Config, called *bool) {
  1888  				config.InsecureSkipVerify = false
  1889  				config.VerifyConnection = func(c ConnectionState) error {
  1890  					return verifyConnectionCallback(called, true, c)
  1891  				}
  1892  			},
  1893  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1894  				if clientErr != nil {
  1895  					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
  1896  				}
  1897  				if serverErr != nil {
  1898  					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
  1899  				}
  1900  				if !clientCalled {
  1901  					t.Errorf("test[%d]: client did not call callback", testNo)
  1902  				}
  1903  				if !serverCalled {
  1904  					t.Errorf("test[%d]: server did not call callback", testNo)
  1905  				}
  1906  			},
  1907  		},
  1908  		{
  1909  			configureServer: func(config *Config, called *bool) {
  1910  				config.InsecureSkipVerify = false
  1911  				config.VerifyConnection = func(c ConnectionState) error {
  1912  					return sentinelErr
  1913  				}
  1914  			},
  1915  			configureClient: func(config *Config, called *bool) {
  1916  				config.InsecureSkipVerify = false
  1917  				config.VerifyConnection = nil
  1918  			},
  1919  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1920  				if serverErr != sentinelErr {
  1921  					t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
  1922  				}
  1923  			},
  1924  		},
  1925  		{
  1926  			configureServer: func(config *Config, called *bool) {
  1927  				config.InsecureSkipVerify = false
  1928  				config.VerifyConnection = nil
  1929  			},
  1930  			configureClient: func(config *Config, called *bool) {
  1931  				config.InsecureSkipVerify = false
  1932  				config.VerifyConnection = func(c ConnectionState) error {
  1933  					return sentinelErr
  1934  				}
  1935  			},
  1936  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1937  				if clientErr != sentinelErr {
  1938  					t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
  1939  				}
  1940  			},
  1941  		},
  1942  		{
  1943  			configureServer: func(config *Config, called *bool) {
  1944  				config.InsecureSkipVerify = false
  1945  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1946  					return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
  1947  				}
  1948  				config.VerifyConnection = func(c ConnectionState) error {
  1949  					return sentinelErr
  1950  				}
  1951  			},
  1952  			configureClient: func(config *Config, called *bool) {
  1953  				config.InsecureSkipVerify = false
  1954  				config.VerifyPeerCertificate = nil
  1955  				config.VerifyConnection = nil
  1956  			},
  1957  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1958  				if serverErr != sentinelErr {
  1959  					t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
  1960  				}
  1961  				if !serverCalled {
  1962  					t.Errorf("test[%d]: server did not call callback", testNo)
  1963  				}
  1964  			},
  1965  		},
  1966  		{
  1967  			configureServer: func(config *Config, called *bool) {
  1968  				config.InsecureSkipVerify = false
  1969  				config.VerifyPeerCertificate = nil
  1970  				config.VerifyConnection = nil
  1971  			},
  1972  			configureClient: func(config *Config, called *bool) {
  1973  				config.InsecureSkipVerify = false
  1974  				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
  1975  					return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
  1976  				}
  1977  				config.VerifyConnection = func(c ConnectionState) error {
  1978  					return sentinelErr
  1979  				}
  1980  			},
  1981  			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
  1982  				if clientErr != sentinelErr {
  1983  					t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
  1984  				}
  1985  				if !clientCalled {
  1986  					t.Errorf("test[%d]: client did not call callback", testNo)
  1987  				}
  1988  			},
  1989  		},
  1990  	}
  1991  
  1992  	for i, test := range tests {
  1993  		c, s := localPipe(t)
  1994  		done := make(chan error)
  1995  
  1996  		var clientCalled, serverCalled bool
  1997  
  1998  		go func() {
  1999  			config := testConfig.Clone()
  2000  			config.ServerName = "example.golang"
  2001  			config.ClientAuth = RequireAndVerifyClientCert
  2002  			config.ClientCAs = rootCAs
  2003  			config.Time = now
  2004  			config.MaxVersion = version
  2005  			config.Certificates = make([]Certificate, 1)
  2006  			config.Certificates[0].Certificate = [][]byte{testRSACertificate}
  2007  			config.Certificates[0].PrivateKey = testRSAPrivateKey
  2008  			config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
  2009  			config.Certificates[0].OCSPStaple = []byte("dummy ocsp")
  2010  			test.configureServer(config, &serverCalled)
  2011  
  2012  			err = Server(s, config).Handshake()
  2013  			s.Close()
  2014  			done <- err
  2015  		}()
  2016  
  2017  		config := testConfig.Clone()
  2018  		config.ServerName = "example.golang"
  2019  		config.RootCAs = rootCAs
  2020  		config.Time = now
  2021  		config.MaxVersion = version
  2022  		test.configureClient(config, &clientCalled)
  2023  		clientErr := Client(c, config).Handshake()
  2024  		c.Close()
  2025  		serverErr := <-done
  2026  
  2027  		test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr)
  2028  	}
  2029  }
  2030  
  2031  // brokenConn wraps a net.Conn and causes all Writes after a certain number to
  2032  // fail with brokenConnErr.
  2033  type brokenConn struct {
  2034  	net.Conn
  2035  
  2036  	// breakAfter is the number of successful writes that will be allowed
  2037  	// before all subsequent writes fail.
  2038  	breakAfter int
  2039  
  2040  	// numWrites is the number of writes that have been done.
  2041  	numWrites int
  2042  }
  2043  
  2044  // brokenConnErr is the error that brokenConn returns once exhausted.
  2045  var brokenConnErr = errors.New("too many writes to brokenConn")
  2046  
  2047  func (b *brokenConn) Write(data []byte) (int, error) {
  2048  	if b.numWrites >= b.breakAfter {
  2049  		return 0, brokenConnErr
  2050  	}
  2051  
  2052  	b.numWrites++
  2053  	return b.Conn.Write(data)
  2054  }
  2055  
  2056  func TestFailedWrite(t *testing.T) {
  2057  	// Test that a write error during the handshake is returned.
  2058  	for _, breakAfter := range []int{0, 1} {
  2059  		c, s := localPipe(t)
  2060  		done := make(chan bool)
  2061  
  2062  		go func() {
  2063  			Server(s, testConfig).Handshake()
  2064  			s.Close()
  2065  			done <- true
  2066  		}()
  2067  
  2068  		brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
  2069  		err := Client(brokenC, testConfig).Handshake()
  2070  		if err != brokenConnErr {
  2071  			t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
  2072  		}
  2073  		brokenC.Close()
  2074  
  2075  		<-done
  2076  	}
  2077  }
  2078  
  2079  // writeCountingConn wraps a net.Conn and counts the number of Write calls.
  2080  type writeCountingConn struct {
  2081  	net.Conn
  2082  
  2083  	// numWrites is the number of writes that have been done.
  2084  	numWrites int
  2085  }
  2086  
  2087  func (wcc *writeCountingConn) Write(data []byte) (int, error) {
  2088  	wcc.numWrites++
  2089  	return wcc.Conn.Write(data)
  2090  }
  2091  
  2092  func TestBuffering(t *testing.T) {
  2093  	t.Run("TLSv12", func(t *testing.T) { testBuffering(t, VersionTLS12) })
  2094  	t.Run("TLSv13", func(t *testing.T) { testBuffering(t, VersionTLS13) })
  2095  }
  2096  
  2097  func testBuffering(t *testing.T, version uint16) {
  2098  	c, s := localPipe(t)
  2099  	done := make(chan bool)
  2100  
  2101  	clientWCC := &writeCountingConn{Conn: c}
  2102  	serverWCC := &writeCountingConn{Conn: s}
  2103  
  2104  	go func() {
  2105  		config := testConfig.Clone()
  2106  		config.MaxVersion = version
  2107  		Server(serverWCC, config).Handshake()
  2108  		serverWCC.Close()
  2109  		done <- true
  2110  	}()
  2111  
  2112  	err := Client(clientWCC, testConfig).Handshake()
  2113  	if err != nil {
  2114  		t.Fatal(err)
  2115  	}
  2116  	clientWCC.Close()
  2117  	<-done
  2118  
  2119  	var expectedClient, expectedServer int
  2120  	if version == VersionTLS13 {
  2121  		expectedClient = 2
  2122  		expectedServer = 1
  2123  	} else {
  2124  		expectedClient = 2
  2125  		expectedServer = 2
  2126  	}
  2127  
  2128  	if n := clientWCC.numWrites; n != expectedClient {
  2129  		t.Errorf("expected client handshake to complete with %d writes, but saw %d", expectedClient, n)
  2130  	}
  2131  
  2132  	if n := serverWCC.numWrites; n != expectedServer {
  2133  		t.Errorf("expected server handshake to complete with %d writes, but saw %d", expectedServer, n)
  2134  	}
  2135  }
  2136  
  2137  func TestAlertFlushing(t *testing.T) {
  2138  	c, s := localPipe(t)
  2139  	done := make(chan bool)
  2140  
  2141  	clientWCC := &writeCountingConn{Conn: c}
  2142  	serverWCC := &writeCountingConn{Conn: s}
  2143  
  2144  	serverConfig := testConfig.Clone()
  2145  
  2146  	// Cause a signature-time error
  2147  	brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey}
  2148  	brokenKey.D = big.NewInt(42)
  2149  	serverConfig.Certificates = []Certificate{{
  2150  		Certificate: [][]byte{testRSACertificate},
  2151  		PrivateKey:  &brokenKey,
  2152  	}}
  2153  
  2154  	go func() {
  2155  		Server(serverWCC, serverConfig).Handshake()
  2156  		serverWCC.Close()
  2157  		done <- true
  2158  	}()
  2159  
  2160  	err := Client(clientWCC, testConfig).Handshake()
  2161  	if err == nil {
  2162  		t.Fatal("client unexpectedly returned no error")
  2163  	}
  2164  
  2165  	const expectedError = "remote error: tls: internal error"
  2166  	if e := err.Error(); !strings.Contains(e, expectedError) {
  2167  		t.Fatalf("expected to find %q in error but error was %q", expectedError, e)
  2168  	}
  2169  	clientWCC.Close()
  2170  	<-done
  2171  
  2172  	if n := serverWCC.numWrites; n != 1 {
  2173  		t.Errorf("expected server handshake to complete with one write, but saw %d", n)
  2174  	}
  2175  }
  2176  
  2177  func TestHandshakeRace(t *testing.T) {
  2178  	if testing.Short() {
  2179  		t.Skip("skipping in -short mode")
  2180  	}
  2181  	t.Parallel()
  2182  	// This test races a Read and Write to try and complete a handshake in
  2183  	// order to provide some evidence that there are no races or deadlocks
  2184  	// in the handshake locking.
  2185  	for i := 0; i < 32; i++ {
  2186  		c, s := localPipe(t)
  2187  
  2188  		go func() {
  2189  			server := Server(s, testConfig)
  2190  			if err := server.Handshake(); err != nil {
  2191  				panic(err)
  2192  			}
  2193  
  2194  			var request [1]byte
  2195  			if n, err := server.Read(request[:]); err != nil || n != 1 {
  2196  				panic(err)
  2197  			}
  2198  
  2199  			server.Write(request[:])
  2200  			server.Close()
  2201  		}()
  2202  
  2203  		startWrite := make(chan struct{})
  2204  		startRead := make(chan struct{})
  2205  		readDone := make(chan struct{}, 1)
  2206  
  2207  		client := Client(c, testConfig)
  2208  		go func() {
  2209  			<-startWrite
  2210  			var request [1]byte
  2211  			client.Write(request[:])
  2212  		}()
  2213  
  2214  		go func() {
  2215  			<-startRead
  2216  			var reply [1]byte
  2217  			if _, err := io.ReadFull(client, reply[:]); err != nil {
  2218  				panic(err)
  2219  			}
  2220  			c.Close()
  2221  			readDone <- struct{}{}
  2222  		}()
  2223  
  2224  		if i&1 == 1 {
  2225  			startWrite <- struct{}{}
  2226  			startRead <- struct{}{}
  2227  		} else {
  2228  			startRead <- struct{}{}
  2229  			startWrite <- struct{}{}
  2230  		}
  2231  		<-readDone
  2232  	}
  2233  }
  2234  
  2235  var getClientCertificateTests = []struct {
  2236  	setup               func(*Config, *Config)
  2237  	expectedClientError string
  2238  	verify              func(*testing.T, int, *ConnectionState)
  2239  }{
  2240  	{
  2241  		func(clientConfig, serverConfig *Config) {
  2242  			// Returning a Certificate with no certificate data
  2243  			// should result in an empty message being sent to the
  2244  			// server.
  2245  			serverConfig.ClientCAs = nil
  2246  			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  2247  				if len(cri.SignatureSchemes) == 0 {
  2248  					panic("empty SignatureSchemes")
  2249  				}
  2250  				if len(cri.AcceptableCAs) != 0 {
  2251  					panic("AcceptableCAs should have been empty")
  2252  				}
  2253  				return new(Certificate), nil
  2254  			}
  2255  		},
  2256  		"",
  2257  		func(t *testing.T, testNum int, cs *ConnectionState) {
  2258  			if l := len(cs.PeerCertificates); l != 0 {
  2259  				t.Errorf("#%d: expected no certificates but got %d", testNum, l)
  2260  			}
  2261  		},
  2262  	},
  2263  	{
  2264  		func(clientConfig, serverConfig *Config) {
  2265  			// With TLS 1.1, the SignatureSchemes should be
  2266  			// synthesised from the supported certificate types.
  2267  			clientConfig.MaxVersion = VersionTLS11
  2268  			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  2269  				if len(cri.SignatureSchemes) == 0 {
  2270  					panic("empty SignatureSchemes")
  2271  				}
  2272  				return new(Certificate), nil
  2273  			}
  2274  		},
  2275  		"",
  2276  		func(t *testing.T, testNum int, cs *ConnectionState) {
  2277  			if l := len(cs.PeerCertificates); l != 0 {
  2278  				t.Errorf("#%d: expected no certificates but got %d", testNum, l)
  2279  			}
  2280  		},
  2281  	},
  2282  	{
  2283  		func(clientConfig, serverConfig *Config) {
  2284  			// Returning an error should abort the handshake with
  2285  			// that error.
  2286  			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  2287  				return nil, errors.New("GetClientCertificate")
  2288  			}
  2289  		},
  2290  		"GetClientCertificate",
  2291  		func(t *testing.T, testNum int, cs *ConnectionState) {
  2292  		},
  2293  	},
  2294  	{
  2295  		func(clientConfig, serverConfig *Config) {
  2296  			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
  2297  				if len(cri.AcceptableCAs) == 0 {
  2298  					panic("empty AcceptableCAs")
  2299  				}
  2300  				cert := &Certificate{
  2301  					Certificate: [][]byte{testRSACertificate},
  2302  					PrivateKey:  testRSAPrivateKey,
  2303  				}
  2304  				return cert, nil
  2305  			}
  2306  		},
  2307  		"",
  2308  		func(t *testing.T, testNum int, cs *ConnectionState) {
  2309  			if len(cs.VerifiedChains) == 0 {
  2310  				t.Errorf("#%d: expected some verified chains, but found none", testNum)
  2311  			}
  2312  		},
  2313  	},
  2314  }
  2315  
  2316  func TestGetClientCertificate(t *testing.T) {
  2317  	t.Run("TLSv12", func(t *testing.T) { testGetClientCertificate(t, VersionTLS12) })
  2318  	t.Run("TLSv13", func(t *testing.T) { testGetClientCertificate(t, VersionTLS13) })
  2319  }
  2320  
  2321  func testGetClientCertificate(t *testing.T, version uint16) {
  2322  	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
  2323  	if err != nil {
  2324  		panic(err)
  2325  	}
  2326  
  2327  	for i, test := range getClientCertificateTests {
  2328  		serverConfig := testConfig.Clone()
  2329  		serverConfig.ClientAuth = VerifyClientCertIfGiven
  2330  		serverConfig.RootCAs = x509.NewCertPool()
  2331  		serverConfig.RootCAs.AddCert(issuer)
  2332  		serverConfig.ClientCAs = serverConfig.RootCAs
  2333  		serverConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
  2334  		serverConfig.MaxVersion = version
  2335  
  2336  		clientConfig := testConfig.Clone()
  2337  		clientConfig.MaxVersion = version
  2338  
  2339  		test.setup(clientConfig, serverConfig)
  2340  
  2341  		type serverResult struct {
  2342  			cs  ConnectionState
  2343  			err error
  2344  		}
  2345  
  2346  		c, s := localPipe(t)
  2347  		done := make(chan serverResult)
  2348  
  2349  		go func() {
  2350  			defer s.Close()
  2351  			server := Server(s, serverConfig)
  2352  			err := server.Handshake()
  2353  
  2354  			var cs ConnectionState
  2355  			if err == nil {
  2356  				cs = server.ConnectionState()
  2357  			}
  2358  			done <- serverResult{cs, err}
  2359  		}()
  2360  
  2361  		clientErr := Client(c, clientConfig).Handshake()
  2362  		c.Close()
  2363  
  2364  		result := <-done
  2365  
  2366  		if clientErr != nil {
  2367  			if len(test.expectedClientError) == 0 {
  2368  				t.Errorf("#%d: client error: %v", i, clientErr)
  2369  			} else if got := clientErr.Error(); got != test.expectedClientError {
  2370  				t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got)
  2371  			} else {
  2372  				test.verify(t, i, &result.cs)
  2373  			}
  2374  		} else if len(test.expectedClientError) > 0 {
  2375  			t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError)
  2376  		} else if err := result.err; err != nil {
  2377  			t.Errorf("#%d: server error: %v", i, err)
  2378  		} else {
  2379  			test.verify(t, i, &result.cs)
  2380  		}
  2381  	}
  2382  }
  2383  
  2384  func TestRSAPSSKeyError(t *testing.T) {
  2385  	// crypto/tls does not support the rsa_pss_pss_* SignatureSchemes. If support for
  2386  	// public keys with OID RSASSA-PSS is added to crypto/x509, they will be misused with
  2387  	// the rsa_pss_rsae_* SignatureSchemes. Assert that RSASSA-PSS certificates don't
  2388  	// parse, or that they don't carry *rsa.PublicKey keys.
  2389  	b, _ := pem.Decode([]byte(`
  2390  -----BEGIN CERTIFICATE-----
  2391  MIIDZTCCAhygAwIBAgIUCF2x0FyTgZG0CC9QTDjGWkB5vgEwPgYJKoZIhvcNAQEK
  2392  MDGgDTALBglghkgBZQMEAgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQC
  2393  AgDeMBIxEDAOBgNVBAMMB1JTQS1QU1MwHhcNMTgwNjI3MjI0NDM2WhcNMTgwNzI3
  2394  MjI0NDM2WjASMRAwDgYDVQQDDAdSU0EtUFNTMIIBIDALBgkqhkiG9w0BAQoDggEP
  2395  ADCCAQoCggEBANxDm0f76JdI06YzsjB3AmmjIYkwUEGxePlafmIASFjDZl/elD0Z
  2396  /a7xLX468b0qGxLS5al7XCcEprSdsDR6DF5L520+pCbpfLyPOjuOvGmk9KzVX4x5
  2397  b05YXYuXdsQ0Kjxcx2i3jjCday6scIhMJVgBZxTEyMj1thPQM14SHzKCd/m6HmCL
  2398  QmswpH2yMAAcBRWzRpp/vdH5DeOJEB3aelq7094no731mrLUCHRiZ1htq8BDB3ou
  2399  czwqgwspbqZ4dnMXl2MvfySQ5wJUxQwILbiuAKO2lVVPUbFXHE9pgtznNoPvKwQT
  2400  JNcX8ee8WIZc2SEGzofjk3NpjR+2ADB2u3sCAwEAAaNTMFEwHQYDVR0OBBYEFNEz
  2401  AdyJ2f+fU+vSCS6QzohnOnprMB8GA1UdIwQYMBaAFNEzAdyJ2f+fU+vSCS6Qzohn
  2402  OnprMA8GA1UdEwEB/wQFMAMBAf8wPgYJKoZIhvcNAQEKMDGgDTALBglghkgBZQME
  2403  AgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQCAgDeA4IBAQCjEdrR5aab
  2404  sZmCwrMeKidXgfkmWvfuLDE+TCbaqDZp7BMWcMQXT9O0UoUT5kqgKj2ARm2pEW0Z
  2405  H3Z1vj3bbds72qcDIJXp+l0fekyLGeCrX/CbgnMZXEP7+/+P416p34ChR1Wz4dU1
  2406  KD3gdsUuTKKeMUog3plxlxQDhRQmiL25ygH1LmjLd6dtIt0GVRGr8lj3euVeprqZ
  2407  bZ3Uq5eLfsn8oPgfC57gpO6yiN+UURRTlK3bgYvLh4VWB3XXk9UaQZ7Mq1tpXjoD
  2408  HYFybkWzibkZp4WRo+Fa28rirH+/wHt0vfeN7UCceURZEx4JaxIIfe4ku7uDRhJi
  2409  RwBA9Xk1KBNF
  2410  -----END CERTIFICATE-----`))
  2411  	if b == nil {
  2412  		t.Fatal("Failed to decode certificate")
  2413  	}
  2414  	cert, err := x509.ParseCertificate(b.Bytes)
  2415  	if err != nil {
  2416  		return
  2417  	}
  2418  	if _, ok := cert.PublicKey.(*rsa.PublicKey); ok {
  2419  		t.Error("A RSASSA-PSS certificate was parsed like a PKCS#1 v1.5 one, and it will be mistakenly used with rsa_pss_rsae_* signature algorithms")
  2420  	}
  2421  }
  2422  
  2423  func TestCloseClientConnectionOnIdleServer(t *testing.T) {
  2424  	clientConn, serverConn := localPipe(t)
  2425  	client := Client(clientConn, testConfig.Clone())
  2426  	go func() {
  2427  		var b [1]byte
  2428  		serverConn.Read(b[:])
  2429  		client.Close()
  2430  	}()
  2431  	client.SetWriteDeadline(time.Now().Add(time.Minute))
  2432  	err := client.Handshake()
  2433  	if err != nil {
  2434  		if err, ok := err.(net.Error); ok && err.Timeout() {
  2435  			t.Errorf("Expected a closed network connection error but got '%s'", err.Error())
  2436  		}
  2437  	} else {
  2438  		t.Errorf("Error expected, but no error returned")
  2439  	}
  2440  }
  2441  
  2442  func testDowngradeCanary(t *testing.T, clientVersion, serverVersion uint16) error {
  2443  	defer func() { testingOnlyForceDowngradeCanary = false }()
  2444  	testingOnlyForceDowngradeCanary = true
  2445  
  2446  	clientConfig := testConfig.Clone()
  2447  	clientConfig.MaxVersion = clientVersion
  2448  	serverConfig := testConfig.Clone()
  2449  	serverConfig.MaxVersion = serverVersion
  2450  	_, _, err := testHandshake(t, clientConfig, serverConfig)
  2451  	return err
  2452  }
  2453  
  2454  func TestDowngradeCanary(t *testing.T) {
  2455  	if err := testDowngradeCanary(t, VersionTLS13, VersionTLS12); err == nil {
  2456  		t.Errorf("downgrade from TLS 1.3 to TLS 1.2 was not detected")
  2457  	}
  2458  	if testing.Short() {
  2459  		t.Skip("skipping the rest of the checks in short mode")
  2460  	}
  2461  	if err := testDowngradeCanary(t, VersionTLS13, VersionTLS11); err == nil {
  2462  		t.Errorf("downgrade from TLS 1.3 to TLS 1.1 was not detected")
  2463  	}
  2464  	if err := testDowngradeCanary(t, VersionTLS13, VersionTLS10); err == nil {
  2465  		t.Errorf("downgrade from TLS 1.3 to TLS 1.0 was not detected")
  2466  	}
  2467  	if err := testDowngradeCanary(t, VersionTLS12, VersionTLS11); err == nil {
  2468  		t.Errorf("downgrade from TLS 1.2 to TLS 1.1 was not detected")
  2469  	}
  2470  	if err := testDowngradeCanary(t, VersionTLS12, VersionTLS10); err == nil {
  2471  		t.Errorf("downgrade from TLS 1.2 to TLS 1.0 was not detected")
  2472  	}
  2473  	if err := testDowngradeCanary(t, VersionTLS13, VersionTLS13); err != nil {
  2474  		t.Errorf("server unexpectedly sent downgrade canary for TLS 1.3")
  2475  	}
  2476  	if err := testDowngradeCanary(t, VersionTLS12, VersionTLS12); err != nil {
  2477  		t.Errorf("client didn't ignore expected TLS 1.2 canary")
  2478  	}
  2479  	if err := testDowngradeCanary(t, VersionTLS11, VersionTLS11); err != nil {
  2480  		t.Errorf("client unexpectedly reacted to a canary in TLS 1.1")
  2481  	}
  2482  	if err := testDowngradeCanary(t, VersionTLS10, VersionTLS10); err != nil {
  2483  		t.Errorf("client unexpectedly reacted to a canary in TLS 1.0")
  2484  	}
  2485  }
  2486  
  2487  func TestResumptionKeepsOCSPAndSCT(t *testing.T) {
  2488  	t.Run("TLSv12", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS12) })
  2489  	t.Run("TLSv13", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS13) })
  2490  }
  2491  
  2492  func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) {
  2493  	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
  2494  	if err != nil {
  2495  		t.Fatalf("failed to parse test issuer")
  2496  	}
  2497  	roots := x509.NewCertPool()
  2498  	roots.AddCert(issuer)
  2499  	clientConfig := &Config{
  2500  		MaxVersion:         ver,
  2501  		ClientSessionCache: NewLRUClientSessionCache(32),
  2502  		ServerName:         "example.golang",
  2503  		RootCAs:            roots,
  2504  	}
  2505  	serverConfig := testConfig.Clone()
  2506  	serverConfig.MaxVersion = ver
  2507  	serverConfig.Certificates[0].OCSPStaple = []byte{1, 2, 3}
  2508  	serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{4, 5, 6}}
  2509  
  2510  	_, ccs, err := testHandshake(t, clientConfig, serverConfig)
  2511  	if err != nil {
  2512  		t.Fatalf("handshake failed: %s", err)
  2513  	}
  2514  	// after a new session we expect to see OCSPResponse and
  2515  	// SignedCertificateTimestamps populated as usual
  2516  	if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) {
  2517  		t.Errorf("client ConnectionState contained unexpected OCSPResponse: wanted %v, got %v",
  2518  			serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse)
  2519  	}
  2520  	if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) {
  2521  		t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps: wanted %v, got %v",
  2522  			serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
  2523  	}
  2524  
  2525  	// if the server doesn't send any SCTs, repopulate the old SCTs
  2526  	oldSCTs := serverConfig.Certificates[0].SignedCertificateTimestamps
  2527  	serverConfig.Certificates[0].SignedCertificateTimestamps = nil
  2528  	_, ccs, err = testHandshake(t, clientConfig, serverConfig)
  2529  	if err != nil {
  2530  		t.Fatalf("handshake failed: %s", err)
  2531  	}
  2532  	if !ccs.DidResume {
  2533  		t.Fatalf("expected session to be resumed")
  2534  	}
  2535  	// after a resumed session we also expect to see OCSPResponse
  2536  	// and SignedCertificateTimestamps populated
  2537  	if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) {
  2538  		t.Errorf("client ConnectionState contained unexpected OCSPResponse after resumption: wanted %v, got %v",
  2539  			serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse)
  2540  	}
  2541  	if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, oldSCTs) {
  2542  		t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v",
  2543  			oldSCTs, ccs.SignedCertificateTimestamps)
  2544  	}
  2545  
  2546  	//  Only test overriding the SCTs for TLS 1.2, since in 1.3
  2547  	// the server won't send the message containing them
  2548  	if ver == VersionTLS13 {
  2549  		return
  2550  	}
  2551  
  2552  	// if the server changes the SCTs it sends, they should override the saved SCTs
  2553  	serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{7, 8, 9}}
  2554  	_, ccs, err = testHandshake(t, clientConfig, serverConfig)
  2555  	if err != nil {
  2556  		t.Fatalf("handshake failed: %s", err)
  2557  	}
  2558  	if !ccs.DidResume {
  2559  		t.Fatalf("expected session to be resumed")
  2560  	}
  2561  	if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) {
  2562  		t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v",
  2563  			serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
  2564  	}
  2565  }
  2566  
  2567  // TestClientHandshakeContextCancellation tests that cancelling
  2568  // the context given to the client side conn.HandshakeContext
  2569  // interrupts the in-progress handshake.
  2570  func TestClientHandshakeContextCancellation(t *testing.T) {
  2571  	c, s := localPipe(t)
  2572  	ctx, cancel := context.WithCancel(context.Background())
  2573  	unblockServer := make(chan struct{})
  2574  	defer close(unblockServer)
  2575  	go func() {
  2576  		cancel()
  2577  		<-unblockServer
  2578  		_ = s.Close()
  2579  	}()
  2580  	cli := Client(c, testConfig)
  2581  	// Initiates client side handshake, which will block until the client hello is read
  2582  	// by the server, unless the cancellation works.
  2583  	err := cli.HandshakeContext(ctx)
  2584  	if err == nil {
  2585  		t.Fatal("Client handshake did not error when the context was canceled")
  2586  	}
  2587  	if err != context.Canceled {
  2588  		t.Errorf("Unexpected client handshake error: %v", err)
  2589  	}
  2590  	if runtime.GOARCH == "wasm" {
  2591  		t.Skip("conn.Close does not error as expected when called multiple times on WASM")
  2592  	}
  2593  	err = cli.Close()
  2594  	if err == nil {
  2595  		t.Error("Client connection was not closed when the context was canceled")
  2596  	}
  2597  }
  2598  

View as plain text